Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,7 @@ jobs:
num_workers=int('${NUM_WORKERS}'),
skip_national='${SKIP_NATIONAL}' == 'true',
)
print(f'Pipeline spawned. Monitor on the Modal dashboard.')
print(f'::notice ::Modal call ID: {fc.object_id}')
print(f'::notice ::Dashboard: https://modal.com/apps/policyengine/main/deployed/policyengine-us-data-pipeline')
print(f'Pipeline spawned. Call ID: {fc.object_id}')
"
11 changes: 11 additions & 0 deletions .github/workflows/versioning.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,22 @@ jobs:
if: (github.event.head_commit.message != 'Update package version')
runs-on: ubuntu-latest
steps:
# Checkout requires a PAT (POLICYENGINE_GITHUB) with repo write
# access so the workflow can push the version-bump commit back to
# main. If the secret is missing or expired the step fails with a
# cryptic git-auth error. See issue #677 for PAT rotation.
- name: Checkout repo
id: checkout
continue-on-error: true
uses: actions/checkout@v4
with:
token: ${{ secrets.POLICYENGINE_GITHUB }}
fetch-depth: 0
- name: Abort if checkout failed (PAT issue)
if: steps.checkout.outcome == 'failure'
run: |
echo "::error ::Checkout failed — the POLICYENGINE_GITHUB PAT is likely expired or missing. See https://github.com/PolicyEngine/policyengine-us-data/issues/677"
exit 1
- name: Setup Python
uses: actions/setup-python@v5
with:
Expand Down
6 changes: 6 additions & 0 deletions modal_app/worker_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def main():
NYC_COUNTY_FIPS,
AT_LARGE_DISTRICTS,
)
from policyengine_us_data.utils.validate_h5 import validate_h5_or_raise
from policyengine_us_data.calibration.calibration_utils import (
STATE_CODES,
)
Expand Down Expand Up @@ -426,6 +427,11 @@ def main():
raise ValueError(f"Unknown item type: {item_type}")

