diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 2630d0e15..4d5d847bf 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -140,35 +140,27 @@ def get_version() -> str: def partition_work( - states: List[str], - districts: List[str], - cities: List[str], + work_items: List[Dict], num_workers: int, completed: set, ) -> List[List[Dict]]: - """Partition work items across N workers.""" - remaining = [] - - for s in states: - item_id = f"state:{s}" - if item_id not in completed: - remaining.append({"type": "state", "id": s, "weight": 5}) - - for d in districts: - item_id = f"district:{d}" - if item_id not in completed: - remaining.append({"type": "district", "id": d, "weight": 1}) + """Partition work items across N workers using LPT scheduling.""" + remaining = [ + item for item in work_items if f"{item['type']}:{item['id']}" not in completed + ] + remaining.sort(key=lambda x: -x["weight"]) - for c in cities: - item_id = f"city:{c}" - if item_id not in completed: - remaining.append({"type": "city", "id": c, "weight": 3}) + n_workers = min(num_workers, len(remaining)) + if n_workers == 0: + return [] - remaining.sort(key=lambda x: -x["weight"]) + heap = [(0, i) for i in range(n_workers)] + chunks = [[] for _ in range(n_workers)] - chunks = [[] for _ in range(num_workers)] - for i, item in enumerate(remaining): - chunks[i % num_workers].append(item) + for item in remaining: + load, idx = heapq.heappop(heap) + chunks[idx].append(item) + heapq.heappush(heap, (load + item["weight"], idx)) return [c for c in chunks if c] @@ -197,9 +189,7 @@ def get_completed_from_volume(version_dir: Path) -> set: def run_phase( phase_name: str, - states: List[str], - districts: List[str], - cities: List[str], + work_items: List[Dict], num_workers: int, completed: set, branch: str, @@ -216,7 +206,7 @@ def run_phase( and crashes, and validation_rows is a list of per-target validation result dicts. """ - work_chunks = partition_work(states, districts, cities, num_workers, completed) + work_chunks = partition_work(work_items, num_workers, completed) total_remaining = sum(len(c) for c in work_chunks) print(f"\n--- Phase: {phase_name} ---") @@ -228,7 +218,8 @@ def run_phase( handles = [] for i, chunk in enumerate(work_chunks): - print(f" Worker {i}: {len(chunk)} items") + total_weight = sum(item["weight"] for item in chunk) + print(f" Worker {i}: {len(chunk)} items, weight {total_weight}") handle = build_areas_worker.spawn( branch=branch, version=version, @@ -753,7 +744,7 @@ def coordinate_publish( cds = get_all_cds_from_database(db_uri) states = list(STATE_CODES.values()) districts = [get_district_friendly_name(cd) for cd in cds] -print(json.dumps({{"states": states, "districts": districts, "cities": ["NYC"]}})) +print(json.dumps({{"states": states, "districts": districts, "cities": ["NYC"], "cds": cds}})) """, ], capture_output=True, @@ -769,6 +760,22 @@ def coordinate_publish( districts = work_info["districts"] cities = work_info["cities"] + from collections import Counter + from policyengine_us_data.calibration.calibration_utils import STATE_CODES + + raw_cds = work_info["cds"] + cds_per_state = Counter(STATE_CODES.get(int(cd) // 100, "??") for cd in raw_cds) + + CITY_WEIGHTS = {"NYC": 11} + + work_items = [] + for s in states: + work_items.append({"type": "state", "id": s, "weight": cds_per_state.get(s, 1)}) + for d in districts: + work_items.append({"type": "district", "id": d, "weight": 1}) + for c in cities: + work_items.append({"type": "city", "id": c, "weight": CITY_WEIGHTS.get(c, 3)}) + staging_volume.reload() completed = get_completed_from_volume(version_dir) print(f"Found {len(completed)} already-completed items on volume") @@ -786,32 +793,8 @@ def coordinate_publish( accumulated_validation_rows = [] completed, phase_errors, v_rows = run_phase( - "States", - states=states, - districts=[], - cities=[], - completed=completed, - **phase_args, - ) - accumulated_errors.extend(phase_errors) - accumulated_validation_rows.extend(v_rows) - - completed, phase_errors, v_rows = run_phase( - "Districts", - states=[], - districts=districts, - cities=[], - completed=completed, - **phase_args, - ) - accumulated_errors.extend(phase_errors) - accumulated_validation_rows.extend(v_rows) - - completed, phase_errors, v_rows = run_phase( - "Cities", - states=[], - districts=[], - cities=cities, + "All areas", + work_items=work_items, completed=completed, **phase_args, ) diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index e610736b5..27dbb8c2a 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -208,8 +208,7 @@ def main(): from policyengine_us_data.calibration.publish_local_area import ( build_h5, - NYC_COUNTIES, - NYC_CDS, + NYC_COUNTY_FIPS, AT_LARGE_DISTRICTS, ) from policyengine_us_data.calibration.calibration_utils import ( @@ -388,13 +387,6 @@ def main(): ) elif item_type == "city": - cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS] - if not cd_subset: - print( - "No NYC CDs found, skipping", - file=sys.stderr, - ) - continue cities_dir = output_dir / "cities" cities_dir.mkdir(parents=True, exist_ok=True) path = build_h5( @@ -402,8 +394,7 @@ def main(): geography=geography, dataset_path=dataset_path, output_path=cities_dir / "NYC.h5", - cd_subset=cd_subset, - county_filter=NYC_COUNTIES, + county_fips_filter=NYC_COUNTY_FIPS, takeup_filter=takeup_filter, ) diff --git a/policyengine_us_data/calibration/block_assignment.py b/policyengine_us_data/calibration/block_assignment.py index 83af388f2..3754ad5af 100644 --- a/policyengine_us_data/calibration/block_assignment.py +++ b/policyengine_us_data/calibration/block_assignment.py @@ -580,92 +580,3 @@ def derive_geography_from_blocks( "zcta": np.array(zcta_list), "county_index": county_indices, } - - -# === County Filter Functions (for city-level datasets) === - - -def get_county_filter_probability( - cd_geoid: str, - county_filter: set, -) -> float: - """ - Calculate P(county in filter | CD) using block-level data. - - Returns the probability that a household in this CD would be in the - target area (e.g., NYC). Used for weight scaling when building - city-level datasets. - - Args: - cd_geoid: Congressional district geoid (e.g., "3610") - county_filter: Set of county enum names that define the target area - - Returns: - Probability between 0 and 1 - """ - distributions = _get_block_distributions() - cd_key = str(int(cd_geoid)) - - if cd_key not in distributions: - return 0.0 - - dist = distributions[cd_key] - - # Convert county enum names to FIPS codes for comparison - fips_to_enum = _build_county_fips_to_enum() - enum_to_fips = {v: k for k, v in fips_to_enum.items()} - target_fips = {enum_to_fips.get(name) for name in county_filter} - target_fips.discard(None) - - # Sum probabilities of blocks in target counties - return sum( - prob - for block, prob in dist.items() - if get_county_fips_from_block(block) in target_fips - ) - - -def get_filtered_block_distribution( - cd_geoid: str, - county_filter: set, -) -> Dict[str, float]: - """ - Get normalized distribution over blocks in target counties only. - - Used when building city-level datasets to assign only blocks in valid - counties while maintaining relative proportions within the target area. - - Args: - cd_geoid: Congressional district geoid (e.g., "3610") - county_filter: Set of county enum names that define the target area - - Returns: - Dictionary mapping block GEOIDs to normalized probabilities. - Empty dict if CD has no overlap with target area. - """ - distributions = _get_block_distributions() - cd_key = str(int(cd_geoid)) - - if cd_key not in distributions: - return {} - - dist = distributions[cd_key] - - # Convert county enum names to FIPS codes for comparison - fips_to_enum = _build_county_fips_to_enum() - enum_to_fips = {v: k for k, v in fips_to_enum.items()} - target_fips = {enum_to_fips.get(name) for name in county_filter} - target_fips.discard(None) - - # Filter to blocks in target counties - filtered = { - block: prob - for block, prob in dist.items() - if get_county_fips_from_block(block) in target_fips - } - - # Normalize - total = sum(filtered.values()) - if total > 0: - return {block: prob / total for block, prob in filtered.items()} - return {} diff --git a/policyengine_us_data/calibration/clone_and_assign.py b/policyengine_us_data/calibration/clone_and_assign.py index 0fc1e0f61..52a53c20e 100644 --- a/policyengine_us_data/calibration/clone_and_assign.py +++ b/policyengine_us_data/calibration/clone_and_assign.py @@ -67,10 +67,23 @@ def load_global_block_distribution(): return block_geoids, cd_geoids, state_fips, probs +def _build_agi_block_probs(cds, pop_probs, cd_agi_targets): + """Multiply population block probs by CD AGI target weights.""" + agi_weights = np.array([cd_agi_targets.get(cd, 0.0) for cd in cds]) + agi_weights = np.maximum(agi_weights, 0.0) + if agi_weights.sum() == 0: + return pop_probs + agi_probs = pop_probs * agi_weights + return agi_probs / agi_probs.sum() + + def assign_random_geography( n_records: int, n_clones: int = 10, seed: int = 42, + household_agi: np.ndarray = None, + cd_agi_targets: dict = None, + agi_threshold_pctile: float = 90.0, ) -> GeographyAssignment: """Assign random census block geography to cloned CPS records. @@ -95,17 +108,48 @@ def assign_random_geography( n_total = n_records * n_clones rng = np.random.default_rng(seed) + agi_probs = None + extreme_mask = None + if household_agi is not None and cd_agi_targets is not None: + threshold = np.percentile(household_agi, agi_threshold_pctile) + extreme_mask = household_agi >= threshold + agi_probs = _build_agi_block_probs(cds, probs, cd_agi_targets) + logger.info( + "AGI-conditional assignment: %d extreme HHs (AGI >= $%.0f) " + "use AGI-weighted block probs", + extreme_mask.sum(), + threshold, + ) + + def _sample(size, mask_slice=None): + """Sample block indices, using AGI-weighted probs for extreme HHs.""" + if ( + extreme_mask is not None + and agi_probs is not None + and mask_slice is not None + ): + out = np.empty(size, dtype=np.int64) + ext = mask_slice + n_ext = ext.sum() + n_norm = size - n_ext + if n_ext > 0: + out[ext] = rng.choice(len(blocks), size=n_ext, p=agi_probs) + if n_norm > 0: + out[~ext] = rng.choice(len(blocks), size=n_norm, p=probs) + return out + return rng.choice(len(blocks), size=size, p=probs) + indices = np.empty(n_total, dtype=np.int64) # Clone 0: unrestricted draw - indices[:n_records] = rng.choice(len(blocks), size=n_records, p=probs) + indices[:n_records] = _sample(n_records, extreme_mask) assigned_cds = np.empty((n_clones, n_records), dtype=object) assigned_cds[0] = cds[indices[:n_records]] for clone_idx in range(1, n_clones): start = clone_idx * n_records - clone_indices = rng.choice(len(blocks), size=n_records, p=probs) + clone_indices = _sample(n_records, extreme_mask) clone_cds = cds[clone_indices] collisions = np.zeros(n_records, dtype=bool) @@ -116,7 +160,20 @@ def assign_random_geography( n_bad = collisions.sum() if n_bad == 0: break - clone_indices[collisions] = rng.choice(len(blocks), size=n_bad, p=probs) + bad_mask = collisions + if extreme_mask is not None and agi_probs is not None: + bad_ext = bad_mask & extreme_mask + bad_norm = bad_mask & ~extreme_mask + if bad_ext.sum() > 0: + clone_indices[bad_ext] = rng.choice( + len(blocks), size=bad_ext.sum(), p=agi_probs + ) + if bad_norm.sum() > 0: + clone_indices[bad_norm] = rng.choice( + len(blocks), size=bad_norm.sum(), p=probs + ) + else: + clone_indices[collisions] = rng.choice(len(blocks), size=n_bad, p=probs) clone_cds = cds[clone_indices] collisions = np.zeros(n_records, dtype=bool) for prev in range(clone_idx): diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py index 5751640de..b3e6085a9 100644 --- a/policyengine_us_data/calibration/publish_local_area.py +++ b/policyengine_us_data/calibration/publish_local_area.py @@ -30,7 +30,6 @@ ) from policyengine_us_data.calibration.block_assignment import ( derive_geography_from_blocks, - get_county_filter_probability, ) from policyengine_us_data.calibration.clone_and_assign import ( GeographyAssignment, @@ -46,29 +45,7 @@ CHECKPOINT_FILE_CITIES = Path("completed_cities.txt") WORK_DIR = Path("local_area_build") -NYC_COUNTIES = { - "QUEENS_COUNTY_NY", - "BRONX_COUNTY_NY", - "RICHMOND_COUNTY_NY", - "NEW_YORK_COUNTY_NY", - "KINGS_COUNTY_NY", -} - -NYC_CDS = [ - "3603", - "3605", - "3606", - "3607", - "3608", - "3609", - "3610", - "3611", - "3612", - "3613", - "3614", - "3615", - "3616", -] +NYC_COUNTY_FIPS = {"36005", "36047", "36061", "36081", "36085"} META_FILE = WORK_DIR / "checkpoint_meta.json" @@ -186,7 +163,7 @@ def build_h5( dataset_path: Path, output_path: Path, cd_subset: List[str] = None, - county_filter: set = None, + county_fips_filter: set = None, takeup_filter: List[str] = None, ) -> Path: """Build an H5 file by cloning records for each nonzero weight. @@ -197,8 +174,8 @@ def build_h5( dataset_path: Path to base dataset H5 file. output_path: Where to write the output H5 file. cd_subset: If provided, only include clones for these CDs. - county_filter: If provided, scale weights by P(target|CD) - for city datasets. + county_fips_filter: If provided, zero out weights for clones + whose county FIPS is not in this set. takeup_filter: List of takeup vars to apply. Returns: @@ -239,17 +216,11 @@ def build_h5( cd_mask = np.vectorize(lambda cd: cd in cd_subset_set)(clone_cds_matrix) W[~cd_mask] = 0 - # County filtering: scale weights by P(target_counties | CD) - if county_filter is not None: - unique_cds = np.unique(clone_cds_matrix) - cd_prob = { - cd: get_county_filter_probability(cd, county_filter) for cd in unique_cds - } - p_matrix = np.vectorize( - cd_prob.__getitem__, - otypes=[float], - )(clone_cds_matrix) - W *= p_matrix + # County FIPS filtering: zero out clones not in target counties + if county_fips_filter is not None: + fips_array = np.asarray(geography.county_fips).reshape(n_clones_total, n_hh) + fips_mask = np.isin(fips_array, list(county_fips_filter)) + W[~fips_mask] = 0 label = ( f"CD subset {cd_subset}" @@ -266,7 +237,7 @@ def build_h5( if n_clones == 0: raise ValueError( f"No active clones after filtering. " - f"cd_subset={cd_subset}, county_filter={county_filter}" + f"cd_subset={cd_subset}, county_fips_filter={county_fips_filter}" ) clone_weights = W[active_geo, active_hh] active_blocks = blocks.reshape(n_clones_total, n_hh)[active_geo, active_hh] @@ -783,8 +754,6 @@ def build_cities( """Build city H5 files with checkpointing, optionally uploading.""" w = np.load(weights_path) - all_cds = sorted(set(geography.cd_geoid.astype(str))) - cities_dir = output_dir / "cities" cities_dir.mkdir(parents=True, exist_ok=True) @@ -794,34 +763,29 @@ def build_cities( if "NYC" in completed_cities: print("Skipping NYC (already completed)") else: - cd_subset = [cd for cd in all_cds if cd in NYC_CDS] - if not cd_subset: - print("No NYC-related CDs found, skipping") - else: - output_path = cities_dir / "NYC.h5" - - try: - build_h5( - weights=w, - geography=geography, - dataset_path=dataset_path, - output_path=output_path, - cd_subset=cd_subset, - county_filter=NYC_COUNTIES, - takeup_filter=takeup_filter, - ) - - if upload: - print("Uploading NYC.h5 to GCP...") - upload_local_area_file(str(output_path), "cities", skip_hf=True) - hf_queue.append((str(output_path), "cities")) - - record_completed_city("NYC") - print("Completed NYC") - - except Exception as e: - print(f"ERROR building NYC: {e}") - raise + output_path = cities_dir / "NYC.h5" + + try: + build_h5( + weights=w, + geography=geography, + dataset_path=dataset_path, + output_path=output_path, + county_fips_filter=NYC_COUNTY_FIPS, + takeup_filter=takeup_filter, + ) + + if upload: + print("Uploading NYC.h5 to GCP...") + upload_local_area_file(str(output_path), "cities", skip_hf=True) + hf_queue.append((str(output_path), "cities")) + + record_completed_city("NYC") + print("Completed NYC") + + except Exception as e: + print(f"ERROR building NYC: {e}") + raise if upload and hf_queue: print(f"\nUploading batch of {len(hf_queue)} city files to HuggingFace...") diff --git a/policyengine_us_data/calibration/target_config.yaml b/policyengine_us_data/calibration/target_config.yaml index 1d36747bb..03a8df147 100644 --- a/policyengine_us_data/calibration/target_config.yaml +++ b/policyengine_us_data/calibration/target_config.yaml @@ -6,11 +6,15 @@ include: domain_variable: age # === DISTRICT — count targets === - # REMOVED: person_count by AGI — filer-gated, all AGI bins 100% underestimated + - variable: person_count + geo_level: district + domain_variable: adjusted_gross_income - variable: household_count geo_level: district - # === DISTRICT — dollar targets (all <8% mean error, restored) === + # === DISTRICT — dollar targets === + - variable: adjusted_gross_income + geo_level: district - variable: real_estate_taxes geo_level: district - variable: self_employment_income @@ -36,14 +40,28 @@ include: # REMOVED: is_pregnant — 100% unachievable across all 51 state geos - variable: snap geo_level: state + - variable: adjusted_gross_income + geo_level: state + # REMOVED: state_income_tax — ETL hardcodes $0 for WA and NH, but + # PolicyEngine correctly computes non-zero tax (WA capital gains tax, + # NH interest/dividends tax). The $0 targets produce catastrophic loss + # that crushes WA/NH weights to zero. Fix the ETL before re-enabling. + # - variable: state_income_tax + # geo_level: state # === NATIONAL — aggregate dollar targets === - # REMOVED: adjusted_gross_income — filer-gated + - variable: adjusted_gross_income + geo_level: national + - variable: alimony_expense + geo_level: national + - variable: alimony_income + geo_level: national - variable: child_support_expense geo_level: national - variable: child_support_received geo_level: national - # REMOVED: eitc — filer-gated + - variable: eitc + geo_level: national - variable: health_insurance_premiums_without_medicare_part_b geo_level: national - variable: medicaid @@ -54,18 +72,22 @@ include: geo_level: national - variable: over_the_counter_health_expenses geo_level: national - # REMOVED: qualified_business_income_deduction — filer-gated + - variable: real_estate_taxes + geo_level: national - variable: rent geo_level: national - # REMOVED: salt_deduction — 11.3x overestimate, worst variable in model - variable: snap geo_level: national - variable: social_security geo_level: national + - variable: social_security_dependents + geo_level: national - variable: social_security_disability geo_level: national - variable: social_security_retirement geo_level: national + - variable: social_security_survivors + geo_level: national - variable: spm_unit_capped_housing_subsidy geo_level: national - variable: spm_unit_capped_work_childcare_expenses @@ -74,10 +96,23 @@ include: geo_level: national - variable: tanf geo_level: national - # REMOVED: tip_income — filer-gated + - variable: tip_income + geo_level: national - variable: unemployment_compensation geo_level: national + # === NATIONAL — retirement contribution targets === + - variable: traditional_ira_contributions + geo_level: national + - variable: traditional_401k_contributions + geo_level: national + - variable: roth_401k_contributions + geo_level: national + - variable: roth_ira_contributions + geo_level: national + - variable: self_employed_pension_contribution_ald + geo_level: national + # === NATIONAL — IRS SOI domain-constrained dollar targets (restored: |rel_err| < 15%) === - variable: aca_ptc geo_level: national @@ -100,14 +135,6 @@ include: - variable: unemployment_compensation geo_level: national domain_variable: unemployment_compensation - # REMOVED (|rel_err| > 15% or tension with counts): - # adjusted_gross_income (28%), dividend_income (26%, tension), eitc (23%), - # eitc by child_count (14-77%, tension), income_tax_before_credits (21%), - # income_tax_positive (22%), qualified_business_income_deduction (55-63%), - # qualified_dividend_income (29%, tension), rental_income (20%), - # salt (102%), salt_deduction (1130%), tax_exempt_interest_income (61%), - # taxable_interest_income (61%), taxable_ira_distributions (68%), - # taxable_social_security (55%) # === NATIONAL — IRS SOI filer count targets (restored: |rel_err| < 10%) === - variable: tax_unit_count @@ -116,4 +143,32 @@ include: - variable: tax_unit_count geo_level: national domain_variable: refundable_ctc - # REMOVED (|rel_err| > 10%): all other filer count targets (22-706% error) + + # === NATIONAL — SOI deduction totals (non-reform) === + - variable: medical_expense_deduction + geo_level: national + domain_variable: medical_expense_deduction,tax_unit_itemizes + - variable: qualified_business_income_deduction + geo_level: national + domain_variable: qualified_business_income_deduction + + # === NATIONAL — JCT tax expenditure targets (reform_id=1..5) === + - variable: salt_deduction + geo_level: national + - variable: charitable_deduction + geo_level: national + - variable: deductible_mortgage_interest + geo_level: national + - variable: medical_expense_deduction + geo_level: national + - variable: qualified_business_income_deduction + geo_level: national + + # NOT INCLUDED — high error or tension (from prior validation) + # ===================================================================== + # dividend_income (26%, tension), qualified_dividend_income (29%, tension), + # eitc by child_count (14-77%, tension), rental_income (20%), + # income_tax_before_credits (21%), income_tax_positive (22%), + # salt SOI (102%), taxable_interest_income (61%), + # tax_exempt_interest_income (61%), taxable_ira_distributions (68%), + # taxable_social_security (55%), person_count by AGI bins (100%) diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index 420e9006f..e135b8493 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -931,6 +931,33 @@ def run_calibration( time_period, ) + # Compute base household AGI for conditional geographic assignment + base_agi = sim.calculate("adjusted_gross_income", map_to="household").values.astype( + np.float64 + ) + + # Load CD-level AGI targets from database + import sqlite3 + + from policyengine_us_data.storage import STORAGE_FOLDER + + db_path = str(STORAGE_FOLDER / "calibration" / "policy_data.db") + conn = sqlite3.connect(db_path) + rows = conn.execute( + "SELECT sc.value, t.value " + "FROM targets t " + "JOIN stratum_constraints sc ON t.stratum_id = sc.stratum_id " + "WHERE t.variable = 'adjusted_gross_income' " + "AND sc.constraint_variable = 'congressional_district_geoid' " + "AND t.active = 1" + ).fetchall() + conn.close() + cd_agi_targets = {str(row[0]): float(row[1]) for row in rows} + logger.info( + "Loaded %d CD AGI targets for conditional assignment", + len(cd_agi_targets), + ) + # Step 2: Clone and assign geography logger.info( "Assigning geography: %d x %d = %d total", @@ -942,6 +969,8 @@ def run_calibration( n_records=n_records, n_clones=n_clones, seed=seed, + household_agi=base_agi, + cd_agi_targets=cd_agi_targets, ) # Step 3: Source imputation (if requested) diff --git a/policyengine_us_data/calibration/unified_matrix_builder.py b/policyengine_us_data/calibration/unified_matrix_builder.py index fb8865b80..1917e1a14 100644 --- a/policyengine_us_data/calibration/unified_matrix_builder.py +++ b/policyengine_us_data/calibration/unified_matrix_builder.py @@ -493,14 +493,9 @@ def _assemble_clone_values_standalone( arr = np.zeros(n_records, dtype=np.float32) for state in unique_clone_states: mask = state_masks[int(state)] - arr[mask] = ( - state_values[int(state)] - .get("reform_hh", {}) - .get( - var, - np.zeros(mask.sum(), dtype=np.float32), - ) - ) + reform_data = state_values[int(state)].get("reform_hh", {}) + if var in reform_data: + arr[mask] = reform_data[var][mask] reform_hh_vars[var] = arr return hh_vars, person_vars, reform_hh_vars @@ -852,7 +847,7 @@ def _process_single_clone( ) if variable.endswith("_count"): - vkey = (variable, constraint_key) + vkey = (variable, constraint_key, reform_id) if vkey not in count_cache: count_cache[vkey] = _calculate_target_values_standalone( variable, @@ -864,6 +859,7 @@ def _process_single_clone( entity_rel, household_ids, variable_entity_map, + reform_id=reform_id, ) values = count_cache[vkey] else: @@ -1495,14 +1491,9 @@ def _assemble_clone_values( arr = np.zeros(n_records, dtype=np.float32) for state in unique_clone_states: mask = state_masks[int(state)] - arr[mask] = ( - state_values[int(state)] - .get("reform_hh", {}) - .get( - var, - np.zeros(mask.sum(), dtype=np.float32), - ) - ) + reform_data = state_values[int(state)].get("reform_hh", {}) + if var in reform_data: + arr[mask] = reform_data[var][mask] reform_hh_vars[var] = arr return hh_vars, person_vars, reform_hh_vars @@ -2497,6 +2488,7 @@ def build_matrix( vkey = ( variable, constraint_key, + reform_id, ) if vkey not in count_cache: count_cache[vkey] = _calculate_target_values_standalone( @@ -2509,6 +2501,7 @@ def build_matrix( entity_rel=entity_rel, household_ids=household_ids, variable_entity_map=variable_entity_map, + reform_id=reform_id, ) values = count_cache[vkey] else: diff --git a/policyengine_us_data/db/etl_national_targets.py b/policyengine_us_data/db/etl_national_targets.py index 42ec2232f..6e6b3b3ab 100644 --- a/policyengine_us_data/db/etl_national_targets.py +++ b/policyengine_us_data/db/etl_national_targets.py @@ -708,6 +708,22 @@ def load_national_targets( notes=combined_notes, ) session.add(target) + session.flush() + + persisted = ( + session.query(Target) + .filter(Target.target_id == target.target_id) + .first() + ) + if persisted.reform_id != target_reform_id: + print( + f" WARNING: {target_data['variable']} persisted " + f"with reform_id={persisted.reform_id}, " + f"correcting to {target_reform_id}" + ) + persisted.reform_id = target_reform_id + session.flush() + print(f"Added tax expenditure target: {target_data['variable']}") # Process conditional count targets (enrollment counts) @@ -805,6 +821,31 @@ def load_national_targets( session.commit() + tax_exp_vars = [ + "salt_deduction", + "charitable_deduction", + "deductible_mortgage_interest", + "medical_expense_deduction", + "qualified_business_income_deduction", + ] + bad_targets = ( + session.query(Target) + .join(Stratum, Target.stratum_id == Stratum.stratum_id) + .filter( + Target.variable.in_(tax_exp_vars), + Target.active == True, + Stratum.parent_stratum_id == None, + Target.reform_id == 0, + ) + .all() + ) + if bad_targets: + bad_names = [t.variable for t in bad_targets] + raise ValueError( + f"Post-commit check failed: tax expenditure targets " + f"have reform_id=0 in root stratum: {bad_names}" + ) + total_targets = ( len(direct_targets_df) + len(tax_filer_df) diff --git a/policyengine_us_data/db/validate_database.py b/policyengine_us_data/db/validate_database.py index b57a83c32..8f769d766 100644 --- a/policyengine_us_data/db/validate_database.py +++ b/policyengine_us_data/db/validate_database.py @@ -21,3 +21,29 @@ for var_name in set(stratum_constraints_df["constraint_variable"]): if not var_name in system.variables.keys(): raise ValueError(f"{var_name} not a policyengine-us variable") + +TAX_EXPENDITURE_VARS = [ + "salt_deduction", + "charitable_deduction", + "deductible_mortgage_interest", + "medical_expense_deduction", + "qualified_business_income_deduction", +] + +root_stratum_ids = pd.read_sql( + "SELECT stratum_id FROM strata WHERE parent_stratum_id IS NULL", conn +)["stratum_id"].tolist() + +for var in TAX_EXPENDITURE_VARS: + matches = targets_df[ + (targets_df["variable"] == var) + & (targets_df["active"] == 1) + & (targets_df["stratum_id"].isin(root_stratum_ids)) + & (targets_df["reform_id"] > 0) + ] + if matches.empty: + raise ValueError( + f"Validation failed: {var} has no active target with " + f"reform_id > 0 in the root stratum. Tax expenditure targets " + f"must have a non-zero reform_id for correct calibration." + )