diff --git a/README.md b/README.md index ad48b3c..5caae22 100644 --- a/README.md +++ b/README.md @@ -269,6 +269,8 @@ Edit `claude_desktop_config.json` (macOS: `~/Library/Application Support/Claude/ **`NodeFilter` notes:** - `filter` is a JSON object matching the `NodeFilter` schema. Wire types are `object` or, as a fallback, a JSON-encoded string for clients that flatten objects. +- Unknown filter keys and populated fields that are not applicable to the effective node kind fail loudly with `success=false` and a teaching `message` (no silent key dropping). +- For `neighbors`, mixed-kind neighborhoods fail on the first evaluated neighbor row whose kind makes populated filter fields inapplicable. - Symbol-only keys: `symbol_kind` (single value) and `symbol_kinds` (set membership) for declaration granularity (`class`, `interface`, `enum`, `record`, `annotation`, `method`, `constructor`). - `find(kind="symbol", ...)` results include `symbol_kind` so callers can see declaration granularity without a follow-up `describe`. - For `find`, an empty / whitespace-only filter string or the JSON literal `null` is treated like `{}` (match anything). diff --git a/mcp_v2.py b/mcp_v2.py index d2ebfc2..27fe570 100644 --- a/mcp_v2.py +++ b/mcp_v2.py @@ -6,7 +6,7 @@ import threading from typing import Annotated, Any, Literal -from pydantic import BaseModel, Field, TypeAdapter, ValidationError, validate_call +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, ValidationError, validate_call from sentence_transformers import SentenceTransformer from index_common import SBERT_MODEL @@ -57,6 +57,8 @@ def _get_sentence_transformer(model_name: str, device: str | None) -> SentenceTr class NodeFilter(BaseModel): + model_config = ConfigDict(extra="forbid") + microservice: str | None = None module: str | None = None source_layer: str | None = None @@ -76,6 +78,86 @@ class NodeFilter(BaseModel): client_method: str | None = None +_NODEFILTER_FIELD_ORDER: tuple[str, ...] = tuple(NodeFilter.model_fields.keys()) + +_NODEFILTER_APPLICABLE_FIELDS: dict[Literal["symbol", "route", "client"], tuple[str, ...]] = { + "symbol": ( + "microservice", + "module", + "role", + "exclude_roles", + "annotation", + "capability", + "fqn_prefix", + "symbol_kind", + "symbol_kinds", + ), + "route": ( + "microservice", + "module", + "http_method", + "path_prefix", + "framework", + ), + "client": ( + "microservice", + "module", + "source_layer", + "client_kind", + "target_service", + "target_path_prefix", + "client_method", + ), +} + + +def _ordered_nodefilter_fields(field_names: set[str]) -> list[str]: + return [name for name in _NODEFILTER_FIELD_ORDER if name in field_names] + + +def _populated_nodefilter_fields(nf: NodeFilter) -> set[str]: + populated: set[str] = set() + for field_name in _NODEFILTER_FIELD_ORDER: + value = getattr(nf, field_name) + if value is None: + continue + if isinstance(value, list) and not value: + continue + populated.add(field_name) + return populated + + +def _nodefilter_inapplicable_fields(kind: Literal["symbol", "route", "client"], nf: NodeFilter) -> list[str]: + populated = _populated_nodefilter_fields(nf) + applicable = set(_NODEFILTER_APPLICABLE_FIELDS[kind]) + return _ordered_nodefilter_fields(populated - applicable) + + +def _nodefilter_applicability_error(kind: Literal["symbol", "route", "client"], nf: NodeFilter) -> str | None: + inapplicable = _nodefilter_inapplicable_fields(kind, nf) + if not inapplicable: + return None + applicable = ", ".join(_NODEFILTER_APPLICABLE_FIELDS[kind]) + bad = ", ".join(inapplicable) + return ( + f"Invalid filter for kind='{kind}': populated field(s) not applicable: [{bad}]. " + f"Applicable field(s): [{applicable}]" + ) + + +def _filter_validation_error_message(exc: ValidationError) -> str: + items: list[str] = [] + for err in exc.errors(): + loc = ".".join(str(part) for part in err.get("loc", ())) + msg = str(err.get("msg") or "invalid value") + if loc: + items.append(f"{loc}: {msg}") + else: + items.append(msg) + details = "; ".join(items) if items else str(exc) + return f"Invalid filter: {details}" + + def _coerce_filter( value: NodeFilter | dict[str, Any] | str | None, ) -> NodeFilter | dict[str, Any] | None: @@ -432,11 +514,16 @@ def search_v2( model=model, ) raw_filter = _coerce_filter(filter) - nf = ( - NodeFilter.model_validate(raw_filter) - if raw_filter is not None and not isinstance(raw_filter, NodeFilter) - else raw_filter - ) + try: + nf = ( + NodeFilter.model_validate(raw_filter) + if raw_filter is not None and not isinstance(raw_filter, NodeFilter) + else raw_filter + ) + except ValidationError as exc: + return SearchOutput(success=False, message=_filter_validation_error_message(exc)) + if nf and (err := _nodefilter_applicability_error("symbol", nf)): + return SearchOutput(success=False, message=err) hits: list[SearchHit] = [] for row in rows: if path_contains and path_contains not in str(row.get("filename") or ""): @@ -463,7 +550,12 @@ def find_v2( raw_filter = _coerce_filter(filter) if raw_filter is None: raw_filter = {} - nf = NodeFilter.model_validate(raw_filter) if not isinstance(raw_filter, NodeFilter) else raw_filter + try: + nf = NodeFilter.model_validate(raw_filter) if not isinstance(raw_filter, NodeFilter) else raw_filter + except ValidationError as exc: + return FindOutput(success=False, message=_filter_validation_error_message(exc)) + if err := _nodefilter_applicability_error(kind, nf): + return FindOutput(success=False, message=err) if kind == "symbol": where, params = _symbol_where_from_filter(nf) params["lim"] = int(limit) + int(offset) @@ -539,12 +631,15 @@ def neighbors_v2( label_params = [f"l{i}" for i in range(len(labels))] label_predicate = "(" + " OR ".join(f"label(e) = ${name}" for name in label_params) + ")" g = graph or KuzuGraph.get() - raw_filter = _coerce_filter(filter) - nf = ( - NodeFilter.model_validate(raw_filter) - if raw_filter is not None and not isinstance(raw_filter, NodeFilter) - else raw_filter - ) + try: + raw_filter = _coerce_filter(filter) + nf = ( + NodeFilter.model_validate(raw_filter) + if raw_filter is not None and not isinstance(raw_filter, NodeFilter) + else raw_filter + ) + except ValidationError as exc: + return NeighborsOutput(success=False, message=_filter_validation_error_message(exc)) origins = [ids] if isinstance(ids, str) else list(ids) results: list[Edge] = [] for origin_id in origins: @@ -580,6 +675,8 @@ def neighbors_v2( other_rec = _load_node_record(g, other_id, other_kind) if other_rec is None: continue + if nf and (err := _nodefilter_applicability_error(other_kind, nf)): + return NeighborsOutput(success=False, message=err) if not _node_matches_filter(other_kind, other_rec, nf): continue attrs = { diff --git a/server.py b/server.py index 8650323..5e00ad0 100644 --- a/server.py +++ b/server.py @@ -28,6 +28,7 @@ "Tools: search (NL/code locate), find (structured NodeFilter), describe (one node + edge_summary: stored edge-label counts and optional composed keys for type Symbols and override-axis virtual keys for method Symbols), " "neighbors (one hop; you MUST pass direction in|out AND edge_types list — no defaults). " "NodeFilter `filter` is a JSON object (preferred); a JSON-encoded string is also accepted as a fallback. " + "Unknown filter keys and populated fields not applicable to the effective node kind fail with success=false and message. " "Edge labels: EXTENDS, IMPLEMENTS, INJECTS, DECLARES, DECLARES_CLIENT, CALLS, EXPOSES, HTTP_CALLS, ASYNC_CALLS. " "Reprocess/init, meta, tables, diagnose-ignore, analyze-pr: use java-codebase-rag CLI — not MCP." ) @@ -348,7 +349,8 @@ async def search( filter: dict[str, Any] | str | None = Field( default=None, description=( - "Optional NodeFilter (symbol-oriented keys) applied to each hit after search. " + "Optional NodeFilter (symbol applicability). Unknown keys and populated non-symbol fields return success=false " + "with a teaching message. " "Prefer a JSON object; a JSON-encoded string is accepted as a fallback." ), ), @@ -376,8 +378,9 @@ async def find( filter: dict[str, Any] | str = Field( ..., description=( - "Required NodeFilter (shared schema; irrelevant keys ignored per kind). " - "Symbol filters also support symbol_kind and symbol_kinds. " + "Required NodeFilter (shared schema, strict extras). Unknown keys and populated fields not applicable to " + "the selected kind return success=false with a teaching message. Symbol filters also support symbol_kind " + "and symbol_kinds. " "Prefer a JSON object; a JSON-encoded string is accepted as a fallback." ), ), @@ -430,7 +433,9 @@ async def neighbors( filter: dict[str, Any] | str | None = Field( default=None, description=( - "Optional NodeFilter applied to the other endpoint of each edge. " + "Optional NodeFilter applied to the other endpoint of each edge. Unknown keys and populated fields not " + "applicable to an evaluated neighbor kind return success=false with a teaching message. For mixed " + "neighbor kinds, evaluation fails on the first inapplicable row. " "Prefer a JSON object; a JSON-encoded string is accepted as a fallback." ), ), diff --git a/tests/test_mcp_v2.py b/tests/test_mcp_v2.py index fa29d43..e1bfc5c 100644 --- a/tests/test_mcp_v2.py +++ b/tests/test_mcp_v2.py @@ -7,6 +7,8 @@ from mcp.server.fastmcp.exceptions import ToolError from mcp_v2 import ( + NodeFilter, + _NODEFILTER_APPLICABLE_FIELDS, describe_v2, find_v2, neighbors_v2, @@ -198,10 +200,42 @@ def test_find_client_by_path_prefix(kuzu_graph) -> None: assert bits[-1].startswith(prefix) -def test_find_silent_ignore_irrelevant_filter_keys(kuzu_graph) -> None: +def test_find_cross_kind_filter_fields_return_failure(kuzu_graph) -> None: out = find_v2("symbol", {"path_prefix": "/api"}, graph=kuzu_graph) - assert out.success is True - assert isinstance(out.results, list) + assert out.success is False + assert out.message is not None + assert "path_prefix" in out.message + assert "kind='symbol'" in out.message + + +def test_find_unknown_filter_key_returns_failure(kuzu_graph) -> None: + out = find_v2("symbol", {"typo_key": "x"}, graph=kuzu_graph) + assert out.success is False + assert out.message is not None + assert "Invalid filter" in out.message + assert "typo_key" in out.message + + +def test_find_symbol_only_field_with_kind_client_returns_failure(kuzu_graph) -> None: + out = find_v2("client", {"fqn_prefix": "com.example"}, graph=kuzu_graph) + assert out.success is False + assert out.message is not None + assert "fqn_prefix" in out.message + assert "kind='client'" in out.message + + +def test_find_client_only_field_with_kind_symbol_returns_failure(kuzu_graph) -> None: + out = find_v2("symbol", {"client_kind": "feign_method"}, graph=kuzu_graph) + assert out.success is False + assert out.message is not None + assert "client_kind" in out.message + assert "kind='symbol'" in out.message + + +def test_nodefilter_applicability_table_covers_all_fields() -> None: + declared = set(NodeFilter.model_fields.keys()) + covered = set().union(*_NODEFILTER_APPLICABLE_FIELDS.values()) + assert declared == covered async def test_find_missing_filter_rejected(mcp_server) -> None: @@ -433,6 +467,24 @@ def test_search_filter_accepts_json_string(monkeypatch, kuzu_graph) -> None: assert out_dict.results == out_str.results +def test_search_unknown_filter_key_returns_failure(monkeypatch, kuzu_graph) -> None: + monkeypatch.setattr("mcp_v2.run_search", lambda *args, **kwargs: _fake_search_rows()) + out = search_v2("ChatService", filter={"typo_key": "x"}, graph=kuzu_graph) + assert out.success is False + assert out.message is not None + assert "Invalid filter" in out.message + assert "typo_key" in out.message + + +def test_search_cross_kind_filter_returns_failure(monkeypatch, kuzu_graph) -> None: + monkeypatch.setattr("mcp_v2.run_search", lambda *args, **kwargs: _fake_search_rows()) + out = search_v2("ChatService", filter={"path_prefix": "/api"}, graph=kuzu_graph) + assert out.success is False + assert out.message is not None + assert "path_prefix" in out.message + assert "kind='symbol'" in out.message + + def test_search_filter_empty_string_treated_as_none(monkeypatch, kuzu_graph) -> None: monkeypatch.setattr("mcp_v2.run_search", lambda *args, **kwargs: _fake_search_rows()) baseline = search_v2("ChatService", graph=kuzu_graph) @@ -487,6 +539,30 @@ def test_neighbors_filter_accepts_json_string(kuzu_graph) -> None: assert out_dict.results == out_str.results +def test_neighbors_filter_unknown_key_returns_failure(kuzu_graph) -> None: + mid = _method_id_with_calls(kuzu_graph, "out") + out = neighbors_v2(mid, direction="out", edge_types=["CALLS"], filter={"typo_key": "x"}, graph=kuzu_graph) + assert out.success is False + assert out.message is not None + assert "Invalid filter" in out.message + assert "typo_key" in out.message + + +def test_neighbors_filter_cross_kind_on_neighbor_returns_failure(kuzu_graph) -> None: + mid = _method_id_with_calls(kuzu_graph, "out") + out = neighbors_v2(mid, direction="out", edge_types=["CALLS"], filter={"path_prefix": "/api"}, graph=kuzu_graph) + assert out.success is False + assert out.message is not None + assert "path_prefix" in out.message + assert "kind='symbol'" in out.message + + +def test_neighbors_validate_call_still_raises(kuzu_graph) -> None: + mid = _method_id_with_calls(kuzu_graph, "out") + with pytest.raises(ValidationError): + neighbors_v2(mid, direction="upstream", edge_types=["CALLS"], graph=kuzu_graph) + + def test_filter_invalid_json_returns_failure(monkeypatch, kuzu_graph) -> None: monkeypatch.setattr("mcp_v2.run_search", lambda *args, **kwargs: _fake_search_rows()) out = search_v2("ChatService", filter="{not json", graph=kuzu_graph)