From 070e20d8206623015f1c519fbe4c5b00c06a1a5b Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 14 Apr 2026 10:47:45 +0900 Subject: [PATCH 1/3] working version --- src/tracksdata/graph/_sql_graph.py | 125 ++++++++++++++++---- src/tracksdata/graph/_test/test_subgraph.py | 37 ++++++ 2 files changed, 140 insertions(+), 22 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index f0b9ab86..eaa8512c 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -1,4 +1,6 @@ import binascii +import uuid +import weakref from collections.abc import Callable, Sequence from enum import Enum from typing import TYPE_CHECKING, Any, TypeVar @@ -56,6 +58,15 @@ def _data_numpy_to_native(data: dict[str, Any]) -> None: data[k] = v.item() +def _drop_scratch_tables(engine: sa.Engine, tables: list[sa.Table]) -> None: + """Drop scratch tables, swallowing errors (e.g. at interpreter shutdown).""" + for table in tables: + try: + table.drop(engine) + except Exception as exc: + LOG.debug("Failed to drop scratch table %s: %s", table.name, exc) + + def _filter_query( query: sa.Select, table: type[DeclarativeBase], @@ -99,6 +110,7 @@ def __init__( self._node_attr_comps, self._edge_attr_comps = split_attr_comps(attr_filters) self._include_targets = include_targets self._include_sources = include_sources + self._scratch_tables: list[sa.Table] = [] # creating initial query self._node_query: sa.Select = sa.select(self._graph.Node) @@ -109,15 +121,24 @@ def __init__( if hasattr(node_ids, "tolist"): node_ids = node_ids.tolist() - self._node_query = self._node_query.filter(self._graph.Node.node_id.in_(node_ids)) + # Large IN lists hit SQL bound-variable limits (e.g. SQLite's + # SQLITE_MAX_VARIABLE_NUMBER). Materialize into a scratch table and + # filter via a subquery instead. + scratch: sa.Table | None = None + if len(node_ids) > self._graph._sql_chunk_size(): + scratch = self._graph._create_id_scratch_table(node_ids) + self._scratch_tables.append(scratch) + + def _in_ids(column: sa.Column) -> sa.ColumnElement[bool]: + if scratch is None: + return column.in_(node_ids) + return column.in_(sa.select(scratch.c.id)) + + self._node_query = self._node_query.filter(_in_ids(self._graph.Node.node_id)) if not self._include_targets: - self._edge_query = self._edge_query.filter( - self._graph.Edge.target_id.in_(node_ids), - ) + self._edge_query = self._edge_query.filter(_in_ids(self._graph.Edge.target_id)) if not self._include_sources: - self._edge_query = self._edge_query.filter( - self._graph.Edge.source_id.in_(node_ids), - ) + self._edge_query = self._edge_query.filter(_in_ids(self._graph.Edge.source_id)) node_filtered = True if self._node_attr_comps: @@ -182,6 +203,14 @@ def __init__( self._node_query = sa.union(*nodes_query) + if self._scratch_tables: + weakref.finalize( + self, + _drop_scratch_tables, + self._graph._engine, + self._scratch_tables, + ) + @cache_method def node_ids(self) -> list[int]: """ @@ -1095,16 +1124,28 @@ def overlaps( if hasattr(node_ids, "tolist"): node_ids = node_ids.tolist() - with Session(self._engine) as session: - query = session.query(self.Overlap.source_id, self.Overlap.target_id) - - if node_ids is not None: - query = query.filter( - self.Overlap.source_id.in_(node_ids), - self.Overlap.target_id.in_(node_ids), - ) - - return [[source_id, target_id] for source_id, target_id in query.all()] + scratch: sa.Table | None = None + try: + with Session(self._engine) as session: + query = session.query(self.Overlap.source_id, self.Overlap.target_id) + + if node_ids is not None: + if len(node_ids) > self._sql_chunk_size(): + scratch = self._create_id_scratch_table(node_ids) + query = query.filter( + self.Overlap.source_id.in_(sa.select(scratch.c.id)), + self.Overlap.target_id.in_(sa.select(scratch.c.id)), + ) + else: + query = query.filter( + self.Overlap.source_id.in_(node_ids), + self.Overlap.target_id.in_(node_ids), + ) + + return [[source_id, target_id] for source_id, target_id in query.all()] + finally: + if scratch is not None: + _drop_scratch_tables(self._engine, [scratch]) def has_overlaps(self) -> bool: """ @@ -1794,6 +1835,35 @@ def _sql_chunk_size(self) -> int: return chunk_size + def _create_id_scratch_table(self, ids: Sequence[int]) -> sa.Table: + """Create a uniquely-named helper table holding ``ids``. + + Used to work around SQL bound-variable limits when filtering by large + ``IN (...)`` lists: callers replace ``col.in_(ids)`` with + ``col.in_(sa.select(table.c.id))``. The caller owns the returned table + and is responsible for dropping it. + """ + if hasattr(ids, "tolist"): + ids = ids.tolist() + unique_ids = list({int(v) for v in ids}) + + name = f"_tracksdata_ids_{uuid.uuid4().hex}" + table = sa.Table( + name, + sa.MetaData(), + sa.Column("id", sa.BigInteger, primary_key=True), + ) + table.create(self._engine) + + chunk_size = max(1, self._sql_chunk_size()) + with self._engine.begin() as conn: + for i in range(0, len(unique_ids), chunk_size): + conn.execute( + table.insert(), + [{"id": v} for v in unique_ids[i : i + chunk_size]], + ) + return table + def _update_table( self, table_class: type[DeclarativeBase], @@ -2009,12 +2079,23 @@ def _get_degree( return int(session.execute(stmt).scalar()) stmt = sa.select(edge_key_col, sa.func.count()).group_by(edge_key_col) - if node_ids is not None: - stmt = stmt.where(edge_key_col.in_(node_ids)) + scratch: sa.Table | None = None + try: + if node_ids is not None: + if hasattr(node_ids, "tolist"): + node_ids = node_ids.tolist() + if len(node_ids) > self._sql_chunk_size(): + scratch = self._create_id_scratch_table(node_ids) + stmt = stmt.where(edge_key_col.in_(sa.select(scratch.c.id))) + else: + stmt = stmt.where(edge_key_col.in_(node_ids)) - with Session(self._engine) as session: - # get the number of edges for each using group by and count - degree = dict(session.execute(stmt).all()) + with Session(self._engine) as session: + # get the number of edges for each using group by and count + degree = dict(session.execute(stmt).all()) + finally: + if scratch is not None: + _drop_scratch_tables(self._engine, [scratch]) if node_ids is None: # this is necessary to make sure it's the same order as node_ids diff --git a/src/tracksdata/graph/_test/test_subgraph.py b/src/tracksdata/graph/_test/test_subgraph.py index 96cbfa1a..6ef0d71d 100644 --- a/src/tracksdata/graph/_test/test_subgraph.py +++ b/src/tracksdata/graph/_test/test_subgraph.py @@ -1,3 +1,4 @@ +import itertools import re from collections.abc import Callable from contextlib import contextmanager @@ -1302,3 +1303,39 @@ def test_edge_list(graph_backend: BaseGraph, use_subgraph: bool) -> None: ) ) assert edge_list == expected_edge_list + + +def test_sql_graph_filter_large_node_ids(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: + """Filtering with more ids than SQLite's variable limit must not raise. + + Reproduces the ``OperationalError: too many SQL variables`` failure by + forcing the scratch-table code path via a tiny chunk size. + """ + graph = SQLGraph("sqlite", str(tmp_path / "scratch.db")) + + n_nodes = 40 + node_ids: list[int] = [] + for t in range(n_nodes): + node_ids.append(graph.add_node({DEFAULT_ATTR_KEYS.T: t})) + for src, tgt in itertools.pairwise(node_ids): + graph.add_edge(src, tgt, {}) + graph.add_overlap(node_ids[0], node_ids[1]) + graph.add_overlap(node_ids[2], node_ids[3]) + + # Force scratch-table path: all three call sites gate on this size. + monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 4) + + filtered = graph.filter(node_ids=node_ids) + # Confirm the scratch-table code path was taken rather than raw IN (...). + assert filtered._scratch_tables + subgraph = filtered.subgraph() + assert subgraph.num_nodes() == n_nodes + assert subgraph.num_edges() == n_nodes - 1 + + in_deg = graph.in_degree(node_ids) + out_deg = graph.out_degree(node_ids) + assert sum(in_deg) == n_nodes - 1 + assert sum(out_deg) == n_nodes - 1 + + overlaps = graph.overlaps(node_ids) + assert sorted(map(tuple, overlaps)) == sorted([(node_ids[0], node_ids[1]), (node_ids[2], node_ids[3])]) From c43ac58d3f1f7497b705d9bf8beb776e091db5f8 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 14 Apr 2026 11:47:42 +0900 Subject: [PATCH 2/3] further fix --- src/tracksdata/graph/_sql_graph.py | 164 ++++++++++++-------- src/tracksdata/graph/_test/test_subgraph.py | 53 +++++-- 2 files changed, 140 insertions(+), 77 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index eaa8512c..6657a2fa 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -67,6 +67,73 @@ def _drop_scratch_tables(engine: sa.Engine, tables: list[sa.Table]) -> None: LOG.debug("Failed to drop scratch table %s: %s", table.name, exc) +class _SqlIdSet: + """A set of ids usable in SQL ``IN`` clauses without overflowing bind limits. + + Small sets compile to inline ``col.in_([...])``; larger sets are materialized + into a per-instance scratch table and matched via ``col.in_(SELECT id FROM + scratch)``. The same ``_SqlIdSet`` may be reused against multiple columns — + each call to :meth:`in_clause` emits a fresh expression backed by the same + underlying ids. + + ``occurrences`` is the maximum number of times the id set will be expanded + in a single compiled statement (e.g. filtering both ``source_id`` and + ``target_id`` of an edge table counts as 2). The scratch-table cutoff is + scaled by it so that ``len(ids) * occurrences`` stays safely under the + backend's bound-variable limit. + """ + + def __init__( + self, + graph: "SQLGraph", + ids: Sequence[int], + *, + occurrences: int = 1, + ) -> None: + if hasattr(ids, "tolist"): + ids = ids.tolist() + self._ids: list[int] = list(ids) + self._graph = graph + + limit = max(1, graph._sql_chunk_size() // max(1, occurrences)) + if len(self._ids) > limit: + self._scratch: sa.Table | None = graph._create_id_scratch_table(self._ids) + else: + self._scratch = None + + @property + def ids(self) -> list[int]: + return self._ids + + @property + def uses_scratch_table(self) -> bool: + return self._scratch is not None + + def in_clause(self, column: sa.ColumnElement) -> sa.ColumnElement[bool]: + if self._scratch is None: + return column.in_(self._ids) + return column.in_(sa.select(self._scratch.c.id)) + + def close(self) -> None: + if self._scratch is not None: + _drop_scratch_tables(self._graph._engine, [self._scratch]) + self._scratch = None + + def __enter__(self) -> "_SqlIdSet": + return self + + def __exit__(self, *exc: object) -> None: + self.close() + + +def _close_id_sets(id_sets: list[_SqlIdSet]) -> None: + for id_set in id_sets: + try: + id_set.close() + except Exception as exc: + LOG.debug("Failed to close _SqlIdSet: %s", exc) + + def _filter_query( query: sa.Select, table: type[DeclarativeBase], @@ -110,7 +177,7 @@ def __init__( self._node_attr_comps, self._edge_attr_comps = split_attr_comps(attr_filters) self._include_targets = include_targets self._include_sources = include_sources - self._scratch_tables: list[sa.Table] = [] + self._id_sets: list[_SqlIdSet] = [] # creating initial query self._node_query: sa.Select = sa.select(self._graph.Node) @@ -118,27 +185,19 @@ def __init__( node_filtered = False if node_ids is not None: - if hasattr(node_ids, "tolist"): - node_ids = node_ids.tolist() - - # Large IN lists hit SQL bound-variable limits (e.g. SQLite's - # SQLITE_MAX_VARIABLE_NUMBER). Materialize into a scratch table and - # filter via a subquery instead. - scratch: sa.Table | None = None - if len(node_ids) > self._graph._sql_chunk_size(): - scratch = self._graph._create_id_scratch_table(node_ids) - self._scratch_tables.append(scratch) - - def _in_ids(column: sa.Column) -> sa.ColumnElement[bool]: - if scratch is None: - return column.in_(node_ids) - return column.in_(sa.select(scratch.c.id)) - - self._node_query = self._node_query.filter(_in_ids(self._graph.Node.node_id)) + # ``node_ids`` may be expanded in up to three places within a + # single compiled statement (Node.node_id, Edge.source_id, + # Edge.target_id — the node-attr-filter path at L159-L160 inlines + # the edge query as a subquery), so the scratch-table cutoff is + # scaled accordingly. + id_set = _SqlIdSet(self._graph, node_ids, occurrences=3) + self._id_sets.append(id_set) + + self._node_query = self._node_query.filter(id_set.in_clause(self._graph.Node.node_id)) if not self._include_targets: - self._edge_query = self._edge_query.filter(_in_ids(self._graph.Edge.target_id)) + self._edge_query = self._edge_query.filter(id_set.in_clause(self._graph.Edge.target_id)) if not self._include_sources: - self._edge_query = self._edge_query.filter(_in_ids(self._graph.Edge.source_id)) + self._edge_query = self._edge_query.filter(id_set.in_clause(self._graph.Edge.source_id)) node_filtered = True if self._node_attr_comps: @@ -203,13 +262,8 @@ def _in_ids(column: sa.Column) -> sa.ColumnElement[bool]: self._node_query = sa.union(*nodes_query) - if self._scratch_tables: - weakref.finalize( - self, - _drop_scratch_tables, - self._graph._engine, - self._scratch_tables, - ) + if any(id_set.uses_scratch_table for id_set in self._id_sets): + weakref.finalize(self, _close_id_sets, self._id_sets) @cache_method def node_ids(self) -> list[int]: @@ -1121,31 +1175,18 @@ def overlaps( """ Get the overlaps between the nodes in `node_ids`. """ - if hasattr(node_ids, "tolist"): - node_ids = node_ids.tolist() + with Session(self._engine) as session: + query = session.query(self.Overlap.source_id, self.Overlap.target_id) - scratch: sa.Table | None = None - try: - with Session(self._engine) as session: - query = session.query(self.Overlap.source_id, self.Overlap.target_id) - - if node_ids is not None: - if len(node_ids) > self._sql_chunk_size(): - scratch = self._create_id_scratch_table(node_ids) - query = query.filter( - self.Overlap.source_id.in_(sa.select(scratch.c.id)), - self.Overlap.target_id.in_(sa.select(scratch.c.id)), - ) - else: - query = query.filter( - self.Overlap.source_id.in_(node_ids), - self.Overlap.target_id.in_(node_ids), - ) + if node_ids is None: + return [[source_id, target_id] for source_id, target_id in query.all()] + with _SqlIdSet(self, node_ids, occurrences=2) as id_set: + query = query.filter( + id_set.in_clause(self.Overlap.source_id), + id_set.in_clause(self.Overlap.target_id), + ) return [[source_id, target_id] for source_id, target_id in query.all()] - finally: - if scratch is not None: - _drop_scratch_tables(self._engine, [scratch]) def has_overlaps(self) -> bool: """ @@ -2079,29 +2120,18 @@ def _get_degree( return int(session.execute(stmt).scalar()) stmt = sa.select(edge_key_col, sa.func.count()).group_by(edge_key_col) - scratch: sa.Table | None = None - try: - if node_ids is not None: - if hasattr(node_ids, "tolist"): - node_ids = node_ids.tolist() - if len(node_ids) > self._sql_chunk_size(): - scratch = self._create_id_scratch_table(node_ids) - stmt = stmt.where(edge_key_col.in_(sa.select(scratch.c.id))) - else: - stmt = stmt.where(edge_key_col.in_(node_ids)) + if node_ids is None: with Session(self._engine) as session: - # get the number of edges for each using group by and count degree = dict(session.execute(stmt).all()) - finally: - if scratch is not None: - _drop_scratch_tables(self._engine, [scratch]) - - if node_ids is None: - # this is necessary to make sure it's the same order as node_ids + # preserve the canonical node ordering return [degree.get(node_id, 0) for node_id in self.node_ids()] - return [degree.get(node_id, 0) for node_id in node_ids] + with _SqlIdSet(self, node_ids, occurrences=1) as id_set: + stmt = stmt.where(id_set.in_clause(edge_key_col)) + with Session(self._engine) as session: + degree = dict(session.execute(stmt).all()) + return [degree.get(node_id, 0) for node_id in id_set.ids] def in_degree(self, node_ids: list[int] | int | None = None) -> list[int] | int: """ diff --git a/src/tracksdata/graph/_test/test_subgraph.py b/src/tracksdata/graph/_test/test_subgraph.py index 6ef0d71d..8b781f50 100644 --- a/src/tracksdata/graph/_test/test_subgraph.py +++ b/src/tracksdata/graph/_test/test_subgraph.py @@ -1305,6 +1305,17 @@ def test_edge_list(graph_backend: BaseGraph, use_subgraph: bool) -> None: assert edge_list == expected_edge_list +def _build_chain_graph(graph: SQLGraph, n_nodes: int) -> list[int]: + node_ids: list[int] = [] + for t in range(n_nodes): + node_ids.append(graph.add_node({DEFAULT_ATTR_KEYS.T: t})) + for src, tgt in itertools.pairwise(node_ids): + graph.add_edge(src, tgt, {}) + graph.add_overlap(node_ids[0], node_ids[1]) + graph.add_overlap(node_ids[2], node_ids[3]) + return node_ids + + def test_sql_graph_filter_large_node_ids(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: """Filtering with more ids than SQLite's variable limit must not raise. @@ -1312,22 +1323,15 @@ def test_sql_graph_filter_large_node_ids(tmp_path, monkeypatch: pytest.MonkeyPat forcing the scratch-table code path via a tiny chunk size. """ graph = SQLGraph("sqlite", str(tmp_path / "scratch.db")) - n_nodes = 40 - node_ids: list[int] = [] - for t in range(n_nodes): - node_ids.append(graph.add_node({DEFAULT_ATTR_KEYS.T: t})) - for src, tgt in itertools.pairwise(node_ids): - graph.add_edge(src, tgt, {}) - graph.add_overlap(node_ids[0], node_ids[1]) - graph.add_overlap(node_ids[2], node_ids[3]) + node_ids = _build_chain_graph(graph, n_nodes) - # Force scratch-table path: all three call sites gate on this size. + # Force scratch-table path on every call site. monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 4) filtered = graph.filter(node_ids=node_ids) # Confirm the scratch-table code path was taken rather than raw IN (...). - assert filtered._scratch_tables + assert filtered._id_sets[0].uses_scratch_table subgraph = filtered.subgraph() assert subgraph.num_nodes() == n_nodes assert subgraph.num_edges() == n_nodes - 1 @@ -1339,3 +1343,32 @@ def test_sql_graph_filter_large_node_ids(tmp_path, monkeypatch: pytest.MonkeyPat overlaps = graph.overlaps(node_ids) assert sorted(map(tuple, overlaps)) == sorted([(node_ids[0], node_ids[1]), (node_ids[2], node_ids[3])]) + + +def test_sql_graph_filter_borderline_node_ids(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: + """The scratch cutoff must account for how many times ids appear per statement. + + With ``_sql_chunk_size() == 12`` and ``SQLFilter`` using ``occurrences=3``, + a list of 5 ids would compile to ~15 bound variables — above the limit — + even though ``len(node_ids) <= chunk_size``. The helper must still switch + to the scratch-table path in that band. + """ + graph = SQLGraph("sqlite", str(tmp_path / "scratch.db")) + n_nodes = 5 + node_ids = _build_chain_graph(graph, n_nodes) + + monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 12) + + filtered = graph.filter(node_ids=node_ids) + # 5 ids fits under chunk_size=12 inline, but with occurrences=3 the + # effective cutoff is 12 // 3 == 4, so scratch must kick in. + assert filtered._id_sets[0].uses_scratch_table + subgraph = filtered.subgraph() + assert subgraph.num_nodes() == n_nodes + assert subgraph.num_edges() == n_nodes - 1 + + # overlaps() uses occurrences=2 → cutoff 6, so len==5 stays inline. + # Still assert it returns the right data regardless of path. + assert sorted(map(tuple, graph.overlaps(node_ids))) == sorted( + [(node_ids[0], node_ids[1]), (node_ids[2], node_ids[3])] + ) From 95bac1e24e19390d8a9a3acc2b1877bdc7bad6b5 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 13 May 2026 14:32:23 +0900 Subject: [PATCH 3/3] udpate --- src/tracksdata/graph/_sql_graph.py | 26 +++++++++------ src/tracksdata/graph/_test/test_subgraph.py | 37 +++++++++++++++------ 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 6657a2fa..b4a64793 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -58,6 +58,9 @@ def _data_numpy_to_native(data: dict[str, Any]) -> None: data[k] = v.item() +# Module-level (not methods) so they can be registered with ``weakref.finalize`` +# without holding a bound reference to the owning object, which would prevent +# it from ever being collected. def _drop_scratch_tables(engine: sa.Engine, tables: list[sa.Table]) -> None: """Drop scratch tables, swallowing errors (e.g. at interpreter shutdown).""" for table in tables: @@ -79,7 +82,7 @@ class _SqlIdSet: ``occurrences`` is the maximum number of times the id set will be expanded in a single compiled statement (e.g. filtering both ``source_id`` and ``target_id`` of an edge table counts as 2). The scratch-table cutoff is - scaled by it so that ``len(ids) * occurrences`` stays safely under the + divided by it so that ``len(ids) * occurrences`` stays safely under the backend's bound-variable limit. """ @@ -185,12 +188,13 @@ def __init__( node_filtered = False if node_ids is not None: - # ``node_ids`` may be expanded in up to three places within a - # single compiled statement (Node.node_id, Edge.source_id, - # Edge.target_id — the node-attr-filter path at L159-L160 inlines - # the edge query as a subquery), so the scratch-table cutoff is - # scaled accordingly. - id_set = _SqlIdSet(self._graph, node_ids, occurrences=3) + # The id set is expanded once for Node.node_id, plus once each for + # Edge.target_id / Edge.source_id when those filters are not + # suppressed by ``include_targets`` / ``include_sources``. The + # scratch-table cutoff is divided accordingly so that the total + # number of bound variables stays under the backend's limit. + occurrences = 1 + (not self._include_targets) + (not self._include_sources) + id_set = _SqlIdSet(self._graph, node_ids, occurrences=occurrences) self._id_sets.append(id_set) self._node_query = self._node_query.filter(id_set.in_clause(self._graph.Node.node_id)) @@ -262,9 +266,13 @@ def __init__( self._node_query = sa.union(*nodes_query) - if any(id_set.uses_scratch_table for id_set in self._id_sets): + if self._uses_scratch_tables(): weakref.finalize(self, _close_id_sets, self._id_sets) + def _uses_scratch_tables(self) -> bool: + """Whether any id set backing this filter materialized a scratch table.""" + return any(id_set.uses_scratch_table for id_set in self._id_sets) + @cache_method def node_ids(self) -> list[int]: """ @@ -1884,8 +1892,6 @@ def _create_id_scratch_table(self, ids: Sequence[int]) -> sa.Table: ``col.in_(sa.select(table.c.id))``. The caller owns the returned table and is responsible for dropping it. """ - if hasattr(ids, "tolist"): - ids = ids.tolist() unique_ids = list({int(v) for v in ids}) name = f"_tracksdata_ids_{uuid.uuid4().hex}" diff --git a/src/tracksdata/graph/_test/test_subgraph.py b/src/tracksdata/graph/_test/test_subgraph.py index 8b781f50..89484803 100644 --- a/src/tracksdata/graph/_test/test_subgraph.py +++ b/src/tracksdata/graph/_test/test_subgraph.py @@ -1316,6 +1316,16 @@ def _build_chain_graph(graph: SQLGraph, n_nodes: int) -> list[int]: return node_ids +def _scratch_table_count(graph: SQLGraph) -> int: + """Count leftover ``_tracksdata_ids_*`` scratch tables in a SQLite graph.""" + import sqlalchemy as sa + + with graph._engine.connect() as conn: + return conn.execute( + sa.text("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name LIKE '_tracksdata_ids_%'") + ).scalar() + + def test_sql_graph_filter_large_node_ids(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: """Filtering with more ids than SQLite's variable limit must not raise. @@ -1329,21 +1339,28 @@ def test_sql_graph_filter_large_node_ids(tmp_path, monkeypatch: pytest.MonkeyPat # Force scratch-table path on every call site. monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 4) - filtered = graph.filter(node_ids=node_ids) - # Confirm the scratch-table code path was taken rather than raw IN (...). - assert filtered._id_sets[0].uses_scratch_table - subgraph = filtered.subgraph() - assert subgraph.num_nodes() == n_nodes - assert subgraph.num_edges() == n_nodes - 1 - + # Context-manager paths (overlaps, _get_degree) must drop their scratch + # tables once the block exits — the count should return to baseline after + # each call regardless of whether the scratch path fired inside. + assert _scratch_table_count(graph) == 0 in_deg = graph.in_degree(node_ids) + assert _scratch_table_count(graph) == 0 out_deg = graph.out_degree(node_ids) + assert _scratch_table_count(graph) == 0 + overlaps = graph.overlaps(node_ids) + assert _scratch_table_count(graph) == 0 + assert sum(in_deg) == n_nodes - 1 assert sum(out_deg) == n_nodes - 1 - - overlaps = graph.overlaps(node_ids) assert sorted(map(tuple, overlaps)) == sorted([(node_ids[0], node_ids[1]), (node_ids[2], node_ids[3])]) + filtered = graph.filter(node_ids=node_ids) + # Confirm the scratch-table code path was taken rather than raw IN (...). + assert filtered._uses_scratch_tables() + subgraph = filtered.subgraph() + assert subgraph.num_nodes() == n_nodes + assert subgraph.num_edges() == n_nodes - 1 + def test_sql_graph_filter_borderline_node_ids(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: """The scratch cutoff must account for how many times ids appear per statement. @@ -1362,7 +1379,7 @@ def test_sql_graph_filter_borderline_node_ids(tmp_path, monkeypatch: pytest.Monk filtered = graph.filter(node_ids=node_ids) # 5 ids fits under chunk_size=12 inline, but with occurrences=3 the # effective cutoff is 12 // 3 == 4, so scratch must kick in. - assert filtered._id_sets[0].uses_scratch_table + assert filtered._uses_scratch_tables() subgraph = filtered.subgraph() assert subgraph.num_nodes() == n_nodes assert subgraph.num_edges() == n_nodes - 1