Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ Edit `claude_desktop_config.json` (macOS: `~/Library/Application Support/Claude/
**`NodeFilter` notes:**

- `filter` is a JSON object matching the `NodeFilter` schema. Wire types are `object` or, as a fallback, a JSON-encoded string for clients that flatten objects.
- Unknown filter keys and populated fields that are not applicable to the effective node kind fail loudly with `success=false` and a teaching `message` (no silent key dropping).
- For `neighbors`, mixed-kind neighborhoods fail on the first evaluated neighbor row whose kind makes populated filter fields inapplicable.
- Symbol-only keys: `symbol_kind` (single value) and `symbol_kinds` (set membership) for declaration granularity (`class`, `interface`, `enum`, `record`, `annotation`, `method`, `constructor`).
- `find(kind="symbol", ...)` results include `symbol_kind` so callers can see declaration granularity without a follow-up `describe`.
- For `find`, an empty / whitespace-only filter string or the JSON literal `null` is treated like `{}` (match anything).
Expand Down
123 changes: 110 additions & 13 deletions mcp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import threading
from typing import Annotated, Any, Literal

from pydantic import BaseModel, Field, TypeAdapter, ValidationError, validate_call
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, ValidationError, validate_call
from sentence_transformers import SentenceTransformer

from index_common import SBERT_MODEL
Expand Down Expand Up @@ -57,6 +57,8 @@ def _get_sentence_transformer(model_name: str, device: str | None) -> SentenceTr


class NodeFilter(BaseModel):
model_config = ConfigDict(extra="forbid")

microservice: str | None = None
module: str | None = None
source_layer: str | None = None
Expand All @@ -76,6 +78,86 @@ class NodeFilter(BaseModel):
client_method: str | None = None


_NODEFILTER_FIELD_ORDER: tuple[str, ...] = tuple(NodeFilter.model_fields.keys())

_NODEFILTER_APPLICABLE_FIELDS: dict[Literal["symbol", "route", "client"], tuple[str, ...]] = {
"symbol": (
"microservice",
"module",
"role",
"exclude_roles",
"annotation",
"capability",
"fqn_prefix",
"symbol_kind",
"symbol_kinds",
),
"route": (
"microservice",
"module",
"http_method",
"path_prefix",
"framework",
),
"client": (
"microservice",
"module",
"source_layer",
"client_kind",
"target_service",
"target_path_prefix",
"client_method",
),
}


def _ordered_nodefilter_fields(field_names: set[str]) -> list[str]:
return [name for name in _NODEFILTER_FIELD_ORDER if name in field_names]


def _populated_nodefilter_fields(nf: NodeFilter) -> set[str]:
populated: set[str] = set()
for field_name in _NODEFILTER_FIELD_ORDER:
value = getattr(nf, field_name)
if value is None:
continue
if isinstance(value, list) and not value:
continue
populated.add(field_name)
return populated


def _nodefilter_inapplicable_fields(kind: Literal["symbol", "route", "client"], nf: NodeFilter) -> list[str]:
populated = _populated_nodefilter_fields(nf)
applicable = set(_NODEFILTER_APPLICABLE_FIELDS[kind])
return _ordered_nodefilter_fields(populated - applicable)


def _nodefilter_applicability_error(kind: Literal["symbol", "route", "client"], nf: NodeFilter) -> str | None:
inapplicable = _nodefilter_inapplicable_fields(kind, nf)
if not inapplicable:
return None
applicable = ", ".join(_NODEFILTER_APPLICABLE_FIELDS[kind])
bad = ", ".join(inapplicable)
return (
f"Invalid filter for kind='{kind}': populated field(s) not applicable: [{bad}]. "
f"Applicable field(s): [{applicable}]"
)


def _filter_validation_error_message(exc: ValidationError) -> str:
items: list[str] = []
for err in exc.errors():
loc = ".".join(str(part) for part in err.get("loc", ()))
msg = str(err.get("msg") or "invalid value")
if loc:
items.append(f"{loc}: {msg}")
else:
items.append(msg)
details = "; ".join(items) if items else str(exc)
return f"Invalid filter: {details}"


def _coerce_filter(
value: NodeFilter | dict[str, Any] | str | None,
) -> NodeFilter | dict[str, Any] | None:
Expand Down Expand Up @@ -432,11 +514,16 @@ def search_v2(
model=model,
)
raw_filter = _coerce_filter(filter)
nf = (
NodeFilter.model_validate(raw_filter)
if raw_filter is not None and not isinstance(raw_filter, NodeFilter)
else raw_filter
)
try:
nf = (
NodeFilter.model_validate(raw_filter)
if raw_filter is not None and not isinstance(raw_filter, NodeFilter)
else raw_filter
)
except ValidationError as exc:
return SearchOutput(success=False, message=_filter_validation_error_message(exc))
if nf and (err := _nodefilter_applicability_error("symbol", nf)):
return SearchOutput(success=False, message=err)
hits: list[SearchHit] = []
for row in rows:
if path_contains and path_contains not in str(row.get("filename") or ""):
Expand All @@ -463,7 +550,12 @@ def find_v2(
raw_filter = _coerce_filter(filter)
if raw_filter is None:
raw_filter = {}
nf = NodeFilter.model_validate(raw_filter) if not isinstance(raw_filter, NodeFilter) else raw_filter
try:
nf = NodeFilter.model_validate(raw_filter) if not isinstance(raw_filter, NodeFilter) else raw_filter
except ValidationError as exc:
return FindOutput(success=False, message=_filter_validation_error_message(exc))
if err := _nodefilter_applicability_error(kind, nf):
return FindOutput(success=False, message=err)
if kind == "symbol":
where, params = _symbol_where_from_filter(nf)
params["lim"] = int(limit) + int(offset)
Expand Down Expand Up @@ -539,12 +631,15 @@ def neighbors_v2(
label_params = [f"l{i}" for i in range(len(labels))]
label_predicate = "(" + " OR ".join(f"label(e) = ${name}" for name in label_params) + ")"
g = graph or KuzuGraph.get()
raw_filter = _coerce_filter(filter)
nf = (
NodeFilter.model_validate(raw_filter)
if raw_filter is not None and not isinstance(raw_filter, NodeFilter)
else raw_filter
)
try:
raw_filter = _coerce_filter(filter)
nf = (
NodeFilter.model_validate(raw_filter)
if raw_filter is not None and not isinstance(raw_filter, NodeFilter)
else raw_filter
)
except ValidationError as exc:
return NeighborsOutput(success=False, message=_filter_validation_error_message(exc))
origins = [ids] if isinstance(ids, str) else list(ids)
results: list[Edge] = []
for origin_id in origins:
Expand Down Expand Up @@ -580,6 +675,8 @@ def neighbors_v2(
other_rec = _load_node_record(g, other_id, other_kind)
if other_rec is None:
continue
if nf and (err := _nodefilter_applicability_error(other_kind, nf)):
return NeighborsOutput(success=False, message=err)
if not _node_matches_filter(other_kind, other_rec, nf):
continue
attrs = {
Expand Down
13 changes: 9 additions & 4 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"Tools: search (NL/code locate), find (structured NodeFilter), describe (one node + edge_summary: stored edge-label counts and optional composed keys for type Symbols and override-axis virtual keys for method Symbols), "
"neighbors (one hop; you MUST pass direction in|out AND edge_types list — no defaults). "
"NodeFilter `filter` is a JSON object (preferred); a JSON-encoded string is also accepted as a fallback. "
"Unknown filter keys and populated fields not applicable to the effective node kind fail with success=false and message. "
"Edge labels: EXTENDS, IMPLEMENTS, INJECTS, DECLARES, DECLARES_CLIENT, CALLS, EXPOSES, HTTP_CALLS, ASYNC_CALLS. "
"Reprocess/init, meta, tables, diagnose-ignore, analyze-pr: use java-codebase-rag CLI — not MCP."
)
Expand Down Expand Up @@ -348,7 +349,8 @@ async def search(
filter: dict[str, Any] | str | None = Field(
default=None,
description=(
"Optional NodeFilter (symbol-oriented keys) applied to each hit after search. "
"Optional NodeFilter (symbol applicability). Unknown keys and populated non-symbol fields return success=false "
"with a teaching message. "
"Prefer a JSON object; a JSON-encoded string is accepted as a fallback."
),
),
Expand Down Expand Up @@ -376,8 +378,9 @@ async def find(
filter: dict[str, Any] | str = Field(
...,
description=(
"Required NodeFilter (shared schema; irrelevant keys ignored per kind). "
"Symbol filters also support symbol_kind and symbol_kinds. "
"Required NodeFilter (shared schema, strict extras). Unknown keys and populated fields not applicable to "
"the selected kind return success=false with a teaching message. Symbol filters also support symbol_kind "
"and symbol_kinds. "
"Prefer a JSON object; a JSON-encoded string is accepted as a fallback."
),
),
Expand Down Expand Up @@ -430,7 +433,9 @@ async def neighbors(
filter: dict[str, Any] | str | None = Field(
default=None,
description=(
"Optional NodeFilter applied to the other endpoint of each edge. "
"Optional NodeFilter applied to the other endpoint of each edge. Unknown keys and populated fields not "
"applicable to an evaluated neighbor kind return success=false with a teaching message. For mixed "
"neighbor kinds, evaluation fails on the first inapplicable row. "
"Prefer a JSON object; a JSON-encoded string is accepted as a fallback."
),
),
Expand Down
82 changes: 79 additions & 3 deletions tests/test_mcp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from mcp.server.fastmcp.exceptions import ToolError

from mcp_v2 import (
NodeFilter,
_NODEFILTER_APPLICABLE_FIELDS,
describe_v2,
find_v2,
neighbors_v2,
Expand Down Expand Up @@ -198,10 +200,42 @@ def test_find_client_by_path_prefix(kuzu_graph) -> None:
assert bits[-1].startswith(prefix)


def test_find_silent_ignore_irrelevant_filter_keys(kuzu_graph) -> None:
def test_find_cross_kind_filter_fields_return_failure(kuzu_graph) -> None:
out = find_v2("symbol", {"path_prefix": "/api"}, graph=kuzu_graph)
assert out.success is True
assert isinstance(out.results, list)
assert out.success is False
assert out.message is not None
assert "path_prefix" in out.message
assert "kind='symbol'" in out.message


def test_find_unknown_filter_key_returns_failure(kuzu_graph) -> None:
out = find_v2("symbol", {"typo_key": "x"}, graph=kuzu_graph)
assert out.success is False
assert out.message is not None
assert "Invalid filter" in out.message
assert "typo_key" in out.message


def test_find_symbol_only_field_with_kind_client_returns_failure(kuzu_graph) -> None:
out = find_v2("client", {"fqn_prefix": "com.example"}, graph=kuzu_graph)
assert out.success is False
assert out.message is not None
assert "fqn_prefix" in out.message
assert "kind='client'" in out.message


def test_find_client_only_field_with_kind_symbol_returns_failure(kuzu_graph) -> None:
out = find_v2("symbol", {"client_kind": "feign_method"}, graph=kuzu_graph)
assert out.success is False
assert out.message is not None
assert "client_kind" in out.message
assert "kind='symbol'" in out.message


def test_nodefilter_applicability_table_covers_all_fields() -> None:
declared = set(NodeFilter.model_fields.keys())
covered = set().union(*_NODEFILTER_APPLICABLE_FIELDS.values())
assert declared == covered


async def test_find_missing_filter_rejected(mcp_server) -> None:
Expand Down Expand Up @@ -433,6 +467,24 @@ def test_search_filter_accepts_json_string(monkeypatch, kuzu_graph) -> None:
assert out_dict.results == out_str.results


def test_search_unknown_filter_key_returns_failure(monkeypatch, kuzu_graph) -> None:
monkeypatch.setattr("mcp_v2.run_search", lambda *args, **kwargs: _fake_search_rows())
out = search_v2("ChatService", filter={"typo_key": "x"}, graph=kuzu_graph)
assert out.success is False
assert out.message is not None
assert "Invalid filter" in out.message
assert "typo_key" in out.message


def test_search_cross_kind_filter_returns_failure(monkeypatch, kuzu_graph) -> None:
monkeypatch.setattr("mcp_v2.run_search", lambda *args, **kwargs: _fake_search_rows())
out = search_v2("ChatService", filter={"path_prefix": "/api"}, graph=kuzu_graph)
assert out.success is False
assert out.message is not None
assert "path_prefix" in out.message
assert "kind='symbol'" in out.message


def test_search_filter_empty_string_treated_as_none(monkeypatch, kuzu_graph) -> None:
monkeypatch.setattr("mcp_v2.run_search", lambda *args, **kwargs: _fake_search_rows())
baseline = search_v2("ChatService", graph=kuzu_graph)
Expand Down Expand Up @@ -487,6 +539,30 @@ def test_neighbors_filter_accepts_json_string(kuzu_graph) -> None:
assert out_dict.results == out_str.results


def test_neighbors_filter_unknown_key_returns_failure(kuzu_graph) -> None:
mid = _method_id_with_calls(kuzu_graph, "out")
out = neighbors_v2(mid, direction="out", edge_types=["CALLS"], filter={"typo_key": "x"}, graph=kuzu_graph)
assert out.success is False
assert out.message is not None
assert "Invalid filter" in out.message
assert "typo_key" in out.message


def test_neighbors_filter_cross_kind_on_neighbor_returns_failure(kuzu_graph) -> None:
mid = _method_id_with_calls(kuzu_graph, "out")
out = neighbors_v2(mid, direction="out", edge_types=["CALLS"], filter={"path_prefix": "/api"}, graph=kuzu_graph)
assert out.success is False
assert out.message is not None
assert "path_prefix" in out.message
assert "kind='symbol'" in out.message


def test_neighbors_validate_call_still_raises(kuzu_graph) -> None:
mid = _method_id_with_calls(kuzu_graph, "out")
with pytest.raises(ValidationError):
neighbors_v2(mid, direction="upstream", edge_types=["CALLS"], graph=kuzu_graph)


def test_filter_invalid_json_returns_failure(monkeypatch, kuzu_graph) -> None:
monkeypatch.setattr("mcp_v2.run_search", lambda *args, **kwargs: _fake_search_rows())
out = search_v2("ChatService", filter="{not json", graph=kuzu_graph)
Expand Down
Loading