if path:
validate_h5_or_raise(
path,
label=f"{item_type}:{item_id}",
period=args.period,
)
results["completed"].append(f"{item_type}:{item_id}")
print(
f"Completed {item_type}:{item_id}",
Expand Down
165 changes: 165 additions & 0 deletions policyengine_us_data/tests/test_validate_h5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""Tests for H5 pre-publish validation."""

from unittest.mock import patch, MagicMock

import h5py
import numpy as np
import pytest

from policyengine_us_data.utils.validate_h5 import (
validate_h5_entity_dimensions,
validate_h5_or_raise,
)


def _make_mock_tbs(variable_entities: dict[str, str]):
"""Build a mock CountryTaxBenefitSystem with given variable->entity mappings."""
tbs = MagicMock()
variables = {}
for var_name, entity_key in variable_entities.items():
var_mock = MagicMock()
var_mock.entity.key = entity_key
variables[var_name] = var_mock
tbs.variables = variables
return tbs


def _write_h5_flat(path, datasets: dict[str, np.ndarray]):
"""Flat layout: datasets at the top level (storage files)."""
with h5py.File(path, "w") as f:
for name, arr in datasets.items():
f.create_dataset(name, data=arr)


def _write_h5_nested(path, period, datasets: dict[str, np.ndarray]):
"""Nested layout: variable/period (pipeline-built files)."""
with h5py.File(path, "w") as f:
for name, arr in datasets.items():
grp = f.create_group(name)
grp.create_dataset(str(period), data=arr)


PERIOD = 2024
N_PERSONS = 10
N_HOUSEHOLDS = 5

GOOD_DATA = {
"person_id": np.arange(N_PERSONS),
"household_id": np.arange(N_HOUSEHOLDS),
"age": np.ones(N_PERSONS),
"income": np.ones(N_PERSONS),
"household_weight": np.ones(N_HOUSEHOLDS),
}


@pytest.fixture
def mock_tbs():
return _make_mock_tbs(
{
"person_id": "person",
"household_id": "household",
"age": "person",
"household_weight": "household",
"income": "person",
}
)


class TestFlatLayout:
def test_all_correct(self, tmp_path, mock_tbs):
h5_path = tmp_path / "good.h5"
_write_h5_flat(h5_path, GOOD_DATA)
with patch(
"policyengine_us.CountryTaxBenefitSystem",
return_value=mock_tbs,
):
results = validate_h5_entity_dimensions(h5_path, period=PERIOD)
assert results == []

def test_wrong_person_length(self, tmp_path, mock_tbs):
h5_path = tmp_path / "bad.h5"
data = {**GOOD_DATA, "age": np.ones(N_PERSONS + 99)}
_write_h5_flat(h5_path, data)
with patch(
"policyengine_us.CountryTaxBenefitSystem",
return_value=mock_tbs,
):
results = validate_h5_entity_dimensions(h5_path, period=PERIOD)
dim_fails = [r for r in results if r["check"] == "dimension"]
assert len(dim_fails) == 1
assert "age" in dim_fails[0]["detail"]


class TestNestedLayout:
def test_all_correct(self, tmp_path, mock_tbs):
h5_path = tmp_path / "good_nested.h5"
_write_h5_nested(h5_path, PERIOD, GOOD_DATA)
with patch(
"policyengine_us.CountryTaxBenefitSystem",
return_value=mock_tbs,
):
results = validate_h5_entity_dimensions(h5_path, period=PERIOD)
assert results == []

def test_wrong_person_length(self, tmp_path, mock_tbs):
h5_path = tmp_path / "bad_nested.h5"
data = {**GOOD_DATA, "age": np.ones(N_PERSONS + 99)}
_write_h5_nested(h5_path, PERIOD, data)
with patch(
"policyengine_us.CountryTaxBenefitSystem",
return_value=mock_tbs,
):
results = validate_h5_entity_dimensions(h5_path, period=PERIOD)
dim_fails = [r for r in results if r["check"] == "dimension"]
assert len(dim_fails) == 1
assert "age" in dim_fails[0]["detail"]


class TestOrRaise:
def test_passes(self, tmp_path, mock_tbs):
h5_path = tmp_path / "good.h5"
_write_h5_flat(h5_path, GOOD_DATA)
with patch(
"policyengine_us.CountryTaxBenefitSystem",
return_value=mock_tbs,
):
validate_h5_or_raise(h5_path, period=PERIOD)

def test_raises_on_mismatch(self, tmp_path, mock_tbs):
h5_path = tmp_path / "bad.h5"
data = {**GOOD_DATA, "age": np.ones(N_PERSONS + 99)}
_write_h5_flat(h5_path, data)
with patch(
"policyengine_us.CountryTaxBenefitSystem",
return_value=mock_tbs,
):
with pytest.raises(ValueError, match="age"):
validate_h5_or_raise(h5_path, period=PERIOD)


class TestMissingHouseholdWeight:
def test_missing_weight(self, tmp_path, mock_tbs):
h5_path = tmp_path / "no_weight.h5"
data = {k: v for k, v in GOOD_DATA.items() if k != "household_weight"}
_write_h5_flat(h5_path, data)
with patch(
"policyengine_us.CountryTaxBenefitSystem",
return_value=mock_tbs,
):
results = validate_h5_entity_dimensions(h5_path, period=PERIOD)
checks = [r["check"] for r in results]
assert "household_weight_exists" in checks


class TestAllZeroWeights:
def test_zero_weights(self, tmp_path, mock_tbs):
h5_path = tmp_path / "zero_weight.h5"
data = {**GOOD_DATA, "household_weight": np.zeros(N_HOUSEHOLDS)}
_write_h5_flat(h5_path, data)
with patch(
"policyengine_us.CountryTaxBenefitSystem",
return_value=mock_tbs,
):
results = validate_h5_entity_dimensions(h5_path, period=PERIOD)
checks = [r["check"] for r in results]
assert "household_weight_nonzero" in checks
156 changes: 156 additions & 0 deletions policyengine_us_data/utils/validate_h5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Pre-publish validation for H5 dataset files.

Checks entity dimension consistency and weight sanity before upload.
"""

from __future__ import annotations

import sys
from pathlib import Path

import h5py
import numpy as np

from policyengine_us_data.utils.downsample import ENTITY_ID_VARIABLES


def _read_array(f: h5py.File, var_name: str, period: int):
"""Read a variable array, handling both H5 layouts.

Pipeline-built files use ``variable/period`` nesting (groups at top level,
datasets underneath keyed by year). Storage flat files store datasets
directly at the top level with no period sub-key.

Returns None if the variable is not found.
"""
if var_name not in f:
return None
item = f[var_name]
if isinstance(item, h5py.Dataset):
return item
# Group — look for period sub-key
period_key = str(period)
if period_key in item:
return item[period_key]
return None


def validate_h5_entity_dimensions(
h5_path: str | Path, period: int = 2024
) -> list[dict]:
"""Validate that every variable in the H5 has the correct entity length.

Args:
h5_path: Path to an H5 dataset file.
period: Tax year key inside the H5.

Returns:
List of ``{check, status, detail}`` dicts (empty means all OK).
"""
from policyengine_us import CountryTaxBenefitSystem

tbs = CountryTaxBenefitSystem()
results: list[dict] = []
h5_path = Path(h5_path)

with h5py.File(h5_path, "r") as f:
variable_names = list(f.keys())

entity_counts: dict[str, int] = {}
for entity_key, id_var in ENTITY_ID_VARIABLES.items():
arr = _read_array(f, id_var, period)
if arr is not None:
entity_counts[entity_key] = len(arr)

for var_name in variable_names:
variable_meta = tbs.variables.get(var_name)
if variable_meta is None:
continue
entity_key = getattr(getattr(variable_meta, "entity", None), "key", None)
expected = entity_counts.get(entity_key)
if expected is None:
continue
arr = _read_array(f, var_name, period)
if arr is None:
continue
actual = len(arr)
if actual != expected:
results.append(
{
"check": "dimension",
"status": "FAIL",
"detail": (
f"{var_name} ({entity_key}): "
f"expected {expected}, got {actual}"
),
}
)

# household_weight existence and sanity
hw = _read_array(f, "household_weight", period)
if hw is None:
results.append(
{
"check": "household_weight_exists",
"status": "FAIL",
"detail": "household_weight not found in H5",
}
)
else:
if np.all(np.asarray(hw) == 0):
results.append(
{
"check": "household_weight_nonzero",
"status": "FAIL",
"detail": "all household_weight values are zero",
}
)

hh_count = entity_counts.get("household", 0)
if hh_count == 0:
results.append(
{
"check": "household_count",
"status": "FAIL",
"detail": "household count is zero",
}
)

return results


def validate_h5_or_raise(
h5_path: str | Path, label: str = "", period: int = 2024
) -> None:
"""Run all H5 validations and raise on any failure.

Args:
h5_path: Path to the H5 file.
label: Optional label for error messages.
period: Tax year key inside the H5.

Raises:
ValueError: If any validation check fails.
"""
failures = validate_h5_entity_dimensions(h5_path, period=period)
if failures:
tag = f" [{label}]" if label else ""
lines = [f"H5 validation failed{tag} for {h5_path}:"]
for f in failures:
lines.append(f" {f['check']}: {f['detail']}")
raise ValueError("\n".join(lines))


if __name__ == "__main__":
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <h5_path> [period]", file=sys.stderr)
sys.exit(1)
path = sys.argv[1]
yr = int(sys.argv[2]) if len(sys.argv) > 2 else 2024
issues = validate_h5_entity_dimensions(path, period=yr)
if issues:
for issue in issues:
print(f"[{issue['status']}] {issue['check']}: {issue['detail']}")
sys.exit(1)
else:
print(f"All checks passed for {path}")
Loading