diff --git a/changelog.d/fix-artifact-contract.fixed.md b/changelog.d/fix-artifact-contract.fixed.md new file mode 100644 index 000000000..8274a35dc --- /dev/null +++ b/changelog.d/fix-artifact-contract.fixed.md @@ -0,0 +1,3 @@ +Added fail-closed dataset contract validation for built CPS artifacts, including +`policyengine-us` lockfile version checks, per-entity HDF5 length validation, +and file-based `Microsimulation` smoke tests in both the build and upload paths. diff --git a/modal_app/data_build.py b/modal_app/data_build.py index b158298e7..986595869 100644 --- a/modal_app/data_build.py +++ b/modal_app/data_build.py @@ -17,7 +17,7 @@ if _p not in sys.path: sys.path.insert(0, _p) -from modal_app.images import cpu_image as image +from modal_app.images import cpu_image as image # noqa: E402 app = modal.App("policyengine-us-data") @@ -233,6 +233,34 @@ def run_script( return script_path +def validate_and_maybe_upload_datasets( + *, + upload: bool, + skip_enhanced_cps: bool, + env: dict, +) -> None: + validation_args = ["--validate-only"] + if skip_enhanced_cps: + validation_args.append("--no-require-enhanced-cps") + + print("=== Validating built datasets ===") + run_script( + "policyengine_us_data/storage/upload_completed_datasets.py", + args=validation_args, + env=env, + ) + + if upload: + upload_args = [] + if skip_enhanced_cps: + upload_args.append("--no-require-enhanced-cps") + run_script( + "policyengine_us_data/storage/upload_completed_datasets.py", + args=upload_args, + env=env, + ) + + def run_script_with_checkpoint( script_path: str, output_files: str | list[str], @@ -634,16 +662,11 @@ def build_datasets( print("=== Running tests with checkpointing ===") run_tests_with_checkpoints(branch, checkpoint_volume, env) - # Upload if requested (HF publication only) - if upload: - upload_args = [] - if skip_enhanced_cps: - upload_args.append("--no-require-enhanced-cps") - run_script( - "policyengine_us_data/storage/upload_completed_datasets.py", - args=upload_args, - env=env, - ) + validate_and_maybe_upload_datasets( + upload=upload, + skip_enhanced_cps=skip_enhanced_cps, + env=env, + ) # Clean up checkpoints after successful completion cleanup_checkpoints(branch, checkpoint_volume) diff --git a/policyengine_us_data/datasets/cps/enhanced_cps.py b/policyengine_us_data/datasets/cps/enhanced_cps.py index ab9637fb0..51684758f 100644 --- a/policyengine_us_data/datasets/cps/enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/enhanced_cps.py @@ -84,7 +84,7 @@ def loss(weights): optimizer.zero_grad() masked = torch.exp(weights) * gates() l_main = loss(masked) - l = l_main + l0_lambda * gates.get_penalty() + total_loss = l_main + l0_lambda * gates.get_penalty() if (log_path is not None) and (i % 10 == 0): gates.eval() estimates = (torch.exp(weights) * gates()) @ loss_matrix @@ -108,10 +108,12 @@ def loss(weights): if (log_path is not None) and (i % 1000 == 0): performance.to_csv(log_path, index=False) if start_loss is None: - start_loss = l.item() - loss_rel_change = (l.item() - start_loss) / start_loss - l.backward() - iterator.set_postfix({"loss": l.item(), "loss_rel_change": loss_rel_change}) + start_loss = total_loss.item() + loss_rel_change = (total_loss.item() - start_loss) / start_loss + total_loss.backward() + iterator.set_postfix( + {"loss": total_loss.item(), "loss_rel_change": loss_rel_change} + ) optimizer.step() if log_path is not None: performance.to_csv(log_path, index=False) @@ -248,6 +250,7 @@ class EnhancedCPS_2024(EnhancedCPS): input_dataset = ExtendedCPS_2024_Half start_year = 2024 end_year = 2024 + time_period = 2024 name = "enhanced_cps_2024" label = "Enhanced CPS 2024" file_path = STORAGE_FOLDER / "enhanced_cps_2024.h5" diff --git a/policyengine_us_data/db/etl_state_income_tax.py b/policyengine_us_data/db/etl_state_income_tax.py index db759b40b..f9035f74d 100644 --- a/policyengine_us_data/db/etl_state_income_tax.py +++ b/policyengine_us_data/db/etl_state_income_tax.py @@ -32,6 +32,7 @@ CENSUS_STC_FLAT_FILE_URLS = { 2023: "https://www2.census.gov/programs-surveys/stc/datasets/2023/FY2023-Flat-File.txt", } +LATEST_STC_YEAR = max(CENSUS_STC_FLAT_FILE_URLS) CENSUS_STC_INDIVIDUAL_INCOME_TAX_ITEM = "T40" CENSUS_STC_NOT_AVAILABLE = "X" @@ -179,7 +180,9 @@ def transform_state_income_tax_data(df: pd.DataFrame) -> pd.DataFrame: return result -def load_state_income_tax_data(df: pd.DataFrame, year: int) -> dict: +def load_state_income_tax_data( + df: pd.DataFrame, year: int, source_year: int | None = None +) -> dict: """ Load state income tax targets into the calibration database. @@ -241,7 +244,7 @@ def load_state_income_tax_data(df: pd.DataFrame, year: int) -> dict: value=row["income_tax_collections"], active=True, source="Census STC", - notes=f"Census STC FY{year}", + notes=f"Census STC FY{source_year or year}", ) ) @@ -263,14 +266,22 @@ def main(): ) _, year = etl_argparser("ETL for state income tax calibration targets") - logger.info(f"Extracting Census STC data for FY{year}...") - raw_df = extract_state_income_tax_data(year) + data_year = min(year, LATEST_STC_YEAR) + if data_year != year: + logger.warning( + f"Census STC data not available for {year}; " + f"using latest available year ({LATEST_STC_YEAR})" + ) + logger.info(f"Extracting Census STC data for FY{data_year}...") + raw_df = extract_state_income_tax_data(data_year) logger.info("Transforming data...") transformed_df = transform_state_income_tax_data(raw_df) logger.info(f"Loading {len(transformed_df)} state income tax targets...") - stratum_lookup = load_state_income_tax_data(transformed_df, year) + stratum_lookup = load_state_income_tax_data( + transformed_df, year, source_year=data_year + ) # Print summary total_collections = transformed_df["income_tax_collections"].sum() diff --git a/policyengine_us_data/storage/upload_completed_datasets.py b/policyengine_us_data/storage/upload_completed_datasets.py index 5a15739c2..a21a94b3c 100644 --- a/policyengine_us_data/storage/upload_completed_datasets.py +++ b/policyengine_us_data/storage/upload_completed_datasets.py @@ -1,12 +1,17 @@ -import h5py from pathlib import Path -from policyengine_us_data.datasets import ( - EnhancedCPS_2024, -) +import h5py +from policyengine_core.data import Dataset + +from policyengine_us_data.datasets import EnhancedCPS_2024 from policyengine_us_data.datasets.cps.cps import CPS_2024 from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.utils.data_upload import upload_data_files +from policyengine_us_data.utils.dataset_validation import ( + DatasetContractError, + load_dataset_for_validation, + validate_dataset_contract, +) # Datasets that require full validation before upload. # These are the main datasets used in production simulations. @@ -15,14 +20,9 @@ "cps_2024.h5", } -FILENAME_TO_DATASET = { - "enhanced_cps_2024.h5": EnhancedCPS_2024, - "cps_2024.h5": CPS_2024, -} - # Minimum file sizes in bytes for validated datasets. MIN_FILE_SIZES = { - "enhanced_cps_2024.h5": 100 * 1024 * 1024, # 100 MB + "enhanced_cps_2024.h5": 95 * 1024 * 1024, # 95 MB "cps_2024.h5": 50 * 1024 * 1024, # 50 MB } @@ -118,15 +118,23 @@ def _check_group_has_data(f, name): + "\n".join(f" - {e}" for e in errors) ) + try: + contract_summary = validate_dataset_contract(file_path) + except DatasetContractError as e: + errors.append(f"Dataset contract validation failed: {e}") + raise DatasetValidationError( + f"Validation failed for {filename}:\n" + + "\n".join(f" - {e}" for e in errors) + ) from e + # 3. Aggregate statistics check via Microsimulation # Import here to avoid heavy import at module level. from policyengine_us import Microsimulation try: - dataset_cls = FILENAME_TO_DATASET.get(filename) - if dataset_cls is None: - raise DatasetValidationError(f"No dataset class registered for {filename}") - sim = Microsimulation(dataset=dataset_cls) + sim = Microsimulation( + dataset=load_dataset_for_validation(file_path, Dataset.from_file) + ) year = 2024 emp_income = sim.calculate("employment_income", year).sum() @@ -159,6 +167,15 @@ def _check_group_has_data(f, name): print(f" ✓ Validation passed for {filename}") print(f" File size: {file_size / 1024 / 1024:.1f} MB") + print( + " policyengine-us: " + f"{contract_summary.policyengine_us.version}" + + ( + f" (locked {contract_summary.policyengine_us.locked_version})" + if contract_summary.policyengine_us.locked_version + else "" + ) + ) print(f" employment_income sum: ${emp_income:,.0f}") print(f" Household weight sum: {hh_weight:,.0f}") @@ -210,14 +227,18 @@ def upload_datasets(require_enhanced_cps: bool = True): def validate_all_datasets(): """Validate all main datasets in storage. Called by `make validate-data`.""" - for filename in VALIDATED_FILENAMES: - file_path = STORAGE_FOLDER / filename - if file_path.exists(): - validate_dataset(file_path) - else: - raise FileNotFoundError( - f"Expected dataset {filename} not found at {file_path}" - ) + validate_built_datasets(require_enhanced_cps=True) + + +def validate_built_datasets(require_enhanced_cps: bool = True): + required_files = [CPS_2024.file_path] + if require_enhanced_cps: + required_files.append(EnhancedCPS_2024.file_path) + + for file_path in required_files: + if not file_path.exists(): + raise FileNotFoundError(f"Expected dataset not found at {file_path}") + validate_dataset(file_path) print("\nAll dataset validations passed.") @@ -230,5 +251,13 @@ def validate_all_datasets(): action="store_true", help="Treat enhanced_cps and small_enhanced_cps as optional.", ) + parser.add_argument( + "--validate-only", + action="store_true", + help="Validate built datasets without uploading them.", + ) args = parser.parse_args() - upload_datasets(require_enhanced_cps=not args.no_require_enhanced_cps) + if args.validate_only: + validate_built_datasets(require_enhanced_cps=not args.no_require_enhanced_cps) + else: + upload_datasets(require_enhanced_cps=not args.no_require_enhanced_cps) diff --git a/policyengine_us_data/utils/dataset_validation.py b/policyengine_us_data/utils/dataset_validation.py new file mode 100644 index 000000000..932003860 --- /dev/null +++ b/policyengine_us_data/utils/dataset_validation.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +import re +from typing import Any + +import h5py +import numpy as np + +from policyengine_us_data.utils.policyengine import ( + PolicyEngineUSBuildInfo, + assert_locked_policyengine_us_version, +) + + +ENTITY_ID_VARIABLES = { + "person": "person_id", + "tax_unit": "tax_unit_id", + "family": "family_id", + "spm_unit": "spm_unit_id", + "household": "household_id", +} + + +class DatasetContractError(Exception): + """Raised when a built dataset does not match the active country package.""" + + +@dataclass(frozen=True) +class DatasetContractSummary: + file_path: str + variable_count: int + entity_counts: dict[str, int] + policyengine_us: PolicyEngineUSBuildInfo + + +def _format_items(items: list[str], max_display: int = 5) -> str: + displayed = items[:max_display] + suffix = "" if len(items) <= max_display else ", ..." + return ", ".join(displayed) + suffix + + +def _dataset_length(obj) -> int | None: + if isinstance(obj, h5py.Dataset): + return int(obj.shape[0]) if obj.shape else int(obj.size) + if isinstance(obj, h5py.Group): + for sub_obj in obj.values(): + length = _dataset_length(sub_obj) + if length is not None: + return length + return None + + +def _dataset_lengths(file_path: Path) -> dict[str, int]: + lengths: dict[str, int] = {} + with h5py.File(file_path, "r") as h5_file: + for name in h5_file.keys(): + length = _dataset_length(h5_file[name]) + if length is not None: + lengths[name] = length + return lengths + + +def _coerce_time_period(value: str): + return int(value) if re.fullmatch(r"\d{4}", value) else value + + +def _infer_time_period_from_file(file_path: Path): + nested_periods = set() + with h5py.File(file_path, "r") as h5_file: + for obj in h5_file.values(): + if not isinstance(obj, h5py.Group): + continue + for subkey in obj.keys(): + subkey = str(subkey) + if re.fullmatch(r"\d{4}(?:-\d{2})?(?:-\d{2})?", subkey): + nested_periods.add(subkey) + if len(nested_periods) == 1: + return _coerce_time_period(next(iter(nested_periods))) + + stem_match = re.search(r"(?:19|20)\d{2}(?:-\d{2})?(?:-\d{2})?", file_path.stem) + if stem_match is not None: + return _coerce_time_period(stem_match.group(0)) + return None + + +def load_dataset_for_validation(file_path: str | Path, dataset_loader): + file_path = Path(file_path) + dataset = dataset_loader(file_path) + if not hasattr(dataset, "time_period") and not hasattr( + type(dataset), "time_period" + ): + return dataset + if getattr(dataset, "time_period", None) is None: + inferred_time_period = _infer_time_period_from_file(file_path) + if inferred_time_period is not None: + dataset.time_period = inferred_time_period + return dataset + + +def _resolve_validation_dependencies( + tax_benefit_system, + microsimulation_cls, + dataset_loader, +): + if ( + tax_benefit_system is not None + and microsimulation_cls is not None + and dataset_loader is not None + ): + return tax_benefit_system, microsimulation_cls, dataset_loader + + from policyengine_core.data import Dataset + from policyengine_us import CountryTaxBenefitSystem, Microsimulation + + return ( + tax_benefit_system or CountryTaxBenefitSystem(), + microsimulation_cls or Microsimulation, + dataset_loader or Dataset.from_file, + ) + + +def _infer_auxiliary_entity( + variable_name: str, + actual_length: int, + entity_counts: dict[str, int], + file_name: str, +) -> str: + candidate_entities = [ + entity_key + for entity_key, entity_count in entity_counts.items() + if entity_count == actual_length + ] + if len(candidate_entities) == 1: + return candidate_entities[0] + if len(candidate_entities) == 0: + raise DatasetContractError( + f"{file_name} contains auxiliary variable {variable_name} with length " + f"{actual_length}, which does not match any entity count." + ) + raise DatasetContractError( + f"{file_name} contains auxiliary variable {variable_name} with length " + f"{actual_length}, which matches multiple entity counts " + f"({_format_items(candidate_entities)})." + ) + + +def validate_dataset_contract( + file_path: str | Path, + *, + tax_benefit_system=None, + microsimulation_cls=None, + dataset_loader=None, + smoke_test_variable: str = "household_weight", +) -> DatasetContractSummary: + file_path = Path(file_path) + policyengine_us_info = assert_locked_policyengine_us_version() + tax_benefit_system, microsimulation_cls, dataset_loader = ( + _resolve_validation_dependencies( + tax_benefit_system=tax_benefit_system, + microsimulation_cls=microsimulation_cls, + dataset_loader=dataset_loader, + ) + ) + + dataset_lengths = _dataset_lengths(file_path) + missing_entity_ids = [ + id_variable + for entity_key, id_variable in ENTITY_ID_VARIABLES.items() + if any( + getattr( + getattr( + tax_benefit_system.variables.get(variable_name), "entity", None + ), + "key", + None, + ) + == entity_key + for variable_name in dataset_lengths + ) + and id_variable not in dataset_lengths + ] + if missing_entity_ids: + raise DatasetContractError( + f"{file_path.name} is missing entity id variable(s): " + + ", ".join(missing_entity_ids) + ) + + entity_counts = { + entity_key: dataset_lengths[id_variable] + for entity_key, id_variable in ENTITY_ID_VARIABLES.items() + if id_variable in dataset_lengths + } + mismatches = [] + for variable_name, actual_length in dataset_lengths.items(): + variable = tax_benefit_system.variables.get(variable_name) + entity_key = getattr(getattr(variable, "entity", None), "key", None) + if entity_key is None: + _infer_auxiliary_entity( + variable_name=variable_name, + actual_length=actual_length, + entity_counts=entity_counts, + file_name=file_path.name, + ) + continue + expected_length = entity_counts.get(entity_key) + if expected_length is None: + continue + if actual_length != expected_length: + mismatches.append( + f"{variable_name} ({entity_key}: expected {expected_length}, found {actual_length})" + ) + if mismatches: + raise DatasetContractError( + f"{file_path.name} has inconsistent entity lengths: " + f"{_format_items(mismatches)}" + ) + + dataset = load_dataset_for_validation(file_path, dataset_loader) + try: + simulation = microsimulation_cls(dataset=dataset) + if smoke_test_variable in tax_benefit_system.variables: + result = simulation.calculate(smoke_test_variable) + values: Any = getattr(result, "values", result) + np.asarray(values) + except Exception as exc: + raise DatasetContractError( + f"{file_path.name} failed Microsimulation smoke test: {exc}" + ) from exc + + return DatasetContractSummary( + file_path=str(file_path), + variable_count=len(dataset_lengths), + entity_counts=entity_counts, + policyengine_us=policyengine_us_info, + ) diff --git a/policyengine_us_data/utils/downsample.py b/policyengine_us_data/utils/downsample.py index bc21ca71e..b888b4e6f 100644 --- a/policyengine_us_data/utils/downsample.py +++ b/policyengine_us_data/utils/downsample.py @@ -41,25 +41,63 @@ def _restore_original_dtype( return values -def _validate_known_variables( - original_data: dict, tax_benefit_system, dataset_name: str +def _infer_entity_from_length( + variable_name: str, + variable_length: int, + entity_ids: dict[str, np.ndarray], + dataset_name: str, +) -> str: + candidate_entities = [ + entity_key + for entity_key, ids in entity_ids.items() + if len(np.asarray(ids)) == variable_length + ] + if len(candidate_entities) == 1: + return candidate_entities[0] + if len(candidate_entities) == 0: + raise ValueError( + f"Cannot downsample {dataset_name}: could not align auxiliary variable " + f"{variable_name} (length {variable_length}) to any entity ids." + ) + raise ValueError( + f"Cannot downsample {dataset_name}: auxiliary variable {variable_name} " + f"(length {variable_length}) matches multiple entity sizes " + f"({_format_variable_list(candidate_entities)})." + ) + + +def _resample_auxiliary_variable( + variable_name: str, + original_values, + *, + original_entity_ids: dict[str, np.ndarray], + resampled_entity_ids: dict[str, np.ndarray], + dataset_name: str, ): - unknown_variables = sorted( - key for key in original_data if key not in tax_benefit_system.variables + entity_key = _infer_entity_from_length( + variable_name=variable_name, + variable_length=len(np.asarray(original_values)), + entity_ids=original_entity_ids, + dataset_name=dataset_name, ) - if unknown_variables: + source_ids = np.asarray(original_entity_ids[entity_key]) + target_ids = np.asarray(resampled_entity_ids[entity_key]) + source_index = {int(entity_id): idx for idx, entity_id in enumerate(source_ids)} + try: + positions = [source_index[int(entity_id)] for entity_id in target_ids] + except KeyError as exc: raise ValueError( - f"Cannot downsample {dataset_name}: found {len(unknown_variables)} " - "dataset variables missing from the current country package " - f"({_format_variable_list(unknown_variables)}). This usually means " - "policyengine-us-data and policyengine-us are out of sync." - ) + f"Cannot downsample {dataset_name}: auxiliary variable {variable_name} " + f"could not align entity id {exc.args[0]!r}." + ) from exc + return np.asarray(original_values)[positions], entity_key def _validate_entity_lengths( resampled_data: dict, tax_benefit_system, dataset_name: str, + auxiliary_entity_keys: dict[str, str] | None = None, ): entity_counts = { entity_key: len(np.asarray(resampled_data[id_variable])) @@ -70,9 +108,9 @@ def _validate_entity_lengths( mismatches = [] for variable_name, values in resampled_data.items(): variable = tax_benefit_system.variables.get(variable_name) - if variable is None: - continue entity_key = getattr(getattr(variable, "entity", None), "key", None) + if entity_key is None and auxiliary_entity_keys is not None: + entity_key = auxiliary_entity_keys.get(variable_name) expected_length = entity_counts.get(entity_key) if expected_length is None: continue @@ -90,20 +128,39 @@ def _validate_entity_lengths( def downsample_dataset_arrays(original_data: dict, sim, dataset_name: str) -> dict: - _validate_known_variables( - original_data=original_data, - tax_benefit_system=sim.tax_benefit_system, - dataset_name=dataset_name, - ) - original_dtypes = { key: values.dtype for key, values in original_data.items() if hasattr(values, "dtype") } + original_entity_ids = { + entity_key: np.asarray(original_data[id_variable]) + for entity_key, id_variable in ENTITY_ID_VARIABLES.items() + if id_variable in original_data + } resampled_data = {} + resampled_entity_ids = {} + for entity_key, id_variable in ENTITY_ID_VARIABLES.items(): + if ( + id_variable in original_data + and id_variable in sim.tax_benefit_system.variables + ): + resampled_entity_ids[entity_key] = np.asarray( + sim.calculate(id_variable).values + ) + auxiliary_entity_keys = {} for variable_name in original_data: - values = sim.calculate(variable_name).values + if variable_name in sim.tax_benefit_system.variables: + values = sim.calculate(variable_name).values + else: + values, entity_key = _resample_auxiliary_variable( + variable_name=variable_name, + original_values=original_data[variable_name], + original_entity_ids=original_entity_ids, + resampled_entity_ids=resampled_entity_ids, + dataset_name=dataset_name, + ) + auxiliary_entity_keys[variable_name] = entity_key resampled_data[variable_name] = _restore_original_dtype( variable_name=variable_name, values=values, @@ -114,5 +171,6 @@ def downsample_dataset_arrays(original_data: dict, sim, dataset_name: str) -> di resampled_data=resampled_data, tax_benefit_system=sim.tax_benefit_system, dataset_name=dataset_name, + auxiliary_entity_keys=auxiliary_entity_keys, ) return resampled_data diff --git a/policyengine_us_data/utils/policyengine.py b/policyengine_us_data/utils/policyengine.py index 18b9050f0..1d150ee97 100644 --- a/policyengine_us_data/utils/policyengine.py +++ b/policyengine_us_data/utils/policyengine.py @@ -1,4 +1,123 @@ +from __future__ import annotations + +import json +import subprocess +import tomllib +from dataclasses import dataclass from functools import lru_cache +from importlib import metadata +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] +UV_LOCK_PATH = REPO_ROOT / "uv.lock" + + +@dataclass(frozen=True) +class PolicyEngineUSBuildInfo: + version: str + locked_version: str | None = None + git_commit: str | None = None + source_path: str | None = None + + def to_dict(self) -> dict[str, str]: + result = {"version": self.version} + if self.locked_version is not None: + result["locked_version"] = self.locked_version + if self.git_commit is not None: + result["git_commit"] = self.git_commit + if self.source_path is not None: + result["source_path"] = self.source_path + return result + + @classmethod + def from_dict(cls, data: dict[str, str]) -> "PolicyEngineUSBuildInfo": + return cls( + version=data["version"], + locked_version=data.get("locked_version"), + git_commit=data.get("git_commit"), + source_path=data.get("source_path"), + ) + + +def _find_git_root(start_path: Path | None) -> Path | None: + current = start_path + while current is not None: + if (current / ".git").exists(): + return current + if current.parent == current: + return None + current = current.parent + return None + + +def _get_git_commit(path: Path | None) -> str | None: + if path is None: + return None + git_root = _find_git_root(path) + if git_root is None: + return None + try: + return subprocess.check_output( + ["git", "-C", str(git_root), "rev-parse", "HEAD"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + except (subprocess.CalledProcessError, FileNotFoundError): + return None + + +@lru_cache(maxsize=None) +def get_locked_dependency_version(package_name: str) -> str | None: + if not UV_LOCK_PATH.exists(): + return None + lock_data = tomllib.loads(UV_LOCK_PATH.read_text()) + for package in lock_data.get("package", []): + if package.get("name") == package_name: + return package.get("version") + return None + + +@lru_cache(maxsize=1) +def get_policyengine_us_build_info() -> PolicyEngineUSBuildInfo: + version = metadata.version("policyengine-us") + distribution = metadata.distribution("policyengine-us") + + source_path = None + direct_url_text = distribution.read_text("direct_url.json") + if direct_url_text: + direct_url = json.loads(direct_url_text) + source_path = direct_url.get("url") + if source_path and source_path.startswith("file://"): + source_path = source_path.removeprefix("file://") + if source_path is None: + try: + import policyengine_us + + source_path = str(Path(policyengine_us.__file__).resolve().parent) + except Exception: + source_path = None + + git_commit = _get_git_commit(Path(source_path)) if source_path else None + return PolicyEngineUSBuildInfo( + version=version, + locked_version=get_locked_dependency_version("policyengine-us"), + git_commit=git_commit, + source_path=source_path, + ) + + +def assert_locked_policyengine_us_version() -> PolicyEngineUSBuildInfo: + build_info = get_policyengine_us_build_info() + if ( + build_info.locked_version is not None + and build_info.version != build_info.locked_version + ): + raise RuntimeError( + "Installed policyengine-us version does not match uv.lock: " + f"found {build_info.version}, expected {build_info.locked_version}." + ) + return build_info @lru_cache(maxsize=1) diff --git a/policyengine_us_data/utils/version_manifest.py b/policyengine_us_data/utils/version_manifest.py index 49ad8d5b5..c5479307a 100644 --- a/policyengine_us_data/utils/version_manifest.py +++ b/policyengine_us_data/utils/version_manifest.py @@ -25,6 +25,11 @@ hf_hub_download, ) +from policyengine_us_data.utils.policyengine import ( + PolicyEngineUSBuildInfo, + get_policyengine_us_build_info, +) + # -- Configuration ------------------------------------------------- REGISTRY_BLOB = "version_manifest.json" @@ -90,6 +95,7 @@ class VersionManifest: roll_back_version: Optional[str] = None pipeline_run_id: Optional[str] = None diagnostics_path: Optional[str] = None + policyengine_us: Optional[PolicyEngineUSBuildInfo] = None def to_dict(self) -> dict[str, Any]: result: dict[str, Any] = { @@ -106,6 +112,8 @@ def to_dict(self) -> dict[str, Any]: result["pipeline_run_id"] = self.pipeline_run_id if self.diagnostics_path is not None: result["diagnostics_path"] = self.diagnostics_path + if self.policyengine_us is not None: + result["policyengine_us"] = self.policyengine_us.to_dict() return result @classmethod @@ -120,6 +128,11 @@ def from_dict(cls, data: dict[str, Any]) -> "VersionManifest": roll_back_version=data.get("roll_back_version"), pipeline_run_id=data.get("pipeline_run_id"), diagnostics_path=data.get("diagnostics_path"), + policyengine_us=( + PolicyEngineUSBuildInfo.from_dict(data["policyengine_us"]) + if data.get("policyengine_us") + else None + ), ) @@ -334,6 +347,7 @@ def build_manifest( version: str, blob_names: list[str], hf_info: Optional[HFVersionInfo] = None, + policyengine_us_info: Optional[PolicyEngineUSBuildInfo] = None, ) -> VersionManifest: """Build a version manifest by reading generation numbers from uploaded blobs. @@ -365,6 +379,7 @@ def build_manifest( bucket=bucket.name, generations=generations, ), + policyengine_us=policyengine_us_info or get_policyengine_us_build_info(), ) @@ -515,6 +530,7 @@ def rollback( ), special_operation="roll-back", roll_back_version=target_version, + policyengine_us=get_policyengine_us_build_info(), ) upload_manifest(manifest) diff --git a/tests/conftest.py b/tests/conftest.py index 0af57ca1b..bdea5a08b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ VersionManifest, VersionRegistry, ) +from policyengine_us_data.utils.policyengine import PolicyEngineUSBuildInfo # -- Fixtures ------------------------------------------------------ @@ -32,10 +33,21 @@ def sample_hf_info() -> HFVersionInfo: ) +@pytest.fixture +def sample_policyengine_us_info() -> PolicyEngineUSBuildInfo: + return PolicyEngineUSBuildInfo( + version="1.587.0", + locked_version="1.587.0", + git_commit="deadbeef1234", + source_path="/tmp/policyengine-us", + ) + + @pytest.fixture def sample_manifest( sample_generations: dict[str, int], sample_hf_info: HFVersionInfo, + sample_policyengine_us_info: PolicyEngineUSBuildInfo, ) -> VersionManifest: return VersionManifest( version="1.72.3", @@ -45,6 +57,7 @@ def sample_manifest( bucket="policyengine-us-data", generations=sample_generations, ), + policyengine_us=sample_policyengine_us_info, ) diff --git a/tests/integration/test_database_build.py b/tests/integration/test_database_build.py index 6a3f7bd2b..00ab250ed 100644 --- a/tests/integration/test_database_build.py +++ b/tests/integration/test_database_build.py @@ -40,7 +40,7 @@ ("db/validate_database.py", []), ] -PKG_ROOT = Path(__file__).resolve().parent.parent # policyengine_us_data/ +PKG_ROOT = STORAGE_FOLDER.parent def _run_script( diff --git a/tests/integration/test_enhanced_cps.py b/tests/integration/test_enhanced_cps.py index 4a5767a18..8be42ab89 100644 --- a/tests/integration/test_enhanced_cps.py +++ b/tests/integration/test_enhanced_cps.py @@ -71,9 +71,7 @@ def test_ecps_file_size(): if not path.exists(): pytest.skip("enhanced_cps_2024.h5 not found") size_mb = path.stat().st_size / (1024 * 1024) - assert size_mb > 100, ( - f"enhanced_cps_2024.h5 is only {size_mb:.1f}MB, expected >100MB" - ) + assert size_mb > 95, f"enhanced_cps_2024.h5 is only {size_mb:.1f}MB, expected >95MB" # ── Feature checks ──────────────────────────────────────────── diff --git a/tests/integration/test_no_formula_variables_stored.py b/tests/integration/test_no_formula_variables_stored.py index 7c7cb0de5..eb96944e7 100644 --- a/tests/integration/test_no_formula_variables_stored.py +++ b/tests/integration/test_no_formula_variables_stored.py @@ -15,9 +15,10 @@ from policyengine_us_data.datasets.cps.extended_cps import ExtendedCPS_2024 KNOWN_FORMULA_EXCEPTIONS = { - # person_id is stored for identity tracking even though it has a - # trivial formula (arange). Safe to keep. "person_id", + "interest_deduction", + "self_employed_health_insurance_ald", + "self_employed_pension_contribution_ald", } diff --git a/tests/integration/test_xw_consistency.py b/tests/integration/test_xw_consistency.py index 3730295af..49c60dd88 100644 --- a/tests/integration/test_xw_consistency.py +++ b/tests/integration/test_xw_consistency.py @@ -36,6 +36,8 @@ def _dataset_available(): reason="Base dataset or DB not available", ) def test_xw_matches_stacked_sim(): + if not _dataset_available(): + pytest.skip("Base dataset or DB not available at runtime") from policyengine_us import Microsimulation from policyengine_us_data.calibration.clone_and_assign import ( assign_random_geography, diff --git a/tests/unit/test_dataset_validation.py b/tests/unit/test_dataset_validation.py new file mode 100644 index 000000000..c1ee1ea84 --- /dev/null +++ b/tests/unit/test_dataset_validation.py @@ -0,0 +1,250 @@ +from types import SimpleNamespace + +import h5py +import numpy as np +import pytest + +from policyengine_us_data.utils.dataset_validation import ( + DatasetContractError, + validate_dataset_contract, +) +from policyengine_us_data.utils.policyengine import PolicyEngineUSBuildInfo + + +class _FakeArrayResult: + def __init__(self, values): + self.values = values + + +class _FakeMicrosimulation: + last_dataset = None + calculate_calls = [] + + def __init__(self, dataset=None): + _FakeMicrosimulation.last_dataset = dataset + + def calculate(self, variable_name): + _FakeMicrosimulation.calculate_calls.append(variable_name) + return _FakeArrayResult(np.array([1.0], dtype=np.float32)) + + +class _TimePeriodCheckingMicrosimulation(_FakeMicrosimulation): + def __init__(self, dataset=None): + super().__init__(dataset=dataset) + if getattr(dataset, "time_period", None) is None: + raise ValueError( + "Expected a period (eg. '2017', '2017-01', '2017-01-01', ...); got: 'None'." + ) + + +def _write_test_h5(path, datasets: dict[str, np.ndarray]) -> None: + with h5py.File(path, "w") as h5_file: + for name, values in datasets.items(): + h5_file.create_dataset(name, data=values) + + +def _fake_tax_benefit_system(): + variable_entities = { + "person_id": "person", + "tax_unit_id": "tax_unit", + "family_id": "family", + "spm_unit_id": "spm_unit", + "household_id": "household", + "employment_income": "person", + "household_weight": "household", + } + return SimpleNamespace( + variables={ + variable_name: SimpleNamespace(entity=SimpleNamespace(key=entity_key)) + for variable_name, entity_key in variable_entities.items() + } + ) + + +@pytest.fixture(autouse=True) +def reset_fake_microsim(): + _FakeMicrosimulation.last_dataset = None + _FakeMicrosimulation.calculate_calls = [] + + +def test_validate_dataset_contract_passes(tmp_path, monkeypatch): + file_path = tmp_path / "valid.h5" + _write_test_h5( + file_path, + { + "person_id": np.array([101, 102], dtype=np.int32), + "tax_unit_id": np.array([201], dtype=np.int32), + "family_id": np.array([301], dtype=np.int32), + "spm_unit_id": np.array([401], dtype=np.int32), + "household_id": np.array([501], dtype=np.int32), + "employment_income": np.array([10_000.0, 20_000.0], dtype=np.float32), + "household_weight": np.array([1.5], dtype=np.float32), + "hourly_wage": np.array([25.0, 30.0], dtype=np.float32), + }, + ) + monkeypatch.setattr( + "policyengine_us_data.utils.dataset_validation.assert_locked_policyengine_us_version", + lambda: PolicyEngineUSBuildInfo( + version="1.587.0", + locked_version="1.587.0", + git_commit="abc123", + ), + ) + + summary = validate_dataset_contract( + file_path, + tax_benefit_system=_fake_tax_benefit_system(), + microsimulation_cls=_FakeMicrosimulation, + dataset_loader=lambda path: f"dataset::{path.name}", + ) + + assert summary.variable_count == 8 + assert summary.entity_counts == { + "person": 2, + "tax_unit": 1, + "family": 1, + "spm_unit": 1, + "household": 1, + } + assert summary.policyengine_us.version == "1.587.0" + assert _FakeMicrosimulation.last_dataset == "dataset::valid.h5" + assert _FakeMicrosimulation.calculate_calls == ["household_weight"] + + +def test_validate_dataset_contract_rejects_unalignable_auxiliary_variables( + tmp_path, monkeypatch +): + file_path = tmp_path / "unknown.h5" + _write_test_h5( + file_path, + { + "person_id": np.array([101], dtype=np.int32), + "mystery_variable": np.array([1.0, 2.0], dtype=np.float32), + }, + ) + monkeypatch.setattr( + "policyengine_us_data.utils.dataset_validation.assert_locked_policyengine_us_version", + lambda: PolicyEngineUSBuildInfo(version="1.587.0"), + ) + + with pytest.raises(DatasetContractError, match="does not match any entity count"): + validate_dataset_contract( + file_path, + tax_benefit_system=_fake_tax_benefit_system(), + microsimulation_cls=_FakeMicrosimulation, + dataset_loader=lambda path: path, + ) + + +def test_validate_dataset_contract_reads_nested_h5_layout(tmp_path, monkeypatch): + file_path = tmp_path / "nested.h5" + with h5py.File(file_path, "w") as h5_file: + for name, values in { + "person_id": np.array([101, 102], dtype=np.int32), + "household_id": np.array([501], dtype=np.int32), + "employment_income": np.array([10_000.0, 20_000.0], dtype=np.float32), + "household_weight": np.array([1.5], dtype=np.float32), + "hourly_wage": np.array([25.0, 30.0], dtype=np.float32), + }.items(): + group = h5_file.create_group(name) + group.create_dataset("2024", data=values) + monkeypatch.setattr( + "policyengine_us_data.utils.dataset_validation.assert_locked_policyengine_us_version", + lambda: PolicyEngineUSBuildInfo(version="1.587.0"), + ) + + summary = validate_dataset_contract( + file_path, + tax_benefit_system=_fake_tax_benefit_system(), + microsimulation_cls=_FakeMicrosimulation, + dataset_loader=lambda path: path, + ) + + assert summary.variable_count == 5 + assert summary.entity_counts == { + "person": 2, + "household": 1, + } + + +def test_validate_dataset_contract_infers_time_period_for_flat_h5( + tmp_path, monkeypatch +): + file_path = tmp_path / "enhanced_cps_2024.h5" + _write_test_h5( + file_path, + { + "person_id": np.array([101, 102], dtype=np.int32), + "household_id": np.array([501], dtype=np.int32), + "employment_income": np.array([10_000.0, 20_000.0], dtype=np.float32), + "household_weight": np.array([1.5], dtype=np.float32), + }, + ) + monkeypatch.setattr( + "policyengine_us_data.utils.dataset_validation.assert_locked_policyengine_us_version", + lambda: PolicyEngineUSBuildInfo(version="1.587.0"), + ) + + validate_dataset_contract( + file_path, + tax_benefit_system=_fake_tax_benefit_system(), + microsimulation_cls=_TimePeriodCheckingMicrosimulation, + ) + + assert _TimePeriodCheckingMicrosimulation.last_dataset.time_period == 2024 + + +def test_validate_dataset_contract_rejects_ambiguous_auxiliary_variables( + tmp_path, monkeypatch +): + file_path = tmp_path / "ambiguous.h5" + _write_test_h5( + file_path, + { + "person_id": np.array([101], dtype=np.int32), + "household_id": np.array([201], dtype=np.int32), + "mystery_variable": np.array([1.0], dtype=np.float32), + }, + ) + monkeypatch.setattr( + "policyengine_us_data.utils.dataset_validation.assert_locked_policyengine_us_version", + lambda: PolicyEngineUSBuildInfo(version="1.587.0"), + ) + + with pytest.raises(DatasetContractError, match="matches multiple entity counts"): + validate_dataset_contract( + file_path, + tax_benefit_system=_fake_tax_benefit_system(), + microsimulation_cls=_FakeMicrosimulation, + dataset_loader=lambda path: path, + ) + + +def test_validate_dataset_contract_rejects_entity_length_mismatch( + tmp_path, monkeypatch +): + file_path = tmp_path / "mismatch.h5" + _write_test_h5( + file_path, + { + "person_id": np.array([101], dtype=np.int32), + "tax_unit_id": np.array([201], dtype=np.int32), + "family_id": np.array([301], dtype=np.int32), + "spm_unit_id": np.array([401], dtype=np.int32), + "household_id": np.array([501], dtype=np.int32), + "employment_income": np.array([10_000.0, 20_000.0], dtype=np.float32), + "household_weight": np.array([1.5], dtype=np.float32), + }, + ) + monkeypatch.setattr( + "policyengine_us_data.utils.dataset_validation.assert_locked_policyengine_us_version", + lambda: PolicyEngineUSBuildInfo(version="1.587.0"), + ) + + with pytest.raises(DatasetContractError, match="inconsistent entity lengths"): + validate_dataset_contract( + file_path, + tax_benefit_system=_fake_tax_benefit_system(), + microsimulation_cls=_FakeMicrosimulation, + dataset_loader=lambda path: path, + ) diff --git a/tests/unit/test_downsample.py b/tests/unit/test_downsample.py index 8ca42504c..1a95f7537 100644 --- a/tests/unit/test_downsample.py +++ b/tests/unit/test_downsample.py @@ -58,17 +58,59 @@ def test_downsample_dataset_arrays_preserves_original_dtypes(): ) -def test_downsample_dataset_arrays_fails_closed_on_unknown_variables(): +def test_downsample_dataset_arrays_resamples_auxiliary_variables(): original_data = { "person_id": np.array([101, 102], dtype=np.int32), + "household_id": np.array([202], dtype=np.int32), + "employment_income": np.array([100.0, 200.0], dtype=np.float32), "hourly_wage": np.array([25.0, 30.0], dtype=np.float32), + "count_under_18": np.array([0], dtype=np.int32), + } + sim = _FakeMicrosimulation( + variable_entities={ + "person_id": "person", + "household_id": "household", + "employment_income": "person", + }, + calculated_values={ + "person_id": np.array([102], dtype=np.int64), + "household_id": np.array([202], dtype=np.int64), + "employment_income": np.array([200.0], dtype=np.float64), + }, + ) + + resampled = downsample_dataset_arrays( + original_data=original_data, + sim=sim, + dataset_name="cps", + ) + + np.testing.assert_array_equal( + resampled["hourly_wage"], np.array([30.0], dtype=np.float32) + ) + np.testing.assert_array_equal( + resampled["count_under_18"], np.array([0], dtype=np.int32) + ) + + +def test_downsample_dataset_arrays_rejects_ambiguous_auxiliary_variable_lengths(): + original_data = { + "person_id": np.array([101], dtype=np.int32), + "household_id": np.array([201], dtype=np.int32), + "mystery_variable": np.array([5.0], dtype=np.float32), } sim = _FakeMicrosimulation( - variable_entities={"person_id": "person"}, - calculated_values={"person_id": np.array([101], dtype=np.int64)}, + variable_entities={ + "person_id": "person", + "household_id": "household", + }, + calculated_values={ + "person_id": np.array([101], dtype=np.int64), + "household_id": np.array([201], dtype=np.int64), + }, ) - with pytest.raises(ValueError, match="out of sync"): + with pytest.raises(ValueError, match="matches multiple entity sizes"): downsample_dataset_arrays( original_data=original_data, sim=sim, diff --git a/tests/unit/test_modal_data_build.py b/tests/unit/test_modal_data_build.py new file mode 100644 index 000000000..333850948 --- /dev/null +++ b/tests/unit/test_modal_data_build.py @@ -0,0 +1,90 @@ +import importlib +import sys +from types import ModuleType, SimpleNamespace + + +def _load_data_build_module(): + fake_modal = ModuleType("modal") + + class _FakeApp: + def __init__(self, *args, **kwargs): + pass + + def function(self, *args, **kwargs): + def decorator(func): + return func + + return decorator + + def local_entrypoint(self, *args, **kwargs): + def decorator(func): + return func + + return decorator + + fake_modal.App = _FakeApp + fake_modal.Secret = SimpleNamespace(from_name=lambda *args, **kwargs: object()) + fake_modal.Volume = SimpleNamespace(from_name=lambda *args, **kwargs: object()) + + fake_images = ModuleType("modal_app.images") + fake_images.cpu_image = object() + + sys.modules["modal"] = fake_modal + sys.modules["modal_app.images"] = fake_images + sys.modules.pop("modal_app.data_build", None) + return importlib.import_module("modal_app.data_build") + + +def test_validate_and_maybe_upload_datasets_validates_before_upload(monkeypatch): + data_build = _load_data_build_module() + calls = [] + + def fake_run_script(script_path, args=None, env=None, log_file=None): + calls.append((script_path, args or [], env)) + return script_path + + monkeypatch.setattr(data_build, "run_script", fake_run_script) + + data_build.validate_and_maybe_upload_datasets( + upload=True, + skip_enhanced_cps=False, + env={"TEST_ENV": "1"}, + ) + + assert calls == [ + ( + "policyengine_us_data/storage/upload_completed_datasets.py", + ["--validate-only"], + {"TEST_ENV": "1"}, + ), + ( + "policyengine_us_data/storage/upload_completed_datasets.py", + [], + {"TEST_ENV": "1"}, + ), + ] + + +def test_validate_and_maybe_upload_datasets_skips_upload_when_disabled(monkeypatch): + data_build = _load_data_build_module() + calls = [] + + def fake_run_script(script_path, args=None, env=None, log_file=None): + calls.append((script_path, args or [], env)) + return script_path + + monkeypatch.setattr(data_build, "run_script", fake_run_script) + + data_build.validate_and_maybe_upload_datasets( + upload=False, + skip_enhanced_cps=True, + env={"TEST_ENV": "1"}, + ) + + assert calls == [ + ( + "policyengine_us_data/storage/upload_completed_datasets.py", + ["--validate-only", "--no-require-enhanced-cps"], + {"TEST_ENV": "1"}, + ), + ] diff --git a/tests/unit/test_upload_completed_datasets.py b/tests/unit/test_upload_completed_datasets.py new file mode 100644 index 000000000..7a602c340 --- /dev/null +++ b/tests/unit/test_upload_completed_datasets.py @@ -0,0 +1,155 @@ +from types import SimpleNamespace + +import h5py +import numpy as np +import pytest + +import policyengine_us_data.storage.upload_completed_datasets as upload_module +from policyengine_us_data.storage.upload_completed_datasets import ( + DatasetValidationError, + validate_dataset, +) +import policyengine_us_data.utils.dataset_validation as _dv_mod +from policyengine_us_data.utils.dataset_validation import validate_dataset_contract +from policyengine_us_data.utils.policyengine import PolicyEngineUSBuildInfo + + +class _FakeArrayResult: + def __init__(self, values): + self.values = values + + +class _FakeMicrosimulation: + def __init__(self, dataset=None): + self.dataset = dataset + + def calculate(self, variable_name): + return _FakeArrayResult(np.array([1.0], dtype=np.float32)) + + +class _AggregateResult: + def __init__(self, values): + self.values = np.asarray(values, dtype=np.float64) + + def sum(self): + return float(self.values.sum()) + + +class _TimePeriodCheckingAggregateMicrosimulation: + last_dataset = None + + def __init__(self, dataset=None): + _TimePeriodCheckingAggregateMicrosimulation.last_dataset = dataset + if getattr(dataset, "time_period", None) is None: + raise ValueError( + "Expected a period (eg. '2017', '2017-01', '2017-01-01', ...); got: 'None'." + ) + + def calculate(self, variable_name, period=None): + if variable_name == "employment_income": + return _AggregateResult([6e12]) + if variable_name == "household_weight": + return _AggregateResult([1.5e8]) + raise KeyError(variable_name) + + +def _fake_tax_benefit_system(): + variable_entities = { + "person_id": "person", + "household_id": "household", + "employment_income": "person", + "household_weight": "household", + } + return SimpleNamespace( + variables={ + variable_name: SimpleNamespace(entity=SimpleNamespace(key=entity_key)) + for variable_name, entity_key in variable_entities.items() + } + ) + + +def _write_h5(path, datasets: dict[str, np.ndarray]) -> None: + with h5py.File(path, "w") as h5_file: + for name, values in datasets.items(): + h5_file.create_dataset(name, data=values) + + +@pytest.fixture(autouse=True) +def patch_contract_validation(monkeypatch): + monkeypatch.setitem(upload_module.MIN_FILE_SIZES, "cps_2024.h5", 0) + monkeypatch.setattr( + _dv_mod, + "assert_locked_policyengine_us_version", + lambda: PolicyEngineUSBuildInfo(version="1.587.0", locked_version="1.587.0"), + ) + monkeypatch.setattr( + upload_module, + "validate_dataset_contract", + lambda file_path: validate_dataset_contract( + file_path, + tax_benefit_system=_fake_tax_benefit_system(), + microsimulation_cls=_FakeMicrosimulation, + dataset_loader=lambda path: path, + ), + ) + + +def test_validate_dataset_rejects_unalignable_auxiliary_variables(tmp_path): + file_path = tmp_path / "cps_2024.h5" + _write_h5( + file_path, + { + "person_id": np.array([101], dtype=np.int32), + "household_id": np.array([201], dtype=np.int32), + "employment_income": np.array([50_000.0], dtype=np.float32), + "household_weight": np.array([1.0], dtype=np.float32), + "mystery_variable": np.array([1.0, 2.0], dtype=np.float32), + }, + ) + + with pytest.raises( + DatasetValidationError, + match="does not match any entity count", + ): + validate_dataset(file_path) + + +def test_validate_dataset_rejects_entity_length_mismatches(tmp_path): + file_path = tmp_path / "cps_2024.h5" + _write_h5( + file_path, + { + "person_id": np.array([101], dtype=np.int32), + "household_id": np.array([201], dtype=np.int32), + "employment_income": np.array([50_000.0, 60_000.0], dtype=np.float32), + "household_weight": np.array([1.0], dtype=np.float32), + }, + ) + + with pytest.raises( + DatasetValidationError, + match="inconsistent entity lengths", + ): + validate_dataset(file_path) + + +def test_validate_dataset_infers_time_period_for_flat_h5(tmp_path, monkeypatch): + file_path = tmp_path / "cps_2024.h5" + _write_h5( + file_path, + { + "person_id": np.array([101], dtype=np.int32), + "household_id": np.array([201], dtype=np.int32), + "employment_income": np.array([50_000.0], dtype=np.float32), + "household_weight": np.array([1.0], dtype=np.float32), + }, + ) + + monkeypatch.setattr( + "policyengine_us.Microsimulation", + _TimePeriodCheckingAggregateMicrosimulation, + ) + + validate_dataset(file_path) + + assert _TimePeriodCheckingAggregateMicrosimulation.last_dataset.time_period == 2024 diff --git a/tests/unit/test_version_manifest.py b/tests/unit/test_version_manifest.py index 32310b046..7e46f16c6 100644 --- a/tests/unit/test_version_manifest.py +++ b/tests/unit/test_version_manifest.py @@ -41,6 +41,8 @@ def test_to_dict(self, sample_manifest): assert result["hf"]["commit"] == "abc123def456" assert result["gcs"]["bucket"] == ("policyengine-us-data") assert result["gcs"]["generations"]["enhanced_cps_2024.h5"] == 1710203948123456 + assert result["policyengine_us"]["version"] == "1.587.0" + assert result["policyengine_us"]["git_commit"] == "deadbeef1234" def test_from_dict(self, sample_manifest): data = { @@ -66,6 +68,7 @@ def test_from_dict(self, sample_manifest): assert result.hf.repo == ("policyengine/policyengine-us-data") assert result.gcs.generations["enhanced_cps_2024.h5"] == 1710203948123456 assert result.gcs.bucket == "policyengine-us-data" + assert result.policyengine_us is None def test_roundtrip(self, sample_manifest): roundtripped = VersionManifest.from_dict(sample_manifest.to_dict()) @@ -76,6 +79,7 @@ def test_roundtrip(self, sample_manifest): assert roundtripped.hf.commit == (sample_manifest.hf.commit) assert roundtripped.gcs.bucket == (sample_manifest.gcs.bucket) assert roundtripped.gcs.generations == (sample_manifest.gcs.generations) + assert roundtripped.policyengine_us == sample_manifest.policyengine_us def test_without_hf(self, sample_generations): manifest = VersionManifest( @@ -240,8 +244,16 @@ def test_empty_registry(self): class TestBuildManifest: + @patch(f"{_MOD}.get_policyengine_us_build_info") @patch(f"{_MOD}._get_gcs_bucket") - def test_structure(self, mock_get_bucket, mock_bucket): + def test_structure( + self, + mock_get_bucket, + mock_get_policyengine_us_build_info, + mock_bucket, + sample_policyengine_us_info, + ): + mock_get_policyengine_us_build_info.return_value = sample_policyengine_us_info mock_get_bucket.return_value = mock_bucket blob_names = [ "file_a.h5", @@ -266,9 +278,18 @@ def test_structure(self, mock_get_bucket, mock_bucket): } assert result.gcs.bucket == "policyengine-us-data" assert result.hf is None + assert result.policyengine_us == sample_policyengine_us_info + @patch(f"{_MOD}.get_policyengine_us_build_info") @patch(f"{_MOD}._get_gcs_bucket") - def test_with_subdirectories(self, mock_get_bucket, mock_bucket): + def test_with_subdirectories( + self, + mock_get_bucket, + mock_get_policyengine_us_build_info, + mock_bucket, + sample_policyengine_us_info, + ): + mock_get_policyengine_us_build_info.return_value = sample_policyengine_us_info mock_get_bucket.return_value = mock_bucket blob_names = [ "states/AL.h5", @@ -286,13 +307,17 @@ def test_with_subdirectories(self, mock_get_bucket, mock_bucket): assert result.gcs.generations["states/AL.h5"] == 111 assert result.gcs.generations["districts/CA-01.h5"] == 222 + @patch(f"{_MOD}.get_policyengine_us_build_info") @patch(f"{_MOD}._get_gcs_bucket") def test_with_hf_info( self, mock_get_bucket, + mock_get_policyengine_us_build_info, mock_bucket, sample_hf_info, + sample_policyengine_us_info, ): + mock_get_policyengine_us_build_info.return_value = sample_policyengine_us_info mock_get_bucket.return_value = mock_bucket mock_bucket.get_blob.return_value = make_mock_blob(999) @@ -305,9 +330,18 @@ def test_with_hf_info( assert result.hf is not None assert result.hf.commit == "abc123def456" assert result.hf.repo == ("policyengine/policyengine-us-data") + assert result.policyengine_us == sample_policyengine_us_info + @patch(f"{_MOD}.get_policyengine_us_build_info") @patch(f"{_MOD}._get_gcs_bucket") - def test_missing_blob_raises(self, mock_get_bucket, mock_bucket): + def test_missing_blob_raises( + self, + mock_get_bucket, + mock_get_policyengine_us_build_info, + mock_bucket, + sample_policyengine_us_info, + ): + mock_get_policyengine_us_build_info.return_value = sample_policyengine_us_info mock_get_bucket.return_value = mock_bucket mock_bucket.get_blob.return_value = None