Skip to content
3 changes: 3 additions & 0 deletions changelog.d/fix-artifact-contract.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Added fail-closed dataset contract validation for built CPS artifacts, including
`policyengine-us` lockfile version checks, per-entity HDF5 length validation,
and file-based `Microsimulation` smoke tests in both the build and upload paths.
45 changes: 34 additions & 11 deletions modal_app/data_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
if _p not in sys.path:
sys.path.insert(0, _p)

from modal_app.images import cpu_image as image
from modal_app.images import cpu_image as image # noqa: E402

app = modal.App("policyengine-us-data")

Expand Down Expand Up @@ -233,6 +233,34 @@ def run_script(
return script_path


def validate_and_maybe_upload_datasets(
*,
upload: bool,
skip_enhanced_cps: bool,
env: dict,
) -> None:
validation_args = ["--validate-only"]
if skip_enhanced_cps:
validation_args.append("--no-require-enhanced-cps")

print("=== Validating built datasets ===")
run_script(
"policyengine_us_data/storage/upload_completed_datasets.py",
args=validation_args,
env=env,
)

if upload:
upload_args = []
if skip_enhanced_cps:
upload_args.append("--no-require-enhanced-cps")
run_script(
"policyengine_us_data/storage/upload_completed_datasets.py",
args=upload_args,
env=env,
)


def run_script_with_checkpoint(
script_path: str,
output_files: str | list[str],
Expand Down Expand Up @@ -634,16 +662,11 @@ def build_datasets(
print("=== Running tests with checkpointing ===")
run_tests_with_checkpoints(branch, checkpoint_volume, env)

# Upload if requested (HF publication only)
if upload:
upload_args = []
if skip_enhanced_cps:
upload_args.append("--no-require-enhanced-cps")
run_script(
"policyengine_us_data/storage/upload_completed_datasets.py",
args=upload_args,
env=env,
)
validate_and_maybe_upload_datasets(
upload=upload,
skip_enhanced_cps=skip_enhanced_cps,
env=env,
)

# Clean up checkpoints after successful completion
cleanup_checkpoints(branch, checkpoint_volume)
Expand Down
13 changes: 8 additions & 5 deletions policyengine_us_data/datasets/cps/enhanced_cps.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def loss(weights):
optimizer.zero_grad()
masked = torch.exp(weights) * gates()
l_main = loss(masked)
l = l_main + l0_lambda * gates.get_penalty()
total_loss = l_main + l0_lambda * gates.get_penalty()
if (log_path is not None) and (i % 10 == 0):
gates.eval()
estimates = (torch.exp(weights) * gates()) @ loss_matrix
Expand All @@ -108,10 +108,12 @@ def loss(weights):
if (log_path is not None) and (i % 1000 == 0):
performance.to_csv(log_path, index=False)
if start_loss is None:
start_loss = l.item()
loss_rel_change = (l.item() - start_loss) / start_loss
l.backward()
iterator.set_postfix({"loss": l.item(), "loss_rel_change": loss_rel_change})
start_loss = total_loss.item()
loss_rel_change = (total_loss.item() - start_loss) / start_loss
total_loss.backward()
iterator.set_postfix(
{"loss": total_loss.item(), "loss_rel_change": loss_rel_change}
)
optimizer.step()
if log_path is not None:
performance.to_csv(log_path, index=False)
Expand Down Expand Up @@ -248,6 +250,7 @@ class EnhancedCPS_2024(EnhancedCPS):
input_dataset = ExtendedCPS_2024_Half
start_year = 2024
end_year = 2024
time_period = 2024
name = "enhanced_cps_2024"
label = "Enhanced CPS 2024"
file_path = STORAGE_FOLDER / "enhanced_cps_2024.h5"
Expand Down
21 changes: 16 additions & 5 deletions policyengine_us_data/db/etl_state_income_tax.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
CENSUS_STC_FLAT_FILE_URLS = {
2023: "https://www2.census.gov/programs-surveys/stc/datasets/2023/FY2023-Flat-File.txt",
}
LATEST_STC_YEAR = max(CENSUS_STC_FLAT_FILE_URLS)
CENSUS_STC_INDIVIDUAL_INCOME_TAX_ITEM = "T40"
CENSUS_STC_NOT_AVAILABLE = "X"

Expand Down Expand Up @@ -179,7 +180,9 @@ def transform_state_income_tax_data(df: pd.DataFrame) -> pd.DataFrame:
return result


def load_state_income_tax_data(df: pd.DataFrame, year: int) -> dict:
def load_state_income_tax_data(
df: pd.DataFrame, year: int, source_year: int | None = None
) -> dict:
"""
Load state income tax targets into the calibration database.

Expand Down Expand Up @@ -241,7 +244,7 @@ def load_state_income_tax_data(df: pd.DataFrame, year: int) -> dict:
value=row["income_tax_collections"],
active=True,
source="Census STC",
notes=f"Census STC FY{year}",
notes=f"Census STC FY{source_year or year}",
)
)

Expand All @@ -263,14 +266,22 @@ def main():
)
_, year = etl_argparser("ETL for state income tax calibration targets")

logger.info(f"Extracting Census STC data for FY{year}...")
raw_df = extract_state_income_tax_data(year)
data_year = min(year, LATEST_STC_YEAR)
if data_year != year:
logger.warning(
f"Census STC data not available for {year}; "
f"using latest available year ({LATEST_STC_YEAR})"
)
logger.info(f"Extracting Census STC data for FY{data_year}...")
raw_df = extract_state_income_tax_data(data_year)

logger.info("Transforming data...")
transformed_df = transform_state_income_tax_data(raw_df)

logger.info(f"Loading {len(transformed_df)} state income tax targets...")
stratum_lookup = load_state_income_tax_data(transformed_df, year)
stratum_lookup = load_state_income_tax_data(
transformed_df, year, source_year=data_year
)

# Print summary
total_collections = transformed_df["income_tax_collections"].sum()
Expand Down
75 changes: 52 additions & 23 deletions policyengine_us_data/storage/upload_completed_datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import h5py
from pathlib import Path

from policyengine_us_data.datasets import (
EnhancedCPS_2024,
)
import h5py
from policyengine_core.data import Dataset

from policyengine_us_data.datasets import EnhancedCPS_2024
from policyengine_us_data.datasets.cps.cps import CPS_2024
from policyengine_us_data.storage import STORAGE_FOLDER
from policyengine_us_data.utils.data_upload import upload_data_files
from policyengine_us_data.utils.dataset_validation import (
DatasetContractError,
load_dataset_for_validation,
validate_dataset_contract,
)

# Datasets that require full validation before upload.
# These are the main datasets used in production simulations.
Expand All @@ -15,14 +20,9 @@
"cps_2024.h5",
}

FILENAME_TO_DATASET = {
"enhanced_cps_2024.h5": EnhancedCPS_2024,
"cps_2024.h5": CPS_2024,
}

# Minimum file sizes in bytes for validated datasets.
MIN_FILE_SIZES = {
"enhanced_cps_2024.h5": 100 * 1024 * 1024, # 100 MB
"enhanced_cps_2024.h5": 95 * 1024 * 1024, # 95 MB
"cps_2024.h5": 50 * 1024 * 1024, # 50 MB
}

Expand Down Expand Up @@ -118,15 +118,23 @@ def _check_group_has_data(f, name):
+ "\n".join(f" - {e}" for e in errors)
)

try:
contract_summary = validate_dataset_contract(file_path)
except DatasetContractError as e:
errors.append(f"Dataset contract validation failed: {e}")
raise DatasetValidationError(
f"Validation failed for {filename}:\n"
+ "\n".join(f" - {e}" for e in errors)
) from e

# 3. Aggregate statistics check via Microsimulation
# Import here to avoid heavy import at module level.
from policyengine_us import Microsimulation

try:
dataset_cls = FILENAME_TO_DATASET.get(filename)
if dataset_cls is None:
raise DatasetValidationError(f"No dataset class registered for {filename}")
sim = Microsimulation(dataset=dataset_cls)
sim = Microsimulation(
dataset=load_dataset_for_validation(file_path, Dataset.from_file)
)
year = 2024

emp_income = sim.calculate("employment_income", year).sum()
Expand Down Expand Up @@ -159,6 +167,15 @@ def _check_group_has_data(f, name):

print(f" ✓ Validation passed for {filename}")
print(f" File size: {file_size / 1024 / 1024:.1f} MB")
print(
" policyengine-us: "
f"{contract_summary.policyengine_us.version}"
+ (
f" (locked {contract_summary.policyengine_us.locked_version})"
if contract_summary.policyengine_us.locked_version
else ""
)
)
print(f" employment_income sum: ${emp_income:,.0f}")
print(f" Household weight sum: {hh_weight:,.0f}")

Expand Down Expand Up @@ -210,14 +227,18 @@ def upload_datasets(require_enhanced_cps: bool = True):

def validate_all_datasets():
"""Validate all main datasets in storage. Called by `make validate-data`."""
for filename in VALIDATED_FILENAMES:
file_path = STORAGE_FOLDER / filename
if file_path.exists():
validate_dataset(file_path)
else:
raise FileNotFoundError(
f"Expected dataset {filename} not found at {file_path}"
)
validate_built_datasets(require_enhanced_cps=True)


def validate_built_datasets(require_enhanced_cps: bool = True):
required_files = [CPS_2024.file_path]
if require_enhanced_cps:
required_files.append(EnhancedCPS_2024.file_path)

for file_path in required_files:
if not file_path.exists():
raise FileNotFoundError(f"Expected dataset not found at {file_path}")
validate_dataset(file_path)
print("\nAll dataset validations passed.")


Expand All @@ -230,5 +251,13 @@ def validate_all_datasets():
action="store_true",
help="Treat enhanced_cps and small_enhanced_cps as optional.",
)
parser.add_argument(
"--validate-only",
action="store_true",
help="Validate built datasets without uploading them.",
)
args = parser.parse_args()
upload_datasets(require_enhanced_cps=not args.no_require_enhanced_cps)
if args.validate_only:
validate_built_datasets(require_enhanced_cps=not args.no_require_enhanced_cps)
else:
upload_datasets(require_enhanced_cps=not args.no_require_enhanced_cps)
Loading
Loading