diff --git a/README.md b/README.md index 66b1d21..84e6b76 100644 --- a/README.md +++ b/README.md @@ -244,7 +244,7 @@ Edit `claude_desktop_config.json` (macOS: `~/Library/Application Support/Claude/ | `find` | Locate nodes by structured filter. | `kind: "symbol"\|"route"\|"client"\|"producer"`, `filter: NodeFilter \| str`, `limit: int=25`, `offset: int=0` | `{"kind":"symbol","filter":{"role":"CONTROLLER"}}` | | `describe` | Full record + edge counts for one node. For **type** symbols, `edge_summary` may include composed dot-keys (`DECLARES.DECLARES_CLIENT`, `DECLARES.EXPOSES`); for **method** symbols it may include override-axis virtual keys (`OVERRIDDEN_BY`, …) and an `OVERRIDES` row that **merges** stored `[:OVERRIDES]` in/out with the dispatch-up rollup (per direction `max`). See [`docs/AGENT-GUIDE.md`](./docs/AGENT-GUIDE.md) (`describe`). | `id: str` | `{"id":"sym:com.bank.chat.core.api.ChatController#joinOperator(JoinOperatorRequest)"}` | | `resolve` | Identifier-shaped node lookup (symbol / route / client / producer). Returns `status` `one`, `many`, or `none`; prefer over `describe(fqn=…)` when an FQN may collide. See [`docs/AGENT-GUIDE.md`](./docs/AGENT-GUIDE.md) (`resolve`). | `identifier: str`, `hint_kind: "symbol"|"route"|"client"|"producer" \| null` | `{"identifier":"com.bank.chat.core.api.ChatController","hint_kind":"symbol"}` | -| `neighbors` | Graph walk. **Required**: `direction` and `edge_types` (stored labels; type Symbols may also pass composed `DECLARES.DECLARES_CLIENT`, `DECLARES.DECLARES_PRODUCER`, `DECLARES.EXPOSES` — `out` only — see [`docs/AGENT-GUIDE.md`](./docs/AGENT-GUIDE.md)). | `ids: str \| list[str]`, `direction: "in"\|"out"`, `edge_types: list[str]`, `limit: int=25`, `offset: int=0`, `filter: NodeFilter \| str \| None` | `{"ids":"sym:…ChatController","direction":"out","edge_types":["DECLARES.DECLARES_CLIENT"]}` | +| `neighbors` | Graph walk. **Required**: `direction` and `edge_types` (stored labels; type Symbols may also pass composed `DECLARES.DECLARES_CLIENT`, `DECLARES.DECLARES_PRODUCER`, `DECLARES.EXPOSES` — `out` only — see [`docs/AGENT-GUIDE.md`](./docs/AGENT-GUIDE.md)). | `ids: str \| list[str]`, `direction: "in"\|"out"`, `edge_types: list[str]`, `limit: int=25`, `offset: int=0`, `filter: NodeFilter \| str \| None`, `edge_filter: EdgeFilter \| str \| None` (`CALLS` only; see guide) | `{"ids":"sym:…ChatController","direction":"out","edge_types":["DECLARES.DECLARES_CLIENT"]}` | **`NodeFilter` notes:** diff --git a/docs/AGENT-GUIDE.md b/docs/AGENT-GUIDE.md index 88ead2b..79a07ea 100644 --- a/docs/AGENT-GUIDE.md +++ b/docs/AGENT-GUIDE.md @@ -211,13 +211,15 @@ Identifier lookup; three statuses above. Args: `identifier`, optional `hint_kind #### `neighbors` -One hop. Args: `ids` (string or array), **`direction`**, **`edge_types`**, `limit` (default 25), `offset`, optional `filter` on the other node. +One hop. Args: `ids` (string or array), **`direction`**, **`edge_types`**, `limit` (default 25), `offset`, optional `filter` on the other node, optional **`edge_filter`** (`edge_types` must be exactly `['CALLS']` — no composed dot-keys or second stored label; fail-loud otherwise). + +**Multiple origin ids:** each id loads the full CALLS stream (or generic hop) in list order; `offset`/`limit` apply to the **concatenated** edge list (`ids[0]` edges first, then `ids[1]`, …), not global source order across origins — a large first origin can leave no rows for later ids within the same page. High fan-out methods are slow; prefer one id per call or a smaller `limit`. Returns **edges** with `attrs` (`confidence`, `strategy`, `match`, … on cross-service edges) and **`other`** node. **Cross-service edges** (`HTTP_CALLS`, `ASYNC_CALLS`): read `attrs.confidence` and `attrs.match` — low confidence or `unresolved`/`phantom`/`ambiguous` means treat as a resolver signal, not ground truth. -**`CALLS` edges:** `attrs.resolved=false` or low `attrs.confidence` may be JDK/external or unresolved static sites — still a lower bound, not exhaustive runtime behaviour. +**`CALLS` edges:** source-ordered (`call_site_line`, `call_site_byte`). `attrs.resolved=false` or low `attrs.confidence` may be JDK/external or unresolved static sites — still a lower bound, not exhaustive runtime behaviour. **`filter` + `edge_filter` together** load the ordered CALLS stream then apply callee `NodeFilter` in Python — expect higher latency on hot methods than `edge_filter` alone. Optional **`edge_filter`** projects before pagination: `min_confidence`; `include_strategies` / `exclude_strategies` (mutually exclusive); `callee_declaring_role`, `callee_declaring_roles`, `exclude_callee_declaring_roles` (`["OTHER"]` also drops known-external rows). **`filter.role` filters the neighbor method (usually `OTHER`), not the callee stereotype** — use `edge_filter.callee_declaring_role` for repository/service hops. **`exclude_external` applies to `find_callers` / `find_callees` only** (FQN-prefix); trim JDK noise on CALLS via `edge_filter`. Accessor noise: role excludes help; getter/setter heuristics in [`propose/AGENT-SKILLS-AND-COMMANDS-PROPOSE.md`](../propose/AGENT-SKILLS-AND-COMMANDS-PROPOSE.md) `/mini-map`. ### Ontology glossary diff --git a/kuzu_queries.py b/kuzu_queries.py index e058b0d..1792d35 100644 --- a/kuzu_queries.py +++ b/kuzu_queries.py @@ -667,6 +667,76 @@ def member_edge_traversal_for(self, type_id: str, composed_key: str) -> list[dic {"id": type_id, "rel": rel}, ) + def count_calls_for_symbol(self, origin_id: str, *, direction: Literal["in", "out"]) -> int: + """Count CALLS edges incident on a Symbol (hints / diagnostics).""" + if direction == "out": + pattern = "MATCH (origin:Symbol {id: $id})-[e:CALLS]->() RETURN count(e) AS n" + else: + pattern = "MATCH (origin:Symbol {id: $id})<-[e:CALLS]-() RETURN count(e) AS n" + rows = self._rows(pattern, {"id": origin_id}) + return int(rows[0].get("n") or 0) if rows else 0 + + def neighbor_calls_for_symbol( + self, + origin_id: str, + *, + direction: Literal["in", "out"], + offset: int = 0, + limit: int | None = None, + sql_pagination: bool = True, + min_confidence: float | None = None, + include_strategies: list[str] | None = None, + exclude_strategies: list[str] | None = None, + callee_declaring_role: str | None = None, + callee_declaring_roles: list[str] | None = None, + exclude_callee_declaring_roles: list[str] | None = None, + ) -> list[dict[str, Any]]: + """CALLS neighbors with source-order delivery and optional edge-attribute pushdown. + + When ``sql_pagination`` is True and ``limit`` is set, ``SKIP``/``LIMIT`` apply after + ``ORDER BY e.call_site_line, e.call_site_byte``. Otherwise the full ordered stream is + returned for caller-side ``NodeFilter`` / pagination. + """ + wh_parts = ["origin.id = $id"] + params: dict[str, Any] = {"id": origin_id} + if min_confidence is not None: + wh_parts.append("e.confidence >= $min_confidence") + params["min_confidence"] = min_confidence + if include_strategies: + wh_parts.append("e.strategy IN $include_strategies") + params["include_strategies"] = include_strategies + if exclude_strategies: + wh_parts.append("NOT (e.strategy IN $exclude_strategies)") + params["exclude_strategies"] = exclude_strategies + if callee_declaring_role is not None: + wh_parts.append("e.callee_declaring_role = $callee_declaring_role") + params["callee_declaring_role"] = callee_declaring_role + if callee_declaring_roles: + wh_parts.append("e.callee_declaring_role IN $callee_declaring_roles") + params["callee_declaring_roles"] = callee_declaring_roles + if exclude_callee_declaring_roles: + wh_parts.append("NOT (e.callee_declaring_role IN $exclude_callee_declaring_roles)") + params["exclude_callee_declaring_roles"] = exclude_callee_declaring_roles + where = " AND ".join(wh_parts) + if direction == "out": + match = "MATCH (origin:Symbol)-[e:CALLS]->(other:Symbol)" + else: + match = "MATCH (origin:Symbol)<-[e:CALLS]-(other:Symbol)" + q = ( + f"{match} WHERE {where} " + "RETURN other.id AS other_id, 'CALLS' AS edge_type, " + "e.confidence AS confidence, e.strategy AS strategy, e.source AS source, " + "e.call_site_line AS call_site_line, e.call_site_byte AS call_site_byte, " + "e.arg_count AS arg_count, e.resolved AS resolved, " + "e.callee_declaring_role AS callee_declaring_role " + "ORDER BY e.call_site_line, e.call_site_byte" + ) + if sql_pagination and limit is not None: + q += " SKIP $offset LIMIT $limit" + params["offset"] = offset + params["limit"] = limit + return self._rows(q, params) + def _edge_row_count_from_method_ids(self, method_ids: list[str], rel: str) -> int: """Count outgoing ``rel`` edges from method symbols (describe rollup helper).""" total = 0 diff --git a/mcp_hints.py b/mcp_hints.py index 27d3958..07a0bfd 100644 --- a/mcp_hints.py +++ b/mcp_hints.py @@ -21,7 +21,10 @@ "Maximum 5 hints per output. Describe-time type rollup hints may recommend " "DECLARES.* dot-keys for neighbors(); empty neighbors structural hints never use " "dot-key edge labels. For neighbors with multiple origin ids, empty-result " - "structural hints describe the first origin only." + "structural hints describe the first origin only. On neighbors with " + "edge_types=['CALLS'] only, optional edge_filter projects the ordered CALLS stream " + "(min_confidence, strategies, callee_declaring_role axes); fail-loud with composed " + "dot-keys or additional stored labels." ) # --- Appendix A verbatim templates (substitute {id}, {kind}, {limit}) --- @@ -109,6 +112,18 @@ "some edges resolved via brownfield/fallback strategy — check attrs.strategy on each row" ) +TPL_NEIGHBORS_CALLS_ROLE_FILTER_OTHER_FALLBACK = ( + "0 CALLS matched callee_declaring_role filter but method has many callees — " + "targets may be OTHER (interface/JDK); try " + "edge_filter={{exclude_callee_declaring_roles: ['ENTITY','DTO']}} instead of role exact match" +) + +TPL_NEIGHBORS_CALLS_NODEFILTER_ROLE_COLLISION = ( + "NodeFilter.role filters the neighbor method's role (usually OTHER), not the callee's " + "declaring type — use edge_filter={{callee_declaring_role: 'SERVICE'}} (or REPOSITORY) " + "for CALLS stereotype projection" +) + # v4 neighbors success-path (propose/HINTS-V4-SUCCESS-PATH-PROPOSE.md); N1a/N1b alias describe templates. TPL_NEIGHBORS_SUCCESS_HTTP_TARGETS = "HTTP targets: neighbors(client_ids,'out',['HTTP_CALLS'])" TPL_NEIGHBORS_SUCCESS_ASYNC_TARGETS = "async targets: neighbors(producer_ids,'out',['ASYNC_CALLS'])" @@ -329,6 +344,45 @@ def _append_neighbors_success_hint(pairs: list[tuple[int, str]], text: str) -> N pairs.append((PRIORITY_LEAF_FOLLOWUP, text)) +def neighbors_calls_meta_hints(payload: dict[str, Any]) -> list[tuple[int, str]]: + """CALLS-specific hints: role-filter OTHER fallback (Decision 20) and NodeFilter.role trap (30).""" + pairs: list[tuple[int, str]] = [] + req_types = payload.get("requested_edge_types") + if not isinstance(req_types, list) or req_types != ["CALLS"]: + return pairs + results = list(payload.get("results") or []) + edge_flt = payload.get("edge_filter") if isinstance(payload.get("edge_filter"), dict) else {} + node_flt = payload.get("node_filter") if isinstance(payload.get("node_filter"), dict) else {} + role_exact = edge_flt.get("callee_declaring_role") + if ( + role_exact in ("SERVICE", "REPOSITORY") + and not results + and int(payload.get("unfiltered_calls_count") or 0) >= 5 + ): + pairs.append((PRIORITY_META, TPL_NEIGHBORS_CALLS_ROLE_FILTER_OTHER_FALLBACK)) + node_role = node_flt.get("role") + if node_role and results: + method_rows = [ + r + for r in results + if str(((r.get("other") or {}) if isinstance(r.get("other"), dict) else {}).get("symbol_kind") or "") + == "method" + ] + if method_rows: + other_roles = [ + str( + ((r.get("other") or {}) if isinstance(r.get("other"), dict) else {}).get("role") + or "" + ) + for r in method_rows + ] + if other_roles and sum(1 for role in other_roles if role == "OTHER") >= max( + 1, (len(other_roles) * 3) // 4 + ): + pairs.append((PRIORITY_META, TPL_NEIGHBORS_CALLS_NODEFILTER_ROLE_COLLISION)) + return pairs + + def neighbors_success_hints(payload: dict[str, Any]) -> list[tuple[int, str]]: """v4 non-empty neighbors follow-ups (N1a–N7); no graph I/O.""" if not payload.get("success"): @@ -573,11 +627,11 @@ def generate_hints( requested_direction=requested_direction, ) ) - else: - if results and offset == 0: - success_pairs = neighbors_success_hints(payload) - if _any_fuzzy_strategy(results): - meta_pairs.append((PRIORITY_META, TPL_NEIGHBORS_FUZZY_STRATEGY)) + elif results and offset == 0: + success_pairs = neighbors_success_hints(payload) + meta_pairs.extend(neighbors_calls_meta_hints(payload)) + if results and _any_fuzzy_strategy(results): + meta_pairs.append((PRIORITY_META, TPL_NEIGHBORS_FUZZY_STRATEGY)) return finalize_hint_list( _filter_neighbors_dotkey_hints(empty_pairs) + success_pairs + meta_pairs, ) diff --git a/mcp_v2.py b/mcp_v2.py index 8df97f8..c59110b 100644 --- a/mcp_v2.py +++ b/mcp_v2.py @@ -24,12 +24,12 @@ import threading from typing import Annotated, Any, Literal, get_args -from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, ValidationError, validate_call +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, ValidationError, model_validator, validate_call from sentence_transformers import SentenceTransformer from index_common import SBERT_MODEL from java_codebase_rag.config import resolved_sbert_model_for_process_env -from java_ontology import ResolveReason +from java_ontology import EDGE_SCHEMA, ResolveReason from kuzu_queries import KuzuGraph from mcp_hints import MCP_HINTS_FIELD_DESCRIPTION, generate_hints from search_lancedb import TABLES, run_search @@ -132,7 +132,53 @@ class NodeFilter(BaseModel): topic_prefix: str | None = None +class EdgeFilter(BaseModel): + model_config = ConfigDict(extra="forbid") + + min_confidence: float | None = None + exclude_strategies: list[str] | None = None + include_strategies: list[str] | None = None + callee_declaring_role: str | None = None + callee_declaring_roles: list[str] | None = None + exclude_callee_declaring_roles: list[str] | None = None + + @model_validator(mode="after") + def _strategy_axes_mutually_exclusive(self) -> EdgeFilter: + has_include = bool(self.include_strategies) + has_exclude = bool(self.exclude_strategies) + if has_include and has_exclude: + raise ValueError("include_strategies and exclude_strategies are mutually exclusive") + return self + + @model_validator(mode="after") + def _role_axes_mutually_exclusive(self) -> EdgeFilter: + role_axes = ( + self.callee_declaring_role is not None, + bool(self.callee_declaring_roles), + bool(self.exclude_callee_declaring_roles), + ) + if sum(role_axes) > 1: + raise ValueError( + "callee_declaring_role, callee_declaring_roles, and " + "exclude_callee_declaring_roles are mutually exclusive" + ) + return self + + _NODEFILTER_FIELD_ORDER: tuple[str, ...] = tuple(NodeFilter.model_fields.keys()) +_EDGEFILTER_FIELD_ORDER: tuple[str, ...] = tuple(EdgeFilter.model_fields.keys()) + +# Populated EdgeFilter field -> EDGE_SCHEMA attribute name used in Cypher pushdown. +_EDGEFILTER_FIELD_TO_ATTR: dict[str, str] = { + "min_confidence": "confidence", + "exclude_strategies": "strategy", + "include_strategies": "strategy", + "callee_declaring_role": "callee_declaring_role", + "callee_declaring_roles": "callee_declaring_role", + "exclude_callee_declaring_roles": "callee_declaring_role", +} + +_ROLE_FILTER_OTHER_FALLBACK_VALUES = frozenset({"SERVICE", "REPOSITORY"}) _NODEFILTER_APPLICABLE_FIELDS: dict[Literal["symbol", "route", "client", "producer"], tuple[str, ...]] = { "symbol": ( @@ -237,6 +283,80 @@ def _filter_validation_error_message(exc: ValidationError) -> str: return f"Invalid filter: {details}" +def _populated_edgefilter_fields(ef: EdgeFilter) -> set[str]: + populated: set[str] = set() + for field_name in _EDGEFILTER_FIELD_ORDER: + value = getattr(ef, field_name) + if value is None: + continue + if isinstance(value, list) and not value: + continue + populated.add(field_name) + return populated + + +def _edge_schema_attr_names(edge_type: str) -> set[str]: + spec = EDGE_SCHEMA.get(edge_type) + if spec is None: + return set() + return {attr.name for attr in spec.attrs} + + +def _edgefilter_applicability_error(edge_types: list[str], ef: EdgeFilter) -> str | None: + populated = _populated_edgefilter_fields(ef) + if not populated: + return None + flat_types = [et for et in edge_types if et not in _COMPOSED_EDGE_TYPES] + composed = [et for et in edge_types if et in _COMPOSED_EDGE_TYPES] + if composed or flat_types != ["CALLS"]: + parts: list[str] = [] + if flat_types != ["CALLS"]: + parts.append(f"stored labels {flat_types!r}") + if composed: + parts.append(f"composed keys {composed!r}") + detail = " and ".join(parts) if parts else "requested edge_types" + return ( + f"edge_filter requires edge_types=['CALLS'] only; {detail} is not supported — " + "split into separate neighbors calls" + ) + for edge_type in flat_types: + available = _edge_schema_attr_names(edge_type) + for field_name in _EDGEFILTER_FIELD_ORDER: + if field_name not in populated: + continue + attr = _EDGEFILTER_FIELD_TO_ATTR[field_name] + if attr not in available: + return ( + f"{attr} is not on {edge_type}; restrict edge_types to ['CALLS'] " + "or split into two neighbors_v2 calls" + ) + return None + + +def _coerce_edge_filter( + value: EdgeFilter | dict[str, Any] | str | None, +) -> EdgeFilter | dict[str, Any] | None: + """Normalize MCP tool input: weak clients sometimes pass JSON-encoded strings.""" + if value is None or isinstance(value, EdgeFilter): + return value + if isinstance(value, str): + s = value.strip() + if not s: + return None + try: + decoded = json.loads(s) + except json.JSONDecodeError as exc: + raise ValueError(f"edge_filter must be a JSON object; invalid JSON: {exc.msg}") from exc + if decoded is None: + return None + if not isinstance(decoded, dict): + raise ValueError( + f"edge_filter must decode to a JSON object, got {type(decoded).__name__}" + ) + return decoded + return value + + def _coerce_filter( value: NodeFilter | dict[str, Any] | str | None, ) -> NodeFilter | dict[str, Any] | None: @@ -1314,6 +1434,87 @@ def _neighbor_edge_attrs(row: dict[str, Any]) -> dict[str, Any]: } +def _edgefilter_pushdown_kwargs(ef: EdgeFilter | None) -> dict[str, Any]: + if ef is None: + return {} + return { + "min_confidence": ef.min_confidence, + "include_strategies": ef.include_strategies, + "exclude_strategies": ef.exclude_strategies, + "callee_declaring_role": ef.callee_declaring_role, + "callee_declaring_roles": ef.callee_declaring_roles, + "exclude_callee_declaring_roles": ef.exclude_callee_declaring_roles, + } + + +def _rows_to_call_edges( + g: Any, + *, + origin_id: str, + direction: Literal["in", "out"], + rows: list[dict[str, Any]], + nf: NodeFilter | None, +) -> list[Edge]: + edges: list[Edge] = [] + for row in rows: + other_id = str(row.get("other_id") or "") + other_kind = _resolve_node_kind(g, other_id) + other_rec = _load_node_record(g, other_id, other_kind) + if other_rec is None: + continue + if nf and (err := _nodefilter_applicability_error(other_kind, nf)): + _log_fail_loud("applicability") + raise ValueError(err) + if not _node_matches_filter(other_kind, other_rec, nf): + continue + edges.append( + Edge( + origin_id=origin_id, + edge_type=str(row.get("edge_type") or "CALLS"), + direction=direction, + other=_node_ref_from_row(other_kind, other_rec), + attrs=_neighbor_edge_attrs(row), + ) + ) + return edges + + +def _neighbors_calls_for_origin( + g: Any, + origin_id: str, + *, + direction: Literal["in", "out"], + nf: NodeFilter | None, + ef: EdgeFilter | None, + offset: int, + limit: int | None, +) -> list[Edge]: + pushdown = _edgefilter_pushdown_kwargs(ef) + sql_pagination = nf is None and limit is not None + if sql_pagination: + rows = g.neighbor_calls_for_symbol( + origin_id, + direction=direction, + offset=offset, + limit=limit, + sql_pagination=True, + **pushdown, + ) + return _rows_to_call_edges(g, origin_id=origin_id, direction=direction, rows=rows, nf=nf) + rows = g.neighbor_calls_for_symbol( + origin_id, + direction=direction, + offset=0, + limit=None, + sql_pagination=False, + **pushdown, + ) + edges = _rows_to_call_edges(g, origin_id=origin_id, direction=direction, rows=rows, nf=nf) + if limit is None: + return edges + return edges[offset : offset + limit] + + @validate_call(config={"arbitrary_types_allowed": True}) def neighbors_v2( ids: str | list[str], @@ -1324,6 +1525,7 @@ def neighbors_v2( limit: int = 25, offset: int = 0, filter: NodeFilter | dict[str, Any] | str | None = None, + edge_filter: EdgeFilter | dict[str, Any] | str | None = None, graph: Any | None = None, ) -> NeighborsOutput: try: @@ -1347,6 +1549,32 @@ def neighbors_v2( hints=[], requested_edge_types=[], ) + try: + raw_edge_filter = _coerce_edge_filter(edge_filter) + ef = ( + EdgeFilter.model_validate(raw_edge_filter) + if raw_edge_filter is not None and not isinstance(raw_edge_filter, EdgeFilter) + else raw_edge_filter + ) + except ValidationError as exc: + _log_fail_loud("edge_filter") + return NeighborsOutput( + success=False, + message=_filter_validation_error_message(exc), + hints=[], + requested_edge_types=[], + ) + except ValueError as exc: + _log_fail_loud("edge_filter") + return NeighborsOutput(success=False, message=str(exc), hints=[], requested_edge_types=[]) + if ef and (err := _edgefilter_applicability_error(requested_edge_types, ef)): + _log_fail_loud("edge_filter") + return NeighborsOutput( + success=False, + message=err, + hints=[], + requested_edge_types=requested_edge_types, + ) if nf and (err := _validate_no_wildcards(nf)): _log_fail_loud("wildcard") return NeighborsOutput(success=False, message=err, hints=[], requested_edge_types=[]) @@ -1357,8 +1585,10 @@ def neighbors_v2( hints=[], requested_edge_types=requested_edge_types, ) + use_calls_path = flat_labels == ["CALLS"] and not composed_keys origins = [ids] if isinstance(ids, str) else list(ids) results: list[Edge] = [] + unfiltered_calls_count: int | None = None for origin_id in origins: origin_kind = _resolve_node_kind(g, origin_id) if composed_keys: @@ -1382,6 +1612,34 @@ def neighbors_v2( hints=[], requested_edge_types=requested_edge_types, ) + if use_calls_path: + paginate_in_sql = len(origins) == 1 and nf is None + try: + origin_edges = _neighbors_calls_for_origin( + g, + origin_id, + direction=direction, + nf=nf, + ef=ef, + offset=offset if paginate_in_sql else 0, + limit=limit if paginate_in_sql else None, + ) + except ValueError as exc: + return NeighborsOutput( + success=False, + message=str(exc), + hints=[], + requested_edge_types=requested_edge_types, + ) + if ( + ef is not None + and ef.callee_declaring_role in _ROLE_FILTER_OTHER_FALLBACK_VALUES + and not origin_edges + and unfiltered_calls_count is None + ): + unfiltered_calls_count = g.count_calls_for_symbol(origin_id, direction=direction) + results.extend(origin_edges) + continue if flat_labels: # Kuzu 0.11.x can drop `label(e) IN $list` in WHERE; use OR of scalar equalities. label_params = [f"l{i}" for i in range(len(flat_labels))] @@ -1456,7 +1714,10 @@ def neighbors_v2( attrs=_neighbor_edge_attrs(row), ) ) - sliced = results[offset : offset + limit] + if use_calls_path and len(origins) > 1: + sliced = results[offset : offset + limit] + else: + sliced = results if use_calls_path else results[offset : offset + limit] first_origin = origins[0] origin_kind = _resolve_node_kind(g, first_origin) subject_record = _load_node_record(g, first_origin, origin_kind) @@ -1468,6 +1729,9 @@ def neighbors_v2( "offset": offset, "origin_id": first_origin, "subject_record": subject_record, + "node_filter": nf.model_dump(exclude_none=True) if nf else None, + "edge_filter": ef.model_dump(exclude_none=True) if ef else None, + "unfiltered_calls_count": unfiltered_calls_count, } return NeighborsOutput( success=True, diff --git a/server.py b/server.py index f86a724..a1d86ec 100644 --- a/server.py +++ b/server.py @@ -454,7 +454,9 @@ async def describe( "for 2-hop member rollups — out only, with via_id in attrs). OVERRIDDEN_BY* keys are not valid edge_types. " "Optional `filter` applies to each neighbor endpoint row; populated fields must be applicable to that " "neighbor's kind—mixed-kind result sets fail on the first inapplicable neighbor (strict frame). " - "Wildcards in prefix fields are rejected. Unknown NodeFilter keys return success=false. " + "Optional `edge_filter` requires edge_types=['CALLS'] only (no composed dot-keys or extra stored " + "labels); projects the ordered CALLS stream by edge attributes (min_confidence, strategies, " + "callee_declaring_role). Wildcards in prefix fields are rejected. Unknown filter keys return success=false. " "Successful responses echo `requested_edge_types` and may include `hints` (advisory next-step strings; " "empty results may include EDGE_SCHEMA-driven traversal hints). " "Each edge's `attrs.strategy` indicates resolution quality (brownfield/fallback vs primary paths)." @@ -478,7 +480,9 @@ async def neighbors( default=25, ge=1, le=500, - description="Max edges after merge (batch expands all origins first)", + description=( + "Max edges after concatenating all origins (ids order; offset/limit on merged list)" + ), ), offset: int = Field( default=0, @@ -493,6 +497,14 @@ async def neighbors( "Prefer a JSON object; a JSON-encoded string is accepted." ), ), + edge_filter: dict[str, Any] | str | None = Field( + default=None, + description=( + "Optional EdgeFilter on CALLS edge attributes (edge_types=['CALLS'] only). Use " + "callee_declaring_role for callee stereotype projection — not NodeFilter.role on method neighbors. " + "Prefer a JSON object; a JSON-encoded string is accepted." + ), + ), ) -> mcp_v2.NeighborsOutput: return await asyncio.to_thread( mcp_v2.neighbors_v2, @@ -502,6 +514,7 @@ async def neighbors( limit, offset, filter, + edge_filter, None, ) diff --git a/tests/fixtures/perf_baselines.json b/tests/fixtures/perf_baselines.json new file mode 100644 index 0000000..df921a7 --- /dev/null +++ b/tests/fixtures/perf_baselines.json @@ -0,0 +1,6 @@ +{ + "neighbors_calls_empty_filter_client_message_processor": { + "median_sec": 0.2, + "notes": "Reference median (5 runs, pass5 bank-chat Kuzu, ClientMessageProcessor#process empty-filter neighbors limit=500). Test asserts observed median <= 1.5× this value. Re-capture with JAVA_CODEBASE_RAG_RUN_HEAVY=1 when the CALLS path or fixture changes materially." + } +} diff --git a/tests/pinned_ids.py b/tests/pinned_ids.py new file mode 100644 index 0000000..99cfe42 --- /dev/null +++ b/tests/pinned_ids.py @@ -0,0 +1,22 @@ +"""Pinned bank-chat symbol ids shared across MCP regression tests.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +# Bank-chat fixture anchor for CALLS-NOISE perf / ordering tests (HV34, Decision 31). +CLIENT_MESSAGE_PROCESSOR_PROCESS_FQN = ( + "com.bank.chat.engine.processors.ClientMessageProcessor#process(ProcessingContext,InternalEvent)" +) + +if TYPE_CHECKING: + from kuzu_queries import KuzuGraph + + +def client_message_processor_process_id(graph: KuzuGraph) -> str: + rows = graph._rows( # noqa: SLF001 + "MATCH (m:Symbol {fqn: $fqn}) RETURN m.id AS id LIMIT 1", + {"fqn": CLIENT_MESSAGE_PROCESSOR_PROCESS_FQN}, + ) + assert rows, f"missing pinned method {CLIENT_MESSAGE_PROCESSOR_PROCESS_FQN}" + return str(rows[0]["id"]) diff --git a/tests/test_mcp_hints.py b/tests/test_mcp_hints.py index e6ac001..2dd62e5 100644 --- a/tests/test_mcp_hints.py +++ b/tests/test_mcp_hints.py @@ -20,7 +20,16 @@ generate_hints, neighbors_empty_hints, ) -from mcp_v2 import FindOutput, SearchOutput, describe_v2, find_v2, neighbors_v2, resolve_v2, search_v2 +from mcp_v2 import ( + FindOutput, + SearchOutput, + describe_v2, + find_v2, + neighbors_v2, + resolve_v2, + search_v2, +) +from pinned_ids import client_message_processor_process_id _TYPE_KINDS = frozenset({"class", "interface", "enum", "record", "annotation"}) @@ -1665,3 +1674,46 @@ def test_hints_pagination_none_skips_page_derived_hints() -> None: def test_hints_template_rendered_length_leq_120(template: str, fmt: dict[str, Any]) -> None: rendered = template.format(**fmt) if fmt else template assert len(rendered) <= 120, rendered + + +def test_neighbors_calls_other_fallback_hint(kuzu_graph) -> None: + rows = kuzu_graph._rows( # noqa: SLF001 + "MATCH (m:Symbol {kind: 'method'})-[:CALLS]->() " + "WITH m, count(*) AS n WHERE n >= 5 " + "RETURN m.id AS id LIMIT 1", + ) + if not rows: + pytest.skip("no method with >=5 outbound CALLS in bank fixture") + mid = str(rows[0]["id"]) + out = neighbors_v2( + mid, + direction="out", + edge_types=["CALLS"], + edge_filter={"callee_declaring_role": "REPOSITORY"}, + limit=25, + graph=kuzu_graph, + ) + assert out.success is True + assert out.results == [] + assert mcp_hints.TPL_NEIGHBORS_CALLS_ROLE_FILTER_OTHER_FALLBACK in out.hints + + +def test_neighbors_calls_nodefilter_role_collision_hint(kuzu_graph) -> None: + mid = client_message_processor_process_id(kuzu_graph) + out = neighbors_v2( + mid, + direction="out", + edge_types=["CALLS"], + filter={"role": "OTHER"}, + limit=50, + graph=kuzu_graph, + ) + assert out.success is True + assert out.results + method_neighbors = [e for e in out.results if e.other.symbol_kind == "method"] + if len(method_neighbors) < 2: + pytest.skip("need multiple method-kind CALLS neighbors for collision hint") + other_roles = [str(e.other.role or "") for e in method_neighbors] + if sum(1 for r in other_roles if r == "OTHER") < max(1, (len(other_roles) * 3) // 4): + pytest.skip("CALLS neighbors are not dominantly OTHER for this method") + assert mcp_hints.TPL_NEIGHBORS_CALLS_NODEFILTER_ROLE_COLLISION in out.hints diff --git a/tests/test_mcp_v2.py b/tests/test_mcp_v2.py index 9b4a3a4..e6e9454 100644 --- a/tests/test_mcp_v2.py +++ b/tests/test_mcp_v2.py @@ -1,8 +1,13 @@ from __future__ import annotations import asyncio +import json +import os import re +import statistics +import time from collections import Counter +from pathlib import Path from typing import Any import pytest @@ -21,6 +26,7 @@ resolve_v2, search_v2, ) +from pinned_ids import client_message_processor_process_id _PR2_CHAIN_SEARCH_DESCRIBE = re.compile(r"search\(query=.*\).*describe") _PR2_SENTINEL_PATTERNS: tuple[re.Pattern[str], ...] = ( @@ -1303,3 +1309,180 @@ def test_resolve_success_output_invariants(kuzu_graph, kuzu_graph_fqn_collision_ assert single.candidates == [] +_PERF_BASELINES_PATH = ( + Path(__file__).resolve().parent / "fixtures" / "perf_baselines.json" +) + + +def test_neighbors_calls_ordered_by_call_site(kuzu_graph) -> None: + mid = client_message_processor_process_id(kuzu_graph) + out = neighbors_v2(mid, direction="out", edge_types=["CALLS"], limit=500, graph=kuzu_graph) + assert out.success is True + assert len(out.results) >= 2 + sites = [ + (int(e.attrs.get("call_site_line") or 0), int(e.attrs.get("call_site_byte") or 0)) + for e in out.results + ] + assert sites == sorted(sites) + + +def test_neighbors_calls_edge_filter_callee_declaring_role(kuzu_graph) -> None: + mid = client_message_processor_process_id(kuzu_graph) + out = neighbors_v2( + mid, + direction="out", + edge_types=["CALLS"], + edge_filter={"callee_declaring_role": "SERVICE"}, + limit=500, + graph=kuzu_graph, + ) + assert out.success is True + assert out.results + for edge in out.results: + assert edge.attrs.get("callee_declaring_role") == "SERVICE" + + +def test_neighbors_calls_edge_filter_pushdown_in_cypher(kuzu_graph, monkeypatch) -> None: + mid = _method_id_with_calls(kuzu_graph, "out") + captured: list[str] = [] + orig_rows = kuzu_graph._rows + + def _capture_rows(query: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]: + captured.append(query) + return orig_rows(query, params) + + monkeypatch.setattr(kuzu_graph, "_rows", _capture_rows) + out = neighbors_v2( + mid, + direction="out", + edge_types=["CALLS"], + edge_filter={"callee_declaring_role": "SERVICE", "min_confidence": 0.5}, + graph=kuzu_graph, + ) + assert out.success is True + calls_queries = [q for q in captured if "ORDER BY e.call_site_line" in q] + assert calls_queries + q = calls_queries[0] + assert "callee_declaring_role" in q + assert "confidence" in q + + +def test_neighbors_calls_edge_filter_before_limit(kuzu_graph) -> None: + mid = client_message_processor_process_id(kuzu_graph) + unfiltered = neighbors_v2( + mid, direction="out", edge_types=["CALLS"], limit=500, graph=kuzu_graph + ) + assert unfiltered.success is True + non_other_total = sum( + 1 for e in unfiltered.results if e.attrs.get("callee_declaring_role") != "OTHER" + ) + assert non_other_total >= 6 + unfiltered_cap = neighbors_v2( + mid, direction="out", edge_types=["CALLS"], limit=5, graph=kuzu_graph + ) + assert unfiltered_cap.success is True + assert len(unfiltered_cap.results) == 5 + other_in_cap = sum( + 1 for e in unfiltered_cap.results if e.attrs.get("callee_declaring_role") == "OTHER" + ) + filtered = neighbors_v2( + mid, + direction="out", + edge_types=["CALLS"], + edge_filter={"exclude_callee_declaring_roles": ["OTHER"]}, + limit=5, + graph=kuzu_graph, + ) + assert filtered.success is True + assert len(filtered.results) == 5 + assert all(e.attrs.get("callee_declaring_role") != "OTHER" for e in filtered.results) + assert other_in_cap >= 1 + + +def test_neighbors_calls_edge_filter_mixed_types_fail_loud(kuzu_graph) -> None: + mid = _method_id_with_calls(kuzu_graph, "out") + out = neighbors_v2( + mid, + direction="out", + edge_types=["CALLS", "OVERRIDES"], + edge_filter={"callee_declaring_role": "SERVICE"}, + graph=kuzu_graph, + ) + assert out.success is False + assert out.message + assert "edge_types=['CALLS']" in out.message + assert "OVERRIDES" in out.message + + +def test_neighbors_calls_edge_filter_composed_types_fail_loud(kuzu_graph) -> None: + rows = kuzu_graph._rows( # noqa: SLF001 + "MATCH (t:Symbol)-[:DECLARES]->(m:Symbol)-[e:EXPOSES]->(:Route) " + "WHERE t.role = 'CONTROLLER' AND t.kind = 'class' " + "RETURN t.id AS id LIMIT 1", + ) + assert rows + tid = str(rows[0]["id"]) + out = neighbors_v2( + tid, + direction="out", + edge_types=["CALLS", "DECLARES.EXPOSES"], + edge_filter={"callee_declaring_role": "SERVICE"}, + graph=kuzu_graph, + ) + assert out.success is False + assert out.message + assert "edge_types=['CALLS']" in out.message + assert "DECLARES.EXPOSES" in out.message + + +def test_neighbors_calls_edge_filter_role_axes_xor(kuzu_graph) -> None: + mid = _method_id_with_calls(kuzu_graph, "out") + out = neighbors_v2( + mid, + direction="out", + edge_types=["CALLS"], + edge_filter={ + "callee_declaring_role": "SERVICE", + "exclude_callee_declaring_roles": ["OTHER"], + }, + graph=kuzu_graph, + ) + assert out.success is False + assert out.message + assert "mutually exclusive" in out.message.lower() + + +def test_neighbors_calls_edge_filter_strategy_xor(kuzu_graph) -> None: + mid = _method_id_with_calls(kuzu_graph, "out") + out = neighbors_v2( + mid, + direction="out", + edge_types=["CALLS"], + edge_filter={"include_strategies": ["exact"], "exclude_strategies": ["phantom"]}, + graph=kuzu_graph, + ) + assert out.success is False + assert out.message + assert "mutually exclusive" in out.message.lower() + + +@pytest.mark.skipif( + os.environ.get("JAVA_CODEBASE_RAG_RUN_HEAVY", "").strip() != "1", + reason="perf gate; set JAVA_CODEBASE_RAG_RUN_HEAVY=1", +) +def test_neighbors_calls_perf_empty_filter_client_message_processor(kuzu_graph) -> None: + mid = client_message_processor_process_id(kuzu_graph) + baseline = json.loads(_PERF_BASELINES_PATH.read_text())[ + "neighbors_calls_empty_filter_client_message_processor" + ]["median_sec"] + times: list[float] = [] + for _ in range(5): + t0 = time.perf_counter() + out = neighbors_v2(mid, direction="out", edge_types=["CALLS"], limit=500, graph=kuzu_graph) + times.append(time.perf_counter() - t0) + assert out.success is True + assert out.results + median_sec = statistics.median(times) + assert median_sec <= float(baseline) * 1.5 + +