From 20dca625a238b9cfb1fc1268bb6d918e896a26e6 Mon Sep 17 00:00:00 2001 From: ZimingHua Date: Wed, 1 Apr 2026 12:33:12 -0400 Subject: [PATCH 1/2] Add dimension guards for ORG hourly_wage imputation Prevent hourly_wage and other ORG variables from being stored with wrong dimensions (e.g. ORG donor count instead of CPS person count) by adding validation checks at every stage of the pipeline: - predict_org_features: assert output matches receiver frame size - add_org_labor_market_inputs: assert predictions match CPS person count - _splice_cps_only_predictions: assert predictions match entity half-size - _impute_org (source_impute): assert predictions match person count - CPS.downsample: drop unknown variables instead of keeping stale arrays Closes #675 Co-Authored-By: Claude Opus 4.6 (1M context) --- .../calibration/source_impute.py | 6 ++++++ policyengine_us_data/datasets/cps/cps.py | 19 ++++++++++++++++++- .../datasets/cps/extended_cps.py | 5 +++++ policyengine_us_data/datasets/org/org.py | 13 ++++++++++++- 4 files changed, 41 insertions(+), 2 deletions(-) diff --git a/policyengine_us_data/calibration/source_impute.py b/policyengine_us_data/calibration/source_impute.py index 883d68dc7..1b3a0703b 100644 --- a/policyengine_us_data/calibration/source_impute.py +++ b/policyengine_us_data/calibration/source_impute.py @@ -761,6 +761,7 @@ def _impute_org( if "self_employment_income" in cps_df.columns else None ) + n_persons = len(data["person_id"][time_period]) predictions = predict_org_features( receiver, self_employment_income=self_employment_income, @@ -768,6 +769,11 @@ def _impute_org( for var in ORG_IMPUTED_VARIABLES: values = predictions[var].values + if len(values) != n_persons: + raise ValueError( + f"ORG prediction for '{var}' has {len(values)} entries " + f"but dataset has {n_persons} persons" + ) if var in ORG_BOOL_VARIABLES: data[var] = {time_period: values.astype(bool)} else: diff --git a/policyengine_us_data/datasets/cps/cps.py b/policyengine_us_data/datasets/cps/cps.py index 148e83e4c..c801b709e 100644 --- a/policyengine_us_data/datasets/cps/cps.py +++ b/policyengine_us_data/datasets/cps/cps.py @@ -108,11 +108,14 @@ def downsample(self, frac: float): sim = Microsimulation(dataset=self) sim.subsample(frac=frac) + keys_to_drop = [] for key in original_data: if key not in sim.tax_benefit_system.variables: logging.warning( - f"Attempting to downsample the variable {key} but failing because it is not in the given country package." + f"Dropping variable {key} during downsample: " + f"not in the current country package." ) + keys_to_drop.append(key) continue values = sim.calculate(key).values @@ -133,6 +136,9 @@ def downsample(self, frac: float): else: original_data[key] = values + for key in keys_to_drop: + del original_data[key] + self.save_dataset(original_data) @@ -1838,6 +1844,7 @@ def add_tips(self, cps: h5py.File): def add_org_labor_market_inputs(cps: h5py.File) -> None: """Impute ORG-derived wage and union inputs onto CPS persons.""" + n_persons = len(np.asarray(cps["age"])) household_ids = np.asarray(cps["household_id"], dtype=np.int64) person_household_ids = np.asarray( cps["person_household_id"], @@ -1864,6 +1871,11 @@ def add_org_labor_market_inputs(cps: h5py.File) -> None: employment_income=cps["employment_income"], weekly_hours_worked=cps["weekly_hours_worked"], ) + if len(receiver) != n_persons: + raise ValueError( + f"ORG receiver frame has {len(receiver)} rows but CPS has " + f"{n_persons} persons" + ) self_employment_income = np.asarray( cps.get( "self_employment_income", @@ -1878,6 +1890,11 @@ def add_org_labor_market_inputs(cps: h5py.File) -> None: for variable in ORG_IMPUTED_VARIABLES: values = predictions[variable].values + if len(values) != n_persons: + raise ValueError( + f"ORG prediction for '{variable}' has {len(values)} entries " + f"but CPS has {n_persons} persons" + ) if variable in ORG_BOOL_VARIABLES: cps[variable] = values.astype(bool) else: diff --git a/policyengine_us_data/datasets/cps/extended_cps.py b/policyengine_us_data/datasets/cps/extended_cps.py index ca5c38afd..8b9041c0c 100644 --- a/policyengine_us_data/datasets/cps/extended_cps.py +++ b/policyengine_us_data/datasets/cps/extended_cps.py @@ -418,6 +418,11 @@ def _splice_cps_only_predictions( ) n_half = entity_half_lengths.get(entity_key, len(data[var][time_period]) // 2) + if len(pred_values) != n_half: + raise ValueError( + f"Stage-2 prediction for '{var}' has {len(pred_values)} " + f"entries but expected {n_half} (half of {entity_key})" + ) values = data[var][time_period] # First half: keep original CPS values. # Second half: replace with QRF predictions. diff --git a/policyengine_us_data/datasets/org/org.py b/policyengine_us_data/datasets/org/org.py index b709bbc69..74dd3c9a0 100644 --- a/policyengine_us_data/datasets/org/org.py +++ b/policyengine_us_data/datasets/org/org.py @@ -461,13 +461,24 @@ def predict_org_features( if missing: raise ValueError(f"ORG receiver frame missing required columns: {missing}") + n_receiver = len(receiver) predictions = get_org_model().predict(X_test=receiver[ORG_PREDICTORS]) + if len(predictions) != n_receiver: + raise ValueError( + f"ORG QRF returned {len(predictions)} rows but receiver has " + f"{n_receiver} rows; predictions must match receiver length" + ) predictions["is_union_member_or_covered"] = _predict_union_coverage_from_bls_tables( receiver, self_employment_income=self_employment_income, ) - return apply_org_domain_constraints( + result = apply_org_domain_constraints( predictions=predictions, receiver=receiver, self_employment_income=self_employment_income, ) + if len(result) != n_receiver: + raise ValueError( + f"ORG post-processing changed row count from {n_receiver} to {len(result)}" + ) + return result From 3f7642edc92a48157965f250f7a7c7f8ec7f9e0d Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 1 Apr 2026 13:25:38 -0400 Subject: [PATCH 2/2] Fail closed on downsample schema skew --- policyengine_us_data/datasets/cps/cps.py | 43 ++----- policyengine_us_data/datasets/scf/scf.py | 36 ++---- policyengine_us_data/tests/test_downsample.py | 103 +++++++++++++++ policyengine_us_data/utils/downsample.py | 118 ++++++++++++++++++ 4 files changed, 238 insertions(+), 62 deletions(-) create mode 100644 policyengine_us_data/tests/test_downsample.py create mode 100644 policyengine_us_data/utils/downsample.py diff --git a/policyengine_us_data/datasets/cps/cps.py b/policyengine_us_data/datasets/cps/cps.py index c801b709e..8d1189aa2 100644 --- a/policyengine_us_data/datasets/cps/cps.py +++ b/policyengine_us_data/datasets/cps/cps.py @@ -24,6 +24,7 @@ build_org_receiver_frame, predict_org_features, ) +from policyengine_us_data.utils.downsample import downsample_dataset_arrays from policyengine_us_data.utils.randomness import seeded_rng @@ -102,44 +103,16 @@ def generate(self): def downsample(self, frac: float): from policyengine_us import Microsimulation - # Store original dtypes before modifying original_data: dict = self.load_dataset() - original_dtypes = {key: original_data[key].dtype for key in original_data} sim = Microsimulation(dataset=self) sim.subsample(frac=frac) - - keys_to_drop = [] - for key in original_data: - if key not in sim.tax_benefit_system.variables: - logging.warning( - f"Dropping variable {key} during downsample: " - f"not in the current country package." - ) - keys_to_drop.append(key) - continue - values = sim.calculate(key).values - - # Preserve the original dtype if possible - if ( - key in original_dtypes - and hasattr(values, "dtype") - and values.dtype != original_dtypes[key] - ): - try: - original_data[key] = values.astype(original_dtypes[key]) - except: - # If conversion fails, log it but continue - logging.warning( - f"Could not convert {key} back to {original_dtypes[key]}" - ) - original_data[key] = values - else: - original_data[key] = values - - for key in keys_to_drop: - del original_data[key] - - self.save_dataset(original_data) + self.save_dataset( + downsample_dataset_arrays( + original_data=original_data, + sim=sim, + dataset_name=self.name, + ) + ) def add_rent(self, cps: h5py.File, person: DataFrame, household: DataFrame): diff --git a/policyengine_us_data/datasets/scf/scf.py b/policyengine_us_data/datasets/scf/scf.py index 5a1a3af59..caac23ee8 100644 --- a/policyengine_us_data/datasets/scf/scf.py +++ b/policyengine_us_data/datasets/scf/scf.py @@ -13,6 +13,8 @@ from typing import Type from filelock import FileLock +from policyengine_us_data.utils.downsample import downsample_dataset_arrays + class SCF(Dataset): """Dataset containing processed Survey of Consumer Finances data.""" @@ -115,36 +117,16 @@ def downsample(self, frac: float): """ from policyengine_us import Microsimulation - # Store original dtypes before modifying original_data: dict = self.load_dataset() - original_dtypes = {key: original_data[key].dtype for key in original_data} - sim = Microsimulation(dataset=self) sim.subsample(frac=frac) - - for key in original_data: - if key not in sim.tax_benefit_system.variables: - continue - values = sim.calculate(key).values - - # Preserve the original dtype if possible - if ( - key in original_dtypes - and hasattr(values, "dtype") - and values.dtype != original_dtypes[key] - ): - try: - original_data[key] = values.astype(original_dtypes[key]) - except: - # If conversion fails, log it but continue - print( - f"Warning: Could not convert {key} back to {original_dtypes[key]}" - ) - original_data[key] = values - else: - original_data[key] = values - - self.save_dataset(original_data) + self.save_dataset( + downsample_dataset_arrays( + original_data=original_data, + sim=sim, + dataset_name=self.name, + ) + ) def _lock(self) -> FileLock: return FileLock(f"{self.file_path}.lock", timeout=600) diff --git a/policyengine_us_data/tests/test_downsample.py b/policyengine_us_data/tests/test_downsample.py new file mode 100644 index 000000000..8ca42504c --- /dev/null +++ b/policyengine_us_data/tests/test_downsample.py @@ -0,0 +1,103 @@ +from types import SimpleNamespace + +import numpy as np +import pytest + +from policyengine_us_data.utils.downsample import downsample_dataset_arrays + + +class _FakeArrayResult: + def __init__(self, values): + self.values = values + + +class _FakeMicrosimulation: + def __init__(self, variable_entities, calculated_values): + self.tax_benefit_system = SimpleNamespace( + variables={ + variable_name: SimpleNamespace(entity=SimpleNamespace(key=entity_key)) + for variable_name, entity_key in variable_entities.items() + } + ) + self._calculated_values = calculated_values + + def calculate(self, variable_name): + return _FakeArrayResult(self._calculated_values[variable_name]) + + +def test_downsample_dataset_arrays_preserves_original_dtypes(): + original_data = { + "person_id": np.array([101, 102], dtype=np.int32), + "household_id": np.array([201], dtype=np.int32), + "employment_income": np.array([100.0, 200.0], dtype=np.float32), + } + sim = _FakeMicrosimulation( + variable_entities={ + "person_id": "person", + "household_id": "household", + "employment_income": "person", + }, + calculated_values={ + "person_id": np.array([101], dtype=np.int64), + "household_id": np.array([201], dtype=np.int64), + "employment_income": np.array([150.0], dtype=np.float64), + }, + ) + + resampled = downsample_dataset_arrays( + original_data=original_data, + sim=sim, + dataset_name="cps", + ) + + assert resampled["person_id"].dtype == np.int32 + assert resampled["household_id"].dtype == np.int32 + assert resampled["employment_income"].dtype == np.float32 + np.testing.assert_array_equal( + resampled["employment_income"], np.array([150.0], dtype=np.float32) + ) + + +def test_downsample_dataset_arrays_fails_closed_on_unknown_variables(): + original_data = { + "person_id": np.array([101, 102], dtype=np.int32), + "hourly_wage": np.array([25.0, 30.0], dtype=np.float32), + } + sim = _FakeMicrosimulation( + variable_entities={"person_id": "person"}, + calculated_values={"person_id": np.array([101], dtype=np.int64)}, + ) + + with pytest.raises(ValueError, match="out of sync"): + downsample_dataset_arrays( + original_data=original_data, + sim=sim, + dataset_name="cps", + ) + + +def test_downsample_dataset_arrays_rejects_entity_length_mismatches(): + original_data = { + "person_id": np.array([101, 102], dtype=np.int32), + "household_id": np.array([201], dtype=np.int32), + "employment_income": np.array([100.0, 200.0], dtype=np.float32), + } + sim = _FakeMicrosimulation( + variable_entities={ + "person_id": "person", + "household_id": "household", + "employment_income": "person", + }, + calculated_values={ + "person_id": np.array([101], dtype=np.int64), + "household_id": np.array([201], dtype=np.int64), + "employment_income": np.array([150.0, 250.0], dtype=np.float64), + }, + ) + + with pytest.raises(ValueError, match="entity lengths are inconsistent"): + downsample_dataset_arrays( + original_data=original_data, + sim=sim, + dataset_name="cps", + ) diff --git a/policyengine_us_data/utils/downsample.py b/policyengine_us_data/utils/downsample.py new file mode 100644 index 000000000..bc21ca71e --- /dev/null +++ b/policyengine_us_data/utils/downsample.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np + + +ENTITY_ID_VARIABLES = { + "person": "person_id", + "tax_unit": "tax_unit_id", + "family": "family_id", + "spm_unit": "spm_unit_id", + "household": "household_id", +} + + +def _format_variable_list(variable_names: list[str], max_display: int = 5) -> str: + displayed = variable_names[:max_display] + suffix = "" if len(variable_names) <= max_display else ", ..." + return ", ".join(displayed) + suffix + + +def _restore_original_dtype( + variable_name: str, + values: Any, + original_dtype: np.dtype | None, +): + if original_dtype is None or not hasattr(values, "dtype"): + return values + if values.dtype == original_dtype: + return values + try: + return values.astype(original_dtype) + except Exception: + logging.warning( + "Could not convert %s back to %s after downsampling.", + variable_name, + original_dtype, + ) + return values + + +def _validate_known_variables( + original_data: dict, tax_benefit_system, dataset_name: str +): + unknown_variables = sorted( + key for key in original_data if key not in tax_benefit_system.variables + ) + if unknown_variables: + raise ValueError( + f"Cannot downsample {dataset_name}: found {len(unknown_variables)} " + "dataset variables missing from the current country package " + f"({_format_variable_list(unknown_variables)}). This usually means " + "policyengine-us-data and policyengine-us are out of sync." + ) + + +def _validate_entity_lengths( + resampled_data: dict, + tax_benefit_system, + dataset_name: str, +): + entity_counts = { + entity_key: len(np.asarray(resampled_data[id_variable])) + for entity_key, id_variable in ENTITY_ID_VARIABLES.items() + if id_variable in resampled_data + } + + mismatches = [] + for variable_name, values in resampled_data.items(): + variable = tax_benefit_system.variables.get(variable_name) + if variable is None: + continue + entity_key = getattr(getattr(variable, "entity", None), "key", None) + expected_length = entity_counts.get(entity_key) + if expected_length is None: + continue + actual_length = len(np.asarray(values)) + if actual_length != expected_length: + mismatches.append( + f"{variable_name} ({entity_key}: expected {expected_length}, found {actual_length})" + ) + + if mismatches: + raise ValueError( + f"Cannot save downsampled {dataset_name}: entity lengths are inconsistent " + f"({_format_variable_list(mismatches)})." + ) + + +def downsample_dataset_arrays(original_data: dict, sim, dataset_name: str) -> dict: + _validate_known_variables( + original_data=original_data, + tax_benefit_system=sim.tax_benefit_system, + dataset_name=dataset_name, + ) + + original_dtypes = { + key: values.dtype + for key, values in original_data.items() + if hasattr(values, "dtype") + } + resampled_data = {} + for variable_name in original_data: + values = sim.calculate(variable_name).values + resampled_data[variable_name] = _restore_original_dtype( + variable_name=variable_name, + values=values, + original_dtype=original_dtypes.get(variable_name), + ) + + _validate_entity_lengths( + resampled_data=resampled_data, + tax_benefit_system=sim.tax_benefit_system, + dataset_name=dataset_name, + ) + return resampled_data