diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index f0b9ab86..b4a64793 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -1,4 +1,6 @@ import binascii +import uuid +import weakref from collections.abc import Callable, Sequence from enum import Enum from typing import TYPE_CHECKING, Any, TypeVar @@ -56,6 +58,85 @@ def _data_numpy_to_native(data: dict[str, Any]) -> None: data[k] = v.item() +# Module-level (not methods) so they can be registered with ``weakref.finalize`` +# without holding a bound reference to the owning object, which would prevent +# it from ever being collected. +def _drop_scratch_tables(engine: sa.Engine, tables: list[sa.Table]) -> None: + """Drop scratch tables, swallowing errors (e.g. at interpreter shutdown).""" + for table in tables: + try: + table.drop(engine) + except Exception as exc: + LOG.debug("Failed to drop scratch table %s: %s", table.name, exc) + + +class _SqlIdSet: + """A set of ids usable in SQL ``IN`` clauses without overflowing bind limits. + + Small sets compile to inline ``col.in_([...])``; larger sets are materialized + into a per-instance scratch table and matched via ``col.in_(SELECT id FROM + scratch)``. The same ``_SqlIdSet`` may be reused against multiple columns — + each call to :meth:`in_clause` emits a fresh expression backed by the same + underlying ids. + + ``occurrences`` is the maximum number of times the id set will be expanded + in a single compiled statement (e.g. filtering both ``source_id`` and + ``target_id`` of an edge table counts as 2). The scratch-table cutoff is + divided by it so that ``len(ids) * occurrences`` stays safely under the + backend's bound-variable limit. + """ + + def __init__( + self, + graph: "SQLGraph", + ids: Sequence[int], + *, + occurrences: int = 1, + ) -> None: + if hasattr(ids, "tolist"): + ids = ids.tolist() + self._ids: list[int] = list(ids) + self._graph = graph + + limit = max(1, graph._sql_chunk_size() // max(1, occurrences)) + if len(self._ids) > limit: + self._scratch: sa.Table | None = graph._create_id_scratch_table(self._ids) + else: + self._scratch = None + + @property + def ids(self) -> list[int]: + return self._ids + + @property + def uses_scratch_table(self) -> bool: + return self._scratch is not None + + def in_clause(self, column: sa.ColumnElement) -> sa.ColumnElement[bool]: + if self._scratch is None: + return column.in_(self._ids) + return column.in_(sa.select(self._scratch.c.id)) + + def close(self) -> None: + if self._scratch is not None: + _drop_scratch_tables(self._graph._engine, [self._scratch]) + self._scratch = None + + def __enter__(self) -> "_SqlIdSet": + return self + + def __exit__(self, *exc: object) -> None: + self.close() + + +def _close_id_sets(id_sets: list[_SqlIdSet]) -> None: + for id_set in id_sets: + try: + id_set.close() + except Exception as exc: + LOG.debug("Failed to close _SqlIdSet: %s", exc) + + def _filter_query( query: sa.Select, table: type[DeclarativeBase], @@ -99,6 +180,7 @@ def __init__( self._node_attr_comps, self._edge_attr_comps = split_attr_comps(attr_filters) self._include_targets = include_targets self._include_sources = include_sources + self._id_sets: list[_SqlIdSet] = [] # creating initial query self._node_query: sa.Select = sa.select(self._graph.Node) @@ -106,18 +188,20 @@ def __init__( node_filtered = False if node_ids is not None: - if hasattr(node_ids, "tolist"): - node_ids = node_ids.tolist() - - self._node_query = self._node_query.filter(self._graph.Node.node_id.in_(node_ids)) + # The id set is expanded once for Node.node_id, plus once each for + # Edge.target_id / Edge.source_id when those filters are not + # suppressed by ``include_targets`` / ``include_sources``. The + # scratch-table cutoff is divided accordingly so that the total + # number of bound variables stays under the backend's limit. + occurrences = 1 + (not self._include_targets) + (not self._include_sources) + id_set = _SqlIdSet(self._graph, node_ids, occurrences=occurrences) + self._id_sets.append(id_set) + + self._node_query = self._node_query.filter(id_set.in_clause(self._graph.Node.node_id)) if not self._include_targets: - self._edge_query = self._edge_query.filter( - self._graph.Edge.target_id.in_(node_ids), - ) + self._edge_query = self._edge_query.filter(id_set.in_clause(self._graph.Edge.target_id)) if not self._include_sources: - self._edge_query = self._edge_query.filter( - self._graph.Edge.source_id.in_(node_ids), - ) + self._edge_query = self._edge_query.filter(id_set.in_clause(self._graph.Edge.source_id)) node_filtered = True if self._node_attr_comps: @@ -182,6 +266,13 @@ def __init__( self._node_query = sa.union(*nodes_query) + if self._uses_scratch_tables(): + weakref.finalize(self, _close_id_sets, self._id_sets) + + def _uses_scratch_tables(self) -> bool: + """Whether any id set backing this filter materialized a scratch table.""" + return any(id_set.uses_scratch_table for id_set in self._id_sets) + @cache_method def node_ids(self) -> list[int]: """ @@ -1092,19 +1183,18 @@ def overlaps( """ Get the overlaps between the nodes in `node_ids`. """ - if hasattr(node_ids, "tolist"): - node_ids = node_ids.tolist() - with Session(self._engine) as session: query = session.query(self.Overlap.source_id, self.Overlap.target_id) - if node_ids is not None: + if node_ids is None: + return [[source_id, target_id] for source_id, target_id in query.all()] + + with _SqlIdSet(self, node_ids, occurrences=2) as id_set: query = query.filter( - self.Overlap.source_id.in_(node_ids), - self.Overlap.target_id.in_(node_ids), + id_set.in_clause(self.Overlap.source_id), + id_set.in_clause(self.Overlap.target_id), ) - - return [[source_id, target_id] for source_id, target_id in query.all()] + return [[source_id, target_id] for source_id, target_id in query.all()] def has_overlaps(self) -> bool: """ @@ -1794,6 +1884,33 @@ def _sql_chunk_size(self) -> int: return chunk_size + def _create_id_scratch_table(self, ids: Sequence[int]) -> sa.Table: + """Create a uniquely-named helper table holding ``ids``. + + Used to work around SQL bound-variable limits when filtering by large + ``IN (...)`` lists: callers replace ``col.in_(ids)`` with + ``col.in_(sa.select(table.c.id))``. The caller owns the returned table + and is responsible for dropping it. + """ + unique_ids = list({int(v) for v in ids}) + + name = f"_tracksdata_ids_{uuid.uuid4().hex}" + table = sa.Table( + name, + sa.MetaData(), + sa.Column("id", sa.BigInteger, primary_key=True), + ) + table.create(self._engine) + + chunk_size = max(1, self._sql_chunk_size()) + with self._engine.begin() as conn: + for i in range(0, len(unique_ids), chunk_size): + conn.execute( + table.insert(), + [{"id": v} for v in unique_ids[i : i + chunk_size]], + ) + return table + def _update_table( self, table_class: type[DeclarativeBase], @@ -2009,18 +2126,18 @@ def _get_degree( return int(session.execute(stmt).scalar()) stmt = sa.select(edge_key_col, sa.func.count()).group_by(edge_key_col) - if node_ids is not None: - stmt = stmt.where(edge_key_col.in_(node_ids)) - - with Session(self._engine) as session: - # get the number of edges for each using group by and count - degree = dict(session.execute(stmt).all()) if node_ids is None: - # this is necessary to make sure it's the same order as node_ids + with Session(self._engine) as session: + degree = dict(session.execute(stmt).all()) + # preserve the canonical node ordering return [degree.get(node_id, 0) for node_id in self.node_ids()] - return [degree.get(node_id, 0) for node_id in node_ids] + with _SqlIdSet(self, node_ids, occurrences=1) as id_set: + stmt = stmt.where(id_set.in_clause(edge_key_col)) + with Session(self._engine) as session: + degree = dict(session.execute(stmt).all()) + return [degree.get(node_id, 0) for node_id in id_set.ids] def in_degree(self, node_ids: list[int] | int | None = None) -> list[int] | int: """ diff --git a/src/tracksdata/graph/_test/test_subgraph.py b/src/tracksdata/graph/_test/test_subgraph.py index 96cbfa1a..89484803 100644 --- a/src/tracksdata/graph/_test/test_subgraph.py +++ b/src/tracksdata/graph/_test/test_subgraph.py @@ -1,3 +1,4 @@ +import itertools import re from collections.abc import Callable from contextlib import contextmanager @@ -1302,3 +1303,89 @@ def test_edge_list(graph_backend: BaseGraph, use_subgraph: bool) -> None: ) ) assert edge_list == expected_edge_list + + +def _build_chain_graph(graph: SQLGraph, n_nodes: int) -> list[int]: + node_ids: list[int] = [] + for t in range(n_nodes): + node_ids.append(graph.add_node({DEFAULT_ATTR_KEYS.T: t})) + for src, tgt in itertools.pairwise(node_ids): + graph.add_edge(src, tgt, {}) + graph.add_overlap(node_ids[0], node_ids[1]) + graph.add_overlap(node_ids[2], node_ids[3]) + return node_ids + + +def _scratch_table_count(graph: SQLGraph) -> int: + """Count leftover ``_tracksdata_ids_*`` scratch tables in a SQLite graph.""" + import sqlalchemy as sa + + with graph._engine.connect() as conn: + return conn.execute( + sa.text("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name LIKE '_tracksdata_ids_%'") + ).scalar() + + +def test_sql_graph_filter_large_node_ids(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: + """Filtering with more ids than SQLite's variable limit must not raise. + + Reproduces the ``OperationalError: too many SQL variables`` failure by + forcing the scratch-table code path via a tiny chunk size. + """ + graph = SQLGraph("sqlite", str(tmp_path / "scratch.db")) + n_nodes = 40 + node_ids = _build_chain_graph(graph, n_nodes) + + # Force scratch-table path on every call site. + monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 4) + + # Context-manager paths (overlaps, _get_degree) must drop their scratch + # tables once the block exits — the count should return to baseline after + # each call regardless of whether the scratch path fired inside. + assert _scratch_table_count(graph) == 0 + in_deg = graph.in_degree(node_ids) + assert _scratch_table_count(graph) == 0 + out_deg = graph.out_degree(node_ids) + assert _scratch_table_count(graph) == 0 + overlaps = graph.overlaps(node_ids) + assert _scratch_table_count(graph) == 0 + + assert sum(in_deg) == n_nodes - 1 + assert sum(out_deg) == n_nodes - 1 + assert sorted(map(tuple, overlaps)) == sorted([(node_ids[0], node_ids[1]), (node_ids[2], node_ids[3])]) + + filtered = graph.filter(node_ids=node_ids) + # Confirm the scratch-table code path was taken rather than raw IN (...). + assert filtered._uses_scratch_tables() + subgraph = filtered.subgraph() + assert subgraph.num_nodes() == n_nodes + assert subgraph.num_edges() == n_nodes - 1 + + +def test_sql_graph_filter_borderline_node_ids(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: + """The scratch cutoff must account for how many times ids appear per statement. + + With ``_sql_chunk_size() == 12`` and ``SQLFilter`` using ``occurrences=3``, + a list of 5 ids would compile to ~15 bound variables — above the limit — + even though ``len(node_ids) <= chunk_size``. The helper must still switch + to the scratch-table path in that band. + """ + graph = SQLGraph("sqlite", str(tmp_path / "scratch.db")) + n_nodes = 5 + node_ids = _build_chain_graph(graph, n_nodes) + + monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 12) + + filtered = graph.filter(node_ids=node_ids) + # 5 ids fits under chunk_size=12 inline, but with occurrences=3 the + # effective cutoff is 12 // 3 == 4, so scratch must kick in. + assert filtered._uses_scratch_tables() + subgraph = filtered.subgraph() + assert subgraph.num_nodes() == n_nodes + assert subgraph.num_edges() == n_nodes - 1 + + # overlaps() uses occurrences=2 → cutoff 6, so len==5 stays inline. + # Still assert it returns the right data regardless of path. + assert sorted(map(tuple, graph.overlaps(node_ids))) == sorted( + [(node_ids[0], node_ids[1]), (node_ids[2], node_ids[3])] + )