diff --git a/pyproject.toml b/pyproject.toml index fd9e10d..4992a8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "gitpython>=3.1.45", "starfix>=0.1.3", "pygraphviz>=1.14", + "tzdata>=2024.1", "uuid-utils>=0.11.1", "s3fs>=2025.12.0", "pymongo>=4.15.5", diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index e30b7f6..bec4645 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -309,7 +309,20 @@ def as_table( drop_columns.extend(f"{constants.SOURCE_PREFIX}{c}" for c in self.keys()[1]) if not column_config.context: drop_columns.append(constants.CONTEXT_KEY) - + if not column_config.meta: + drop_columns.extend( + c + for c in self._cached_output_table.column_names + if c.startswith(constants.META_PREFIX) + ) + elif not isinstance(column_config.meta, bool): + # Collection[str]: keep only meta columns matching the specified prefixes + drop_columns.extend( + c + for c in self._cached_output_table.column_names + if c.startswith(constants.META_PREFIX) + and not any(c.startswith(p) for p in column_config.meta) + ) output_table = self._cached_output_table.drop( [c for c in drop_columns if c in self._cached_output_table.column_names] ) diff --git a/src/orcapod/core/operators/join.py b/src/orcapod/core/operators/join.py index 961c8ea..329c80c 100644 --- a/src/orcapod/core/operators/join.py +++ b/src/orcapod/core/operators/join.py @@ -129,7 +129,7 @@ def static_process(self, *streams: StreamProtocol) -> StreamProtocol: stream = ordered_streams[0] tag_keys, _ = [set(k) for k in stream.keys()] - table = stream.as_table(columns={"source": True, "system_tags": True}) + table = stream.as_table(columns={"source": True, "system_tags": True, "meta": True}) # trick to get cartesian product table = table.add_column(0, COMMON_JOIN_KEY, pa.array([0] * len(table))) table = arrow_data_utils.append_to_system_tags( @@ -139,9 +139,7 @@ def static_process(self, *streams: StreamProtocol) -> StreamProtocol: for idx, next_stream in enumerate(ordered_streams[1:], start=1): next_tag_keys, _ = next_stream.keys() - next_table = next_stream.as_table( - columns={"source": True, "system_tags": True} - ) + next_table = next_stream.as_table(columns={"source": True, "system_tags": True, "meta": True}) next_table = arrow_data_utils.append_to_system_tags( next_table, f"{next_stream.pipeline_hash().to_hex(n_char)}:{idx}", @@ -151,6 +149,31 @@ def static_process(self, *streams: StreamProtocol) -> StreamProtocol: next_table = next_table.add_column( 0, COMMON_JOIN_KEY, pa.array([0] * len(next_table)) ) + + # Rename any non-key columns in next_table that would collide with + # the accumulated table, using stream-index-based suffixes instead of + # Polars' default ``_right`` suffix which causes cascading collisions + # on 3+ stream joins. The only legitimately shared column names are + # the tag join keys; everything else (meta columns, their derived + # source-info columns, etc.) must be unique. + join_key_set = tag_keys.intersection(next_tag_keys) | {COMMON_JOIN_KEY} + existing_names = set(table.column_names) + rename_map = {} + for col in next_table.column_names: + if col not in join_key_set and col in existing_names: + new_name = f"{col}_{idx}" + counter = idx + while new_name in existing_names or new_name in rename_map.values(): + counter += 1 + new_name = f"{col}_{counter}" + rename_map[col] = new_name + if rename_map: + next_table = ( + pl.DataFrame(next_table) + .rename(rename_map) + .to_arrow() + ) + common_tag_keys = tag_keys.intersection(next_tag_keys) common_tag_keys.add(COMMON_JOIN_KEY) diff --git a/tests/test_core/operators/test_operators.py b/tests/test_core/operators/test_operators.py index ef38509..7817a5c 100644 --- a/tests/test_core/operators/test_operators.py +++ b/tests/test_core/operators/test_operators.py @@ -471,6 +471,50 @@ def test_join_is_commutative(self, simple_stream, disjoint_stream): assert isinstance(sym, frozenset) +class TestJoinMetaColumnCollision: + """Verify that a 3-way join with identical meta columns on all inputs does not + raise a DuplicateError. Instead, colliding meta columns should be renamed + with stream-index-based suffixes (e.g. ``__computed_1``, ``__computed_2``).""" + + def _make_stream(self, id_vals, pkt_col, pkt_vals, meta_val): + """Helper: stream with shared tag 'id', one packet column, and ``__computed``.""" + table = pa.table( + { + "id": pa.array(id_vals, type=pa.int64()), + pkt_col: pa.array(pkt_vals, type=pa.int64()), + "__computed": pa.array([meta_val] * len(id_vals), type=pa.bool_()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + def test_three_way_join_with_shared_meta_column_succeeds(self): + """Three streams each carrying ``__computed`` should join without DuplicateError.""" + s1 = self._make_stream([1, 2], "alpha", [10, 20], True) + s2 = self._make_stream([1, 2], "beta", [100, 200], True) + s3 = self._make_stream([1, 2], "gamma", [1000, 2000], True) + + result = Join().static_process(s1, s2, s3) + table = result.as_table() + + assert len(table) == 2 + assert {"id", "alpha", "beta", "gamma"}.issubset(set(table.column_names)) + + def test_three_way_join_meta_columns_renamed_with_index_suffix(self): + """Colliding meta columns from streams 2+ get an index-based suffix.""" + s1 = self._make_stream([1, 2], "alpha", [10, 20], True) + s2 = self._make_stream([1, 2], "beta", [100, 200], False) + s3 = self._make_stream([1, 2], "gamma", [1000, 2000], True) + + result = Join().static_process(s1, s2, s3) + table = result.as_table() + col_names = set(table.column_names) + + # Original meta column preserved; colliding ones renamed with suffix + assert "__computed" in col_names + assert "__computed_1" in col_names + assert "__computed_2" in col_names + + class TestJoinOutputSchemaSystemTags: """Verify that Join.output_schema correctly predicts system tag columns.""" diff --git a/uv.lock b/uv.lock index 6364033..bbede9a 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.11.0" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'darwin'", @@ -1904,6 +1904,7 @@ dependencies = [ { name = "s3fs" }, { name = "starfix" }, { name = "typing-extensions" }, + { name = "tzdata" }, { name = "uuid-utils" }, { name = "xxhash" }, ] @@ -1970,6 +1971,7 @@ requires-dist = [ { name = "s3fs", specifier = ">=2025.12.0" }, { name = "starfix", specifier = ">=0.1.3" }, { name = "typing-extensions" }, + { name = "tzdata", specifier = ">=2024.1" }, { name = "uuid-utils", specifier = ">=0.11.1" }, { name = "xxhash" }, ]