From 68bbb6ab50dc47487a3a28666b8ca8cdfdbbd3a0 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Fri, 22 May 2026 10:01:58 -0400 Subject: [PATCH 1/2] Allow specifying the mask and bbox keys in the GraphArrayView --- src/tracksdata/array/_graph_array.py | 25 +++- .../array/_test/test_graph_array.py | 116 ++++++++++++++++++ 2 files changed, 135 insertions(+), 6 deletions(-) diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index 6e88828d..5debb4bd 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -140,6 +140,10 @@ class GraphArrayView(BaseReadOnlyArray): buffer_cache_size : int, optional The maximum number of buffers to keep in the cache for the array. If None, the default buffer cache size is used. + mask_attr_key : str, optional + The node attribute key used to retrieve mask data. Defaults to "mask". + bbox_attr_key : str, optional + The node attribute key used to retrieve bounding box data. Defaults to "bbox". """ def __init__( @@ -152,6 +156,8 @@ def __init__( chunk_shape: tuple[int, ...] | int | None = None, buffer_cache_size: int | None = None, dtype: np.dtype | None = None, + mask_attr_key: str = DEFAULT_ATTR_KEYS.MASK, + bbox_attr_key: str = DEFAULT_ATTR_KEYS.BBOX, ): if attr_key not in graph.node_attr_keys(return_ids=True): raise ValueError(f"Attribute key '{attr_key}' not found in graph. Expected '{graph.node_attr_keys()}'") @@ -159,6 +165,8 @@ def __init__( self.graph = graph self._attr_key = attr_key self._offset = offset + self._mask_attr_key = mask_attr_key + self._bbox_attr_key = bbox_attr_key if dtype is None: # Infer the dtype from the graph's attribute @@ -198,7 +206,7 @@ def __init__( self._spatial_filter = self.graph.bbox_spatial_filter( frame_attr_key=DEFAULT_ATTR_KEYS.T, - bbox_attr_key=DEFAULT_ATTR_KEYS.BBOX, + bbox_attr_key=self._bbox_attr_key, ) self.graph.node_added.connect(self._on_node_added) self.graph.node_removed.connect(self._on_node_removed) @@ -348,10 +356,10 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda """ subgraph = self._spatial_filter[(slice(time, time), *volume_slicing)] df = subgraph.node_attrs( - attr_keys=[self._attr_key, DEFAULT_ATTR_KEYS.MASK], + attr_keys=[self._attr_key, self._mask_attr_key], ) - for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=True): + for mask, value in zip(df[self._mask_attr_key], df[self._attr_key], strict=True): mask: Mask mask.paint_buffer(buffer, value, offset=self._offset) @@ -394,13 +402,18 @@ def _invalidate_from_attrs(self, attrs: dict) -> None: Invalidate cache region touched by node attributes. Falls back to larger invalidation windows when metadata is incomplete. + When the bbox attribute key is not present in the attrs dict (e.g. when + using a non-default bbox key and the update doesn't affect this view), + the method returns without invalidating. """ time_value = attrs.get(DEFAULT_ATTR_KEYS.T) if time_value is None: raise ValueError(f"Node attributes must contain '{DEFAULT_ATTR_KEYS.T}' key for cache invalidation.") - if DEFAULT_ATTR_KEYS.BBOX not in attrs: - raise ValueError(f"Node attributes must contain '{DEFAULT_ATTR_KEYS.BBOX}' key for cache invalidation.") + if self._bbox_attr_key not in attrs: + # The update doesn't involve this view's bbox attribute — + # nothing to invalidate (e.g. a nuclear GAV seeing a membrane-only update). + return try: time = int(np.asarray(time_value).item()) @@ -411,7 +424,7 @@ def _invalidate_from_attrs(self, attrs: dict) -> None: if not (0 <= time < self.original_shape[0]): return - slices = self._bbox_to_slices(attrs[DEFAULT_ATTR_KEYS.BBOX]) + slices = self._bbox_to_slices(attrs[self._bbox_attr_key]) if slices is not None: self._cache.invalidate(time=time, volume_slicing=slices) diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index 7362f89c..4bb0d680 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -497,3 +497,119 @@ def test_graph_array_view_invalidates_chunk_on_remove(graph_backend: BaseGraph) output = np.asarray(array_view[0]) assert output[1, 1] == 1 assert output[5, 5] == 0 + + +def test_graph_array_view_custom_mask_bbox_keys(graph_backend: BaseGraph) -> None: + """Test GraphArrayView with custom mask_attr_key and bbox_attr_key.""" + + # Standard mask/bbox attributes + graph_backend.add_node_attr_key("label", dtype=pl.Int64) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) + + # Custom (nuclear) mask/bbox attributes + graph_backend.add_node_attr_key("nuc_mask", pl.Object) + graph_backend.add_node_attr_key("nuc_bbox", pl.Array(pl.Int64, 4), default_value=[0, 0, 0, 0]) + + # Create masks: membrane is 4x4, nuclear is 2x2 at same location + mem_mask = Mask(np.ones((4, 4), dtype=bool), bbox=np.array([10, 20, 14, 24])) + nuc_mask = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([11, 21, 13, 23])) + + graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 5, + DEFAULT_ATTR_KEYS.MASK: mem_mask, + DEFAULT_ATTR_KEYS.BBOX: mem_mask.bbox, + "nuc_mask": nuc_mask, + "nuc_bbox": nuc_mask.bbox, + } + ) + + # Standard GAV uses default mask/bbox + std_view = GraphArrayView( + graph=graph_backend, shape=(2, 50, 50), attr_key="label" + ) + # Custom GAV uses nuclear mask/bbox + nuc_view = GraphArrayView( + graph=graph_backend, + shape=(2, 50, 50), + attr_key="label", + mask_attr_key="nuc_mask", + bbox_attr_key="nuc_bbox", + ) + + std_result = np.asarray(std_view[0]) + nuc_result = np.asarray(nuc_view[0]) + + # Standard view should have label painted at membrane mask area (4x4) + assert std_result[10, 20] == 5 + assert std_result[13, 23] == 5 + # Nuclear view should have label painted at nuclear mask area (2x2) + assert nuc_result[11, 21] == 5 + assert nuc_result[12, 22] == 5 + + # Point inside membrane but outside nuclear mask + assert std_result[10, 20] == 5 # inside membrane + assert nuc_result[10, 20] == 0 # outside nuclear + + # Total painted area differs + assert np.sum(std_result > 0) == 16 # 4x4 + assert np.sum(nuc_result > 0) == 4 # 2x2 + + +def test_graph_array_view_custom_keys_survives_membrane_update(graph_backend: BaseGraph) -> None: + """Test that a nuclear GAV handles membrane-only updates without error. + + When update_node_attrs is called with only the standard mask/bbox, + the nuclear GAV should handle the signal gracefully (the signal attrs + may or may not include 'nuc_bbox' depending on the graph backend). + After the update, the nuclear GAV should still return correct data. + """ + graph_backend.add_node_attr_key("label", dtype=pl.Int64) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) + graph_backend.add_node_attr_key("nuc_mask", pl.Object) + graph_backend.add_node_attr_key("nuc_bbox", pl.Array(pl.Int64, 4), default_value=[0, 0, 0, 0]) + + mem_mask = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([1, 1, 3, 3])) + nuc_mask = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([1, 1, 3, 3])) + + node_id = graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 1, + DEFAULT_ATTR_KEYS.MASK: mem_mask, + DEFAULT_ATTR_KEYS.BBOX: mem_mask.bbox, + "nuc_mask": nuc_mask, + "nuc_bbox": nuc_mask.bbox, + } + ) + + nuc_view = GraphArrayView( + graph=graph_backend, + shape=(2, 8, 8), + attr_key="label", + mask_attr_key="nuc_mask", + bbox_attr_key="nuc_bbox", + ) + + # Verify initial nuclear data is correct + output = np.asarray(nuc_view[0]) + assert output[1, 1] == 1 + assert output[2, 2] == 1 + + # Update only the membrane mask — should not crash the nuclear GAV + moved_mem_mask = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([5, 5, 7, 7])) + graph_backend.update_node_attrs( + attrs={ + DEFAULT_ATTR_KEYS.MASK: [moved_mem_mask], + DEFAULT_ATTR_KEYS.BBOX: [moved_mem_mask.bbox], + }, + node_ids=[node_id], + ) + + # Nuclear data should still be correct after membrane-only update + output = np.asarray(nuc_view[0]) + assert output[1, 1] == 1 + assert output[2, 2] == 1 From 230a8e5540ceb251ee69d4aa0df2661968460404 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Fri, 22 May 2026 14:26:03 -0400 Subject: [PATCH 2/2] More checks for missing bbox attribute --- src/tracksdata/graph/filters/_spatial_filter.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/tracksdata/graph/filters/_spatial_filter.py b/src/tracksdata/graph/filters/_spatial_filter.py index e13e1711..5ab5c499 100644 --- a/src/tracksdata/graph/filters/_spatial_filter.py +++ b/src/tracksdata/graph/filters/_spatial_filter.py @@ -427,6 +427,9 @@ def _add_node( new_attrs : dict[str, Any] Current node attributes to insert into the spatial index. """ + if self._bbox_attr_key not in new_attrs: + return + from spatial_graph import PointRTree if self._node_rtree is None: @@ -470,6 +473,8 @@ def _remove_node( """ if self._node_rtree is None: return + if self._bbox_attr_key not in old_attrs: + return positions_min, positions_max = self._attrs_to_bb_window(old_attrs) @@ -485,6 +490,8 @@ def _update_node( old_attrs: dict[str, Any], new_attrs: dict[str, Any], ) -> None: + if self._bbox_attr_key not in old_attrs: + return self._remove_node(node_id, old_attrs=old_attrs) self._add_node(node_id, new_attrs=new_attrs)