Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 38 additions & 55 deletions modal_app/local_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,35 +140,27 @@ def get_version() -> str:


def partition_work(
states: List[str],
districts: List[str],
cities: List[str],
work_items: List[Dict],
num_workers: int,
completed: set,
) -> List[List[Dict]]:
"""Partition work items across N workers."""
remaining = []

for s in states:
item_id = f"state:{s}"
if item_id not in completed:
remaining.append({"type": "state", "id": s, "weight": 5})

for d in districts:
item_id = f"district:{d}"
if item_id not in completed:
remaining.append({"type": "district", "id": d, "weight": 1})
"""Partition work items across N workers using LPT scheduling."""
remaining = [
item for item in work_items if f"{item['type']}:{item['id']}" not in completed
]
remaining.sort(key=lambda x: -x["weight"])

for c in cities:
item_id = f"city:{c}"
if item_id not in completed:
remaining.append({"type": "city", "id": c, "weight": 3})
n_workers = min(num_workers, len(remaining))
if n_workers == 0:
return []

remaining.sort(key=lambda x: -x["weight"])
heap = [(0, i) for i in range(n_workers)]
chunks = [[] for _ in range(n_workers)]

chunks = [[] for _ in range(num_workers)]
for i, item in enumerate(remaining):
chunks[i % num_workers].append(item)
for item in remaining:
load, idx = heapq.heappop(heap)
chunks[idx].append(item)
heapq.heappush(heap, (load + item["weight"], idx))

return [c for c in chunks if c]

Expand Down Expand Up @@ -197,9 +189,7 @@ def get_completed_from_volume(version_dir: Path) -> set:

def run_phase(
phase_name: str,
states: List[str],
districts: List[str],
cities: List[str],
work_items: List[Dict],
num_workers: int,
completed: set,
branch: str,
Expand All @@ -216,7 +206,7 @@ def run_phase(
and crashes, and validation_rows is a list of per-target
validation result dicts.
"""
work_chunks = partition_work(states, districts, cities, num_workers, completed)
work_chunks = partition_work(work_items, num_workers, completed)
total_remaining = sum(len(c) for c in work_chunks)

print(f"\n--- Phase: {phase_name} ---")
Expand All @@ -228,7 +218,8 @@ def run_phase(

handles = []
for i, chunk in enumerate(work_chunks):
print(f" Worker {i}: {len(chunk)} items")
total_weight = sum(item["weight"] for item in chunk)
print(f" Worker {i}: {len(chunk)} items, weight {total_weight}")
handle = build_areas_worker.spawn(
branch=branch,
version=version,
Expand Down Expand Up @@ -753,7 +744,7 @@ def coordinate_publish(
cds = get_all_cds_from_database(db_uri)
states = list(STATE_CODES.values())
districts = [get_district_friendly_name(cd) for cd in cds]
print(json.dumps({{"states": states, "districts": districts, "cities": ["NYC"]}}))
print(json.dumps({{"states": states, "districts": districts, "cities": ["NYC"], "cds": cds}}))
""",
],
capture_output=True,
Expand All @@ -769,6 +760,22 @@ def coordinate_publish(
districts = work_info["districts"]
cities = work_info["cities"]

from collections import Counter
from policyengine_us_data.calibration.calibration_utils import STATE_CODES

raw_cds = work_info["cds"]
cds_per_state = Counter(STATE_CODES.get(int(cd) // 100, "??") for cd in raw_cds)

CITY_WEIGHTS = {"NYC": 11}

work_items = []
for s in states:
work_items.append({"type": "state", "id": s, "weight": cds_per_state.get(s, 1)})
for d in districts:
work_items.append({"type": "district", "id": d, "weight": 1})
for c in cities:
work_items.append({"type": "city", "id": c, "weight": CITY_WEIGHTS.get(c, 3)})

staging_volume.reload()
completed = get_completed_from_volume(version_dir)
print(f"Found {len(completed)} already-completed items on volume")
Expand All @@ -786,32 +793,8 @@ def coordinate_publish(
accumulated_validation_rows = []

completed, phase_errors, v_rows = run_phase(
"States",
states=states,
districts=[],
cities=[],
completed=completed,
**phase_args,
)
accumulated_errors.extend(phase_errors)
accumulated_validation_rows.extend(v_rows)

completed, phase_errors, v_rows = run_phase(
"Districts",
states=[],
districts=districts,
cities=[],
completed=completed,
**phase_args,
)
accumulated_errors.extend(phase_errors)
accumulated_validation_rows.extend(v_rows)

completed, phase_errors, v_rows = run_phase(
"Cities",
states=[],
districts=[],
cities=cities,
"All areas",
work_items=work_items,
completed=completed,
**phase_args,
)
Expand Down
13 changes: 2 additions & 11 deletions modal_app/worker_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ def main():

from policyengine_us_data.calibration.publish_local_area import (
build_h5,
NYC_COUNTIES,
NYC_CDS,
NYC_COUNTY_FIPS,
AT_LARGE_DISTRICTS,
)
from policyengine_us_data.calibration.calibration_utils import (
Expand Down Expand Up @@ -388,22 +387,14 @@ def main():
)

elif item_type == "city":
cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS]
if not cd_subset:
print(
"No NYC CDs found, skipping",
file=sys.stderr,
)
continue
cities_dir = output_dir / "cities"
cities_dir.mkdir(parents=True, exist_ok=True)
path = build_h5(
weights=weights,
geography=geography,
dataset_path=dataset_path,
output_path=cities_dir / "NYC.h5",
cd_subset=cd_subset,
county_filter=NYC_COUNTIES,
county_fips_filter=NYC_COUNTY_FIPS,
takeup_filter=takeup_filter,
)

Expand Down
89 changes: 0 additions & 89 deletions policyengine_us_data/calibration/block_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,92 +580,3 @@ def derive_geography_from_blocks(
"zcta": np.array(zcta_list),
"county_index": county_indices,
}


# === County Filter Functions (for city-level datasets) ===


def get_county_filter_probability(
cd_geoid: str,
county_filter: set,
) -> float:
"""
Calculate P(county in filter | CD) using block-level data.

Returns the probability that a household in this CD would be in the
target area (e.g., NYC). Used for weight scaling when building
city-level datasets.

Args:
cd_geoid: Congressional district geoid (e.g., "3610")
county_filter: Set of county enum names that define the target area

Returns:
Probability between 0 and 1
"""
distributions = _get_block_distributions()
cd_key = str(int(cd_geoid))

if cd_key not in distributions:
return 0.0

dist = distributions[cd_key]

# Convert county enum names to FIPS codes for comparison
fips_to_enum = _build_county_fips_to_enum()
enum_to_fips = {v: k for k, v in fips_to_enum.items()}
target_fips = {enum_to_fips.get(name) for name in county_filter}
target_fips.discard(None)

# Sum probabilities of blocks in target counties
return sum(
prob
for block, prob in dist.items()
if get_county_fips_from_block(block) in target_fips
)


def get_filtered_block_distribution(
cd_geoid: str,
county_filter: set,
) -> Dict[str, float]:
"""
Get normalized distribution over blocks in target counties only.

Used when building city-level datasets to assign only blocks in valid
counties while maintaining relative proportions within the target area.

Args:
cd_geoid: Congressional district geoid (e.g., "3610")
county_filter: Set of county enum names that define the target area

Returns:
Dictionary mapping block GEOIDs to normalized probabilities.
Empty dict if CD has no overlap with target area.
"""
distributions = _get_block_distributions()
cd_key = str(int(cd_geoid))

if cd_key not in distributions:
return {}

dist = distributions[cd_key]

# Convert county enum names to FIPS codes for comparison
fips_to_enum = _build_county_fips_to_enum()
enum_to_fips = {v: k for k, v in fips_to_enum.items()}
target_fips = {enum_to_fips.get(name) for name in county_filter}
target_fips.discard(None)

# Filter to blocks in target counties
filtered = {
block: prob
for block, prob in dist.items()
if get_county_fips_from_block(block) in target_fips
}

# Normalize
total = sum(filtered.values())
if total > 0:
return {block: prob / total for block, prob in filtered.items()}
return {}
63 changes: 60 additions & 3 deletions policyengine_us_data/calibration/clone_and_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,23 @@ def load_global_block_distribution():
return block_geoids, cd_geoids, state_fips, probs


def _build_agi_block_probs(cds, pop_probs, cd_agi_targets):
"""Multiply population block probs by CD AGI target weights."""
agi_weights = np.array([cd_agi_targets.get(cd, 0.0) for cd in cds])
agi_weights = np.maximum(agi_weights, 0.0)
if agi_weights.sum() == 0:
return pop_probs
agi_probs = pop_probs * agi_weights
return agi_probs / agi_probs.sum()


def assign_random_geography(
n_records: int,
n_clones: int = 10,
seed: int = 42,
household_agi: np.ndarray = None,
cd_agi_targets: dict = None,
agi_threshold_pctile: float = 90.0,
) -> GeographyAssignment:
"""Assign random census block geography to cloned
CPS records.
Expand All @@ -95,17 +108,48 @@ def assign_random_geography(
n_total = n_records * n_clones
rng = np.random.default_rng(seed)

agi_probs = None
extreme_mask = None
if household_agi is not None and cd_agi_targets is not None:
threshold = np.percentile(household_agi, agi_threshold_pctile)
extreme_mask = household_agi >= threshold
agi_probs = _build_agi_block_probs(cds, probs, cd_agi_targets)
logger.info(
"AGI-conditional assignment: %d extreme HHs (AGI >= $%.0f) "
"use AGI-weighted block probs",
extreme_mask.sum(),
threshold,
)

def _sample(size, mask_slice=None):
"""Sample block indices, using AGI-weighted probs for extreme HHs."""
if (
extreme_mask is not None
and agi_probs is not None
and mask_slice is not None
):
out = np.empty(size, dtype=np.int64)
ext = mask_slice
n_ext = ext.sum()
n_norm = size - n_ext
if n_ext > 0:
out[ext] = rng.choice(len(blocks), size=n_ext, p=agi_probs)
if n_norm > 0:
out[~ext] = rng.choice(len(blocks), size=n_norm, p=probs)
return out
return rng.choice(len(blocks), size=size, p=probs)

indices = np.empty(n_total, dtype=np.int64)

# Clone 0: unrestricted draw
indices[:n_records] = rng.choice(len(blocks), size=n_records, p=probs)
indices[:n_records] = _sample(n_records, extreme_mask)

assigned_cds = np.empty((n_clones, n_records), dtype=object)
assigned_cds[0] = cds[indices[:n_records]]

for clone_idx in range(1, n_clones):
start = clone_idx * n_records
clone_indices = rng.choice(len(blocks), size=n_records, p=probs)
clone_indices = _sample(n_records, extreme_mask)
clone_cds = cds[clone_indices]

collisions = np.zeros(n_records, dtype=bool)
Expand All @@ -116,7 +160,20 @@ def assign_random_geography(
n_bad = collisions.sum()
if n_bad == 0:
break
clone_indices[collisions] = rng.choice(len(blocks), size=n_bad, p=probs)
bad_mask = collisions
if extreme_mask is not None and agi_probs is not None:
bad_ext = bad_mask & extreme_mask
bad_norm = bad_mask & ~extreme_mask
if bad_ext.sum() > 0:
clone_indices[bad_ext] = rng.choice(
len(blocks), size=bad_ext.sum(), p=agi_probs
)
if bad_norm.sum() > 0:
clone_indices[bad_norm] = rng.choice(
len(blocks), size=bad_norm.sum(), p=probs
)
else:
clone_indices[collisions] = rng.choice(len(blocks), size=n_bad, p=probs)
clone_cds = cds[clone_indices]
collisions = np.zeros(n_records, dtype=bool)
for prev in range(clone_idx):
Expand Down
Loading
Loading