diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e36542..f3188ab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: check-hooks-apply - id: check-useless-excludes - repo: https://github.com/tox-dev/pyproject-fmt - rev: v2.19.0 + rev: v2.21.1 hooks: - id: pyproject-fmt - repo: https://github.com/lyz-code/yamlfix @@ -47,7 +47,7 @@ repos: hooks: - id: yamllint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.6 + rev: v0.15.12 hooks: - id: ruff-check args: diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index af9dbc2..d0d9c1d 100644 Binary files a/src/aca_model/_benchmark_data/benchmark_params.pkl and b/src/aca_model/_benchmark_data/benchmark_params.pkl differ diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 8b4507a..6d39ac6 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -8,42 +8,46 @@ from typing import Any from lcm import AgeGrid, DiscreteGrid, Model +from lcm.typing import UserParams from aca_model.aca import PolicyVariant from aca_model.aca.regimes import build_all_regimes from aca_model.baseline.regimes import RegimeId -from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.config import MODEL_CONFIG, GridConfig def create_model( *, - policy: PolicyVariant = PolicyVariant.ACA, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] - | None = None, - grid_config: GridConfig = GRID_CONFIG, + n_subjects: int, + policy: PolicyVariant, + fixed_params: UserParams, + wage_params: Mapping[str, Any], + derived_categoricals: Mapping[str, DiscreteGrid], + grid_config: GridConfig, + pref_type_grid: DiscreteGrid, ) -> Model: """Create an ACA policy variant model. Args: - policy: Which ACA policy combination to apply. - fixed_params: Parameters to fix at model creation time. These are - partialled into compiled functions and removed from the params - template. Pass data-derived constants here; only estimation - parameters should go through `model.simulate(params=...)`. + n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. + policy: Which ACA policy combination to apply (e.g. + `PolicyVariant.ACA`). + fixed_params: Parameters to fix at model creation time. Pass + data-derived constants here; only estimation parameters + should go through `model.simulate(params=...)`. wage_params: Data-derived wage profile dict (`log_ft_wage_mean`, `log_ft_wage_std`, `adj_wage_hours_*`) used only at grid-build time to size the assets-floor to `-max_annual_labor_income`. Not routed to the pylcm Model. - derived_categoricals: Extra categorical mappings for derived variables - not in the model's state/action grids. Needed when `fixed_params` - contains `pd.Series` indexed by DAG function outputs. - grid_config: Continuous-grid point counts. Defaults to production - values. + derived_categoricals: Categorical mappings for `pd.Series` + fixed_params index levels that aren't model state/action + grids — `target_his`, `his`, `good_health`, `is_married`, + `pref_type`. + grid_config: Continuous-grid point counts. + pref_type_grid: Pref-type `DiscreteGrid`. Returns: - pylcm Model with ACA-specific function overrides. + pylcm Model. """ ages = AgeGrid( @@ -56,6 +60,7 @@ def create_model( grid_config=grid_config, fixed_params=fixed_params, wage_params=wage_params, + pref_type_grid=pref_type_grid, ) return Model( @@ -63,6 +68,7 @@ def create_model( ages=ages, regime_id_class=RegimeId, description=f"Structural retirement model ({policy.name})", - fixed_params=fixed_params or {}, + fixed_params=fixed_params, derived_categoricals=derived_categoricals, + n_subjects=n_subjects, ) diff --git a/src/aca_model/aca/regimes/__init__.py b/src/aca_model/aca/regimes/__init__.py index 2c143bd..96447d5 100644 --- a/src/aca_model/aca/regimes/__init__.py +++ b/src/aca_model/aca/regimes/__init__.py @@ -4,25 +4,30 @@ from collections.abc import Mapping from typing import Any -from lcm import Regime +from lcm import DiscreteGrid, Regime +from lcm.typing import UserParams from aca_model.aca.health_insurance import PolicyVariant from aca_model.aca.regimes._overrides import apply_aca_overrides from aca_model.baseline.regimes import build_all_regimes as baseline_build_all_regimes from aca_model.baseline.regimes._common import REGIME_SPECS -from aca_model.config import GRID_CONFIG, GridConfig +from aca_model.config import GridConfig def build_all_regimes( - policy: PolicyVariant, - grid_config: GridConfig = GRID_CONFIG, *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, + policy: PolicyVariant, + grid_config: GridConfig, + fixed_params: UserParams, + wage_params: Mapping[str, Any], + pref_type_grid: DiscreteGrid, ) -> dict[str, Regime]: """Build all 19 regimes with ACA policy overrides.""" regimes = baseline_build_all_regimes( - grid_config, fixed_params=fixed_params, wage_params=wage_params + grid_config=grid_config, + fixed_params=fixed_params, + wage_params=wage_params, + pref_type_grid=pref_type_grid, ) result = {} for name, regime in regimes.items(): diff --git a/src/aca_model/aca/regimes/_overrides.py b/src/aca_model/aca/regimes/_overrides.py index 79bd9e4..4ab590e 100644 --- a/src/aca_model/aca/regimes/_overrides.py +++ b/src/aca_model/aca/regimes/_overrides.py @@ -8,11 +8,12 @@ from aca_model.aca import health_insurance as aca_hi from aca_model.aca.health_insurance import PolicyVariant +from aca_model.baseline.regimes._common import RegimeSpec def apply_aca_overrides( functions: dict, - spec: dict[str, str], + spec: RegimeSpec, policy: PolicyVariant, ) -> None: """Override baseline functions with ACA versions in-place. diff --git a/src/aca_model/agent/assets_and_income.py b/src/aca_model/agent/assets_and_income.py index dfa83ef..92e9abb 100644 --- a/src/aca_model/agent/assets_and_income.py +++ b/src/aca_model/agent/assets_and_income.py @@ -35,41 +35,74 @@ def cash_on_hand( return assets + after_tax_income + ssi_benefit - hic_premium -def transfers( - cash_on_hand: FloatND, - consumption_floor: float, +def consumption_dollars_floor( + consumption_equiv_floor: float, equivalence_scale: FloatND, ) -> FloatND: - """Government transfers to enforce consumption floor. + """Per-household $-floor on consumption.""" + return consumption_equiv_floor * equivalence_scale - tr = max{0, C_min * equivalence_scale - cash_on_hand} - """ - floor = consumption_floor * equivalence_scale - return jnp.maximum(0.0, floor - cash_on_hand) + +def transfers( + cash_on_hand: FloatND, + consumption_dollars_floor: FloatND, +) -> FloatND: + """Government transfers to enforce the consumption floor.""" + return jnp.maximum(0.0, consumption_dollars_floor - cash_on_hand) def next_assets( cash_on_hand: FloatND, transfers: FloatND, pension_assets_adjustment: FloatND, - consumption: ContinuousAction, + consumption_dollars: ContinuousAction, oop_costs: FloatND, ) -> ContinuousState: - """Compute beginning-of-next-period assets. + """Compute beginning-of-next-period assets for non-terminal targets. OOP health costs are deducted here (not from cash_on_hand) so that the consumption choice does not condition on the HCC shock realization. """ return ( - cash_on_hand + transfers + pension_assets_adjustment - consumption - oop_costs + cash_on_hand + + transfers + + pension_assets_adjustment + - consumption_dollars + - oop_costs ) -def borrowing_constraint( - consumption: ContinuousAction, +def next_assets_when_dead( cash_on_hand: FloatND, transfers: FloatND, - pension_assets_adjustment: FloatND, + consumption_dollars: ContinuousAction, + oop_costs: FloatND, +) -> ContinuousState: + """Compute beginning-of-next-period assets for the dead/terminal target. + + No `pension_assets_adjustment` term: with no future, there is no + next-period pension wealth to impute against. Avoiding the dependency + also keeps the `dead` per-target transition's DAG free of `next_aime` + (which would otherwise need to come from a transition `dead` does not + have, since `aime` is not a state in the terminal regime). + """ + return cash_on_hand + transfers - consumption_dollars - oop_costs + + +def borrowing_constraint( + consumption_dollars: ContinuousAction, + cash_on_hand: FloatND, + consumption_dollars_floor: FloatND, ) -> BoolND: - """Consumption cannot exceed available resources (no borrowing).""" - return consumption <= cash_on_hand + transfers + pension_assets_adjustment + """Consumption cannot exceed post-transfer resources. + + Post-transfer resources are `max(cash_on_hand, consumption_dollars_floor)`: + the transfer system tops `cash_on_hand` to the floor when below, + otherwise resources are unchanged. The algebraic identity is + `cash_on_hand + transfers == max(cash_on_hand, floor)`; the `max` + form is preferred because the additive form rounds to `floor + ε` + (with `|ε| ~ ULP(|cash_on_hand|)`) at extreme cash, which flips + the kink-boundary comparison at large negative values of `assets`. + The `max` form returns `floor` exactly. + """ + return consumption_dollars <= jnp.maximum(cash_on_hand, consumption_dollars_floor) diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 3b0bb5e..5c08541 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -121,33 +121,102 @@ def leisure_retired( return time_endowment - health_loss -def utility( - consumption: ContinuousAction, +def consumption_equiv( + consumption_dollars: ContinuousAction, + equivalence_scale: FloatND, +) -> FloatND: + """Utility-equivalized consumption.""" + return consumption_dollars / equivalence_scale + + +def u_can_work( + consumption_equiv: FloatND, leisure: FloatND, - pref_type: DiscreteState, consumption_weight: FloatND, coefficient_rra: FloatND, - equivalence_scale: FloatND, utility_scale_factor: FloatND, ) -> FloatND: - """Within-period utility: CES aggregator over consumption and leisure. - - u = utility_scale_factor * ((c/eq_scale)^α * l^(1-α))^(1-γ) / (1-γ) - with log case for γ=1. `consumption_weight`, `coefficient_rra`, and - `utility_scale_factor` are indexed by `pref_type`. - """ - alpha = consumption_weight[pref_type] - gamma = coefficient_rra[pref_type] - equiv_cons = consumption / equivalence_scale - composite = equiv_cons**alpha * leisure ** (1.0 - alpha) + """Within-period utility for canwork regimes: CES over consumption and leisure.""" + composite = consumption_equiv**consumption_weight * leisure ** ( + 1.0 - consumption_weight + ) - one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) + one_minus_rra = jnp.where( + jnp.isclose(coefficient_rra, 1.0), 1.0, 1.0 - coefficient_rra + ) u = jnp.where( - jnp.isclose(gamma, 1.0), + jnp.isclose(coefficient_rra, 1.0), jnp.log(composite), - composite**one_minus_gamma / one_minus_gamma, + composite**one_minus_rra / one_minus_rra, + ) + return u * utility_scale_factor + + +def u_cannot_work( + consumption_equiv: FloatND, + good_health: IntND, + consumption_weight: FloatND, + coefficient_rra: FloatND, + utility_scale_factor: FloatND, + time_endowment: float, + leisure_cost_of_bad_health: float, +) -> FloatND: + """Within-period utility for forcedout regimes (no work, retired leisure).""" + leisure = leisure_retired( + good_health=good_health, + time_endowment=time_endowment, + leisure_cost_of_bad_health=leisure_cost_of_bad_health, + ) + return u_can_work( + consumption_equiv=consumption_equiv, + leisure=leisure, + consumption_weight=consumption_weight, + coefficient_rra=coefficient_rra, + utility_scale_factor=utility_scale_factor, + ) + + +def u_dead( + assets: ContinuousState, + bequest_shifter: float, + scaled_bequest_weight: float, + consumption_weight: FloatND, + coefficient_rra: FloatND, + utility_scale_factor: FloatND, +) -> FloatND: + """Terminal bequest utility for the dead regime.""" + return bequest( + assets=assets, + bequest_shifter=bequest_shifter, + scaled_bequest_weight=scaled_bequest_weight, + consumption_weight=consumption_weight, + coefficient_rra=coefficient_rra, + utility_scale_factor=utility_scale_factor, ) - return u * utility_scale_factor[pref_type] + + +def consumption_weight( + consumption_weights: FloatND, + pref_type: DiscreteState, +) -> FloatND: + """Per-type consumption weight indexed by preference type. + + Wired as a DAG function so pylcm broadcasts the scalar to every cell; + mirrors `discount_factor`. + """ + return consumption_weights[pref_type] + + +def coefficient_rra( + coefficients_rra: FloatND, + pref_type: DiscreteState, +) -> FloatND: + """Per-type CRRA coefficient indexed by preference type. + + Wired as a DAG function so pylcm broadcasts the scalar to every cell; + mirrors `discount_factor`. + """ + return coefficients_rra[pref_type] def discount_factor( @@ -164,38 +233,25 @@ def discount_factor( def utility_scale_factor( - average_consumption: float, + average_consumption_dollars: float, consumption_weight: FloatND, coefficient_rra: FloatND, time_endowment: float, fixed_cost_of_work_intercept: float, - fixed_cost_of_work_age_trend: float, - scale_reference_hours: float, - reference_age: int, - scale_reference_age: int, + reference_hours: float, ) -> FloatND: - """Compute scale factor so utility is approximately 1 at typical values. - - Uses leisure at `scale_reference_age` when working `scale_reference_hours` - (after fixed costs) and average consumption. Returns one scale per - preference type, indexed by pref_type. - """ - age_offset = scale_reference_age - reference_age - average_leisure = ( - time_endowment - - scale_reference_hours - - (fixed_cost_of_work_intercept + fixed_cost_of_work_age_trend * age_offset) - ) - u_cons = average_consumption**consumption_weight + """Compute the scale factor so utility is approximately 1 at typical values.""" + average_leisure = time_endowment - reference_hours - fixed_cost_of_work_intercept + u_cons = average_consumption_dollars**consumption_weight u_leisure = average_leisure ** (1.0 - consumption_weight) - one_minus_gamma = jnp.where( + one_minus_rra = jnp.where( jnp.isclose(coefficient_rra, 1.0), 1.0, 1.0 - coefficient_rra ) raw = jnp.where( jnp.isclose(coefficient_rra, 1.0), jnp.log(u_cons * u_leisure), - (u_cons * u_leisure) ** one_minus_gamma / one_minus_gamma, + (u_cons * u_leisure) ** one_minus_rra / one_minus_rra, ) return jnp.abs(1.0 / raw) @@ -227,7 +283,6 @@ def scaled_bequest_weight( def bequest( assets: ContinuousState, - pref_type: DiscreteState, bequest_shifter: float, scaled_bequest_weight: float, consumption_weight: FloatND, @@ -236,18 +291,18 @@ def bequest( ) -> FloatND: """Bequest function for terminal/dead states. - bequest = scale * bwt * (max(0,a) + shifter)^(α*(1-γ)) / (1-γ) - `consumption_weight`, `coefficient_rra`, and `utility_scale_factor` - are indexed by `pref_type`. + bequest = scale * bwt * + (max(0,a) + shifter)^(consumption_weight*(1 - coefficient_rra)) + / (1 - coefficient_rra) """ - alpha = consumption_weight[pref_type] - gamma = coefficient_rra[pref_type] assets_shifted = jnp.maximum(0.0, assets) + bequest_shifter - one_minus_gamma = jnp.where(jnp.isclose(gamma, 1.0), 1.0, 1.0 - gamma) + one_minus_rra = jnp.where( + jnp.isclose(coefficient_rra, 1.0), 1.0, 1.0 - coefficient_rra + ) val = jnp.where( - jnp.isclose(gamma, 1.0), + jnp.isclose(coefficient_rra, 1.0), jnp.log(assets_shifted), - assets_shifted ** (one_minus_gamma * alpha) / one_minus_gamma, + assets_shifted ** (one_minus_rra * consumption_weight) / one_minus_rra, ) - return val * scaled_bequest_weight * utility_scale_factor[pref_type] + return val * scaled_bequest_weight * utility_scale_factor diff --git a/src/aca_model/agent/utility.py b/src/aca_model/agent/utility.py deleted file mode 100644 index fd7bf16..0000000 --- a/src/aca_model/agent/utility.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Utility function variants for different regime types. - -- retired: forcedout regimes (no work, computes leisure_retired internally) -- dead: terminal bequest - -Canwork regimes use `preferences.utility` directly, with `leisure` computed -as a separate DAG function (`preferences.leisure` / `preferences.leisure_tied`). -""" - -from lcm.typing import ( - ContinuousAction, - ContinuousState, - DiscreteState, - FloatND, - IntND, -) - -from aca_model.agent import preferences - - -def retired( - consumption: ContinuousAction, - good_health: IntND, - equivalence_scale: FloatND, - pref_type: DiscreteState, - consumption_weight: FloatND, - coefficient_rra: FloatND, - utility_scale_factor: FloatND, - time_endowment: float, - leisure_cost_of_bad_health: float, -) -> FloatND: - """Utility for forcedout regimes (no work).""" - lei = preferences.leisure_retired( - good_health=good_health, - time_endowment=time_endowment, - leisure_cost_of_bad_health=leisure_cost_of_bad_health, - ) - return preferences.utility( - consumption=consumption, - leisure=lei, - pref_type=pref_type, - consumption_weight=consumption_weight, - coefficient_rra=coefficient_rra, - equivalence_scale=equivalence_scale, - utility_scale_factor=utility_scale_factor, - ) - - -def dead( - assets: ContinuousState, - pref_type: DiscreteState, - bequest_shifter: float, - scaled_bequest_weight: float, - consumption_weight: FloatND, - coefficient_rra: FloatND, - utility_scale_factor: FloatND, -) -> FloatND: - """Terminal bequest utility for dead regime.""" - return preferences.bequest( - assets=assets, - pref_type=pref_type, - bequest_shifter=bequest_shifter, - scaled_bequest_weight=scaled_bequest_weight, - consumption_weight=consumption_weight, - coefficient_rra=coefficient_rra, - utility_scale_factor=utility_scale_factor, - ) diff --git a/src/aca_model/baseline/health_insurance.py b/src/aca_model/baseline/health_insurance.py index 741d160..3732d6d 100644 --- a/src/aca_model/baseline/health_insurance.py +++ b/src/aca_model/baseline/health_insurance.py @@ -246,6 +246,29 @@ def is_medicaid_eligible(is_ssi_eligible: BoolND) -> BoolND: return is_ssi_eligible +def target_his( + his: IntND, + labor_supply: DiscreteAction, + is_medicaid_eligible: BoolND, +) -> IntND: + """Return the HIS class of the surviving target regime. + + Mirrors the cross-HIS branches inside `_make_transition_canwork` (retiree, + tied, nongroup): tied agents who stop working become nongroup, and + Medicaid-eligible agents are overridden to nongroup. Used by + `imputed_pension_wealth_next_period` to look up next-period imputation + coefficients at the target's HIS. + """ + tied_to_ng = (his == HealthInsuranceState.tied) & ( + labor_supply == LaborSupply.do_not_work + ) + return jnp.where( + tied_to_ng | is_medicaid_eligible, + HealthInsuranceState.nongroup, + his, + ).astype(jnp.int32) + + def oop_with_medicaid( primary_oop: FloatND, is_medicaid_eligible: BoolND, diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index a886495..98416ce 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -5,7 +5,7 @@ Usage: from aca_model.baseline.model import create_model - model = create_model() + model = create_model(n_subjects=..., fixed_params=..., wage_params=..., ...) params = get_default_params() V = model.solve(params) """ @@ -14,40 +14,44 @@ from typing import Any from lcm import AgeGrid, DiscreteGrid, Model +from lcm.typing import UserParams from aca_model.baseline.regimes import RegimeId, build_all_regimes -from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.config import MODEL_CONFIG, GridConfig def create_model( *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]] - | None = None, - grid_config: GridConfig = GRID_CONFIG, - pref_type_grid: DiscreteGrid | None = None, + n_subjects: int, + fixed_params: UserParams, + wage_params: Mapping[str, Any], + derived_categoricals: Mapping[str, DiscreteGrid], + grid_config: GridConfig, + pref_type_grid: DiscreteGrid, ) -> Model: """Create the baseline structural retirement model. Args: - fixed_params: Parameters to fix at model creation time. These are - partialled into compiled functions and removed from the params - template. Pass data-derived constants here; only estimation - parameters should go through `model.simulate(params=...)`. + n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. + fixed_params: Parameters to fix at model creation time. Fixed + params are partialled into compiled functions and removed + from the params template. Pass data-derived constants here; + only estimation parameters should go through + `model.simulate(params=...)`. wage_params: Data-derived wage profile dict (`log_ft_wage_mean`, `log_ft_wage_std`, `adj_wage_hours_*`) used only at grid-build time to size the assets-floor to `-max_annual_labor_income`. Not routed to the pylcm Model. - derived_categoricals: Extra categorical mappings for derived variables - not in the model's state/action grids. Needed when `fixed_params` - contains `pd.Series` indexed by DAG function outputs. - grid_config: Continuous-grid point counts. Defaults to production - values; pass `BENCHMARK_GRID_CONFIG` for a fast-but-structurally- - faithful benchmark. - pref_type_grid: Optional override for the `pref_type` `DiscreteGrid`. - Defaults to `DiscreteGrid(PrefType)`. Used by the benchmark to - substitute a 2-type variant with `DispatchStrategy.PARTITION_SCAN`. + derived_categoricals: Categorical mappings for `pd.Series` + fixed_params index levels that aren't model state/action + grids — `target_his`, `his`, `good_health`, `is_married`, + `pref_type`. + grid_config: Continuous-grid point counts. Pass `GRID_CONFIG` for + production values or `BENCHMARK_GRID_CONFIG` for the + fast-but-structurally-faithful benchmark. + pref_type_grid: Pref-type `DiscreteGrid`. Pass + `DiscreteGrid(PrefType)` for the production 3-type layout, + or a compact variant (e.g. `DiscreteGrid(BenchmarkPrefType)`). Returns: A pylcm Model with 19 regimes (18 non-terminal + dead) spanning @@ -60,7 +64,7 @@ def create_model( step="Y", ) regimes = build_all_regimes( - grid_config, + grid_config=grid_config, fixed_params=fixed_params, wage_params=wage_params, pref_type_grid=pref_type_grid, @@ -71,6 +75,7 @@ def create_model( ages=ages, regime_id_class=RegimeId, description="Baseline structural retirement model (pre-ACA)", - fixed_params=fixed_params or {}, + fixed_params=fixed_params, derived_categoricals=derived_categoricals, + n_subjects=n_subjects, ) diff --git a/src/aca_model/baseline/regimes/__init__.py b/src/aca_model/baseline/regimes/__init__.py index a0eaf9e..bd4c564 100644 --- a/src/aca_model/baseline/regimes/__init__.py +++ b/src/aca_model/baseline/regimes/__init__.py @@ -14,6 +14,7 @@ from typing import Any from lcm import DiscreteGrid, Regime +from lcm.typing import UserParams from aca_model.baseline.regimes import _nongroup as nongroup from aca_model.baseline.regimes import _retiree as retiree @@ -25,7 +26,7 @@ build_dead_regime, build_grids, ) -from aca_model.config import GRID_CONFIG, GridConfig +from aca_model.config import GridConfig __all__ = [ "REGIME_SPECS", @@ -58,23 +59,21 @@ def build_regime(name: str, grids: Grids) -> Regime: def build_all_regimes( - grid_config: GridConfig = GRID_CONFIG, *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - pref_type_grid: DiscreteGrid | None = None, + grid_config: GridConfig, + fixed_params: UserParams, + wage_params: Mapping[str, Any], + pref_type_grid: DiscreteGrid, ) -> dict[str, Regime]: """Build all 19 baseline regimes (18 non-terminal + dead). - `fixed_params` is forwarded to `build_grids` for data-driven AIME - breakpoints; `wage_params` for the data-driven assets floor; - either being `None` keeps the corresponding static fallback. - `pref_type_grid` lets callers inject a compact or partition-lifted - `DiscreteGrid(...)` (e.g. the benchmark uses a 2-type - `BenchmarkPrefType` with `DispatchStrategy.PARTITION_SCAN`). + `fixed_params` carries the PIA bends for the AIME piecewise grid; + `wage_params` sizes the assets-floor to `-max_annual_labor_income`; + `pref_type_grid` selects the pref-type cardinality (production + `DiscreteGrid(PrefType)` or the benchmark's 2-type variant). """ grids = build_grids( - grid_config, + grid_config=grid_config, fixed_params=fixed_params, wage_params=wage_params, pref_type_grid=pref_type_grid, diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 688a504..a2e3a13 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -6,7 +6,8 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass -from typing import Any +from types import MappingProxyType +from typing import Any, Literal, TypedDict import jax.numpy as jnp import lcm.shocks.ar1 @@ -22,21 +23,19 @@ ) from lcm.grids.continuous import ContinuousGrid from lcm.grids.piecewise import Piece, PiecewiseLinSpacedGrid -from lcm.typing import BoolND, FloatND +from lcm.typing import BoolND, FloatND, RegimeName, UserParams from aca_model.agent import ( assets_and_income, health, labor_market, preferences, - utility, ) from aca_model.agent.health import Health, HealthWithDisability from aca_model.agent.labor_market import LaborSupply, LaggedLaborSupply, SpousalIncome -from aca_model.agent.preferences import PrefType from aca_model.baseline import health_insurance -from aca_model.baseline.health_insurance import BuyPrivate -from aca_model.config import GRID_CONFIG, MODEL_CONFIG, GridConfig +from aca_model.baseline.health_insurance import BuyPrivate, HealthInsuranceState +from aca_model.config import MODEL_CONFIG, GridConfig from aca_model.environment import social_security, taxes from aca_model.environment.social_security import ClaimedSS @@ -64,8 +63,17 @@ class RegimeId: dead: int +class RegimeSpec(TypedDict): + """Structural decomposition of a regime: (HIS, Medicare, SS, work) axes.""" + + his: Literal["retiree", "tied", "nongroup"] + mc: Literal["nomc", "dimc", "oamc"] + ss: Literal["inelig", "choose", "forced"] + canwork: Literal["canwork", "forcedout"] + + # {his}_{mc}_{ss}_{canwork} -REGIME_SPECS: dict[str, dict[str, str]] = { +REGIME_SPECS: dict[str, RegimeSpec] = { "retiree_nomc_inelig_canwork": { "his": "retiree", "mc": "nomc", @@ -183,7 +191,7 @@ class RegimeId: class Grids: assets: LinSpacedGrid aime: ContinuousGrid - consumption: ContinuousGrid + consumption_dollars: ContinuousGrid wage_res: Any hcc_persistent: Any hcc_transitory: Any @@ -194,42 +202,37 @@ class Grids: # bend points (0 → kink_0 → kink_1 → kink_2). Total = 32. _AIME_PIECE_N_POINTS: tuple[int, int, int] = (10, 11, 11) -# Consumption grid: log-spaced from the lower bound of the -# `consumption_floor` parameter (BOUNDS in task_estimate_parameters) -# up to a high value that brackets the unconstrained optimum for the -# richest agents in the state space. Mirrors the struct-ret design -# (concentrate gridpoints where CRRA curvature is highest). -_CONSUMPTION_GRID_START: float = 100.0 -_CONSUMPTION_GRID_STOP: float = 300_000.0 + +MAX_CONSUMPTION_DOLLARS: float = 300_000.0 +"""Upper bound of the runtime consumption_dollars grid in $/year. + +Lives here next to the other grid bounds (assets `stop=500_000.0`, +AIME `stop=8_000.0`). + +TODO: route through `fixed_params` once pylcm#348 lands (so the bound +can vary across optimizer iterations without re-importing this module). +""" def build_grids( - grid_config: GridConfig = GRID_CONFIG, *, - fixed_params: Mapping[str, Any] | None = None, - wage_params: Mapping[str, Any] | None = None, - pref_type_grid: DiscreteGrid | None = None, + grid_config: GridConfig, + fixed_params: UserParams, + wage_params: Mapping[str, Any], + pref_type_grid: DiscreteGrid, ) -> Grids: """Build continuous-state/action grids from a `GridConfig`. - When `fixed_params` carries `pia_aime_grid`, the AIME grid becomes - a `PiecewiseLinSpacedGrid` breakpointed at the PIA bends (total 32 - points). When `wage_params` provides `log_ft_wage_mean` and friends - (as produced by `aca_data.task_wages`), the assets grid's lower - bound is set to `-max_annual_labor_income` so that the worst shock - lands on a gridpoint inside the support. Without `fixed_params` / - `wage_params` (bare model for tests / compile-only paths), both - grids fall back to their historical static shapes. + The AIME grid is `PiecewiseLinSpacedGrid` breakpointed at the PIA + bends from `fixed_params["pia_aime_grid"]` (total 32 points). The + assets grid's lower bound is `-max_annual_labor_income` computed + from `wage_params` (`log_ft_wage_mean`, `log_ft_wage_std`, + `adj_wage_hours_*`). `wage_params` is passed separately rather than via `fixed_params` because `log_ft_wage_mean` is a per-iteration param at estimation time (reconstructed from `wage_bias_coeffs_*`), not a fixed one; the grid floor must still be known at build time. - - `pref_type_grid` lets callers (e.g. the benchmark) substitute a - compact or partition-lifted `DiscreteGrid(...)` for the production - `DiscreteGrid(PrefType)`. When `None`, defaults to the production - 3-type grid with the default `DispatchStrategy.FUSED_VMAP`. """ # Unit-variance standardised shocks: the total_costs / wage # formulas rescale these by fixed_params-level std parameters @@ -244,13 +247,7 @@ def build_grids( sigma=(1.0 - _WAGE_RHO**2) ** 0.5, mu=0.0, ) - _HCC_RHO = 0.925 - hcc_persistent = lcm.shocks.ar1.Rouwenhorst( - n_points=grid_config.n_hcc_persistent_gridpoints, - rho=_HCC_RHO, - sigma=(1.0 - _HCC_RHO**2) ** 0.5, - mu=0.0, - ) + hcc_persistent = get_hcc_persistent_shock(grid_config=grid_config) hcc_transitory = lcm.shocks.iid.Normal( n_points=grid_config.n_hcc_transitory_gridpoints, gauss_hermite=True, @@ -258,11 +255,9 @@ def build_grids( sigma=1.0, ) - assets_start = 0.0 - if wage_params is not None and _has_required_wage_keys(wage_params=wage_params): - assets_start = -_compute_max_annual_labor_income( - wage_params=wage_params, wage_res_grid=wage_res - ) + assets_start = -_compute_max_annual_labor_income( + wage_params=wage_params, wage_res_grid=wage_res + ) return Grids( assets=LinSpacedGrid( @@ -272,56 +267,58 @@ def build_grids( batch_size=grid_config.n_assets_batch_size, ), aime=_build_aime_grid(grid_config=grid_config, fixed_params=fixed_params), - consumption=IrregSpacedGrid( - points=tuple( - float(c) - for c in np.geomspace( - _CONSUMPTION_GRID_START, - _CONSUMPTION_GRID_STOP, - num=grid_config.n_consumption_gridpoints, - ) - ), + consumption_dollars=IrregSpacedGrid( + n_points=grid_config.n_consumption_dollars_gridpoints, ), wage_res=wage_res, hcc_persistent=hcc_persistent, hcc_transitory=hcc_transitory, - pref_type=pref_type_grid or DiscreteGrid(PrefType), + pref_type=pref_type_grid, + ) + + +_HCC_RHO = 0.925 + + +def get_hcc_persistent_shock(*, grid_config: GridConfig) -> lcm.shocks.ar1.Rouwenhorst: + """Return the persistent-HCC AR(1) shock grid for a given `grid_config`. + + Exposed so callers that need the shock's gridpoints / transition + probs (e.g. `assemble_fixed_params`, the HCC insurer predictor) + can derive them from `grid_config` alone without instantiating a + full `Model`. + """ + return lcm.shocks.ar1.Rouwenhorst( + n_points=grid_config.n_hcc_persistent_gridpoints, + rho=_HCC_RHO, + sigma=(1.0 - _HCC_RHO**2) ** 0.5, + mu=0.0, ) +def get_hcc_persistent_grid_points(*, grid_config: GridConfig) -> FloatND: + """Materialise the persistent-HCC shock gridpoints for `grid_config`.""" + return get_hcc_persistent_shock(grid_config=grid_config).to_jax() + + def _build_aime_grid( - *, grid_config: GridConfig, fixed_params: Mapping[str, Any] | None + *, grid_config: GridConfig, fixed_params: UserParams ) -> ContinuousGrid: """Return the AIME grid. - With `pia_aime_grid` available, the grid is piecewise-linspaced with - breakpoints at the PIA bends and `_AIME_PIECE_N_POINTS` in each - segment. `n_aime_gridpoints` from `grid_config` is ignored on this - path; the total is fixed by the PIA structure (32 points). Without - the fixed params, falls back to the historical `LinSpacedGrid`. + The grid is piecewise-linspaced with breakpoints at the PIA bends + in `fixed_params["pia_aime_grid"]` and `_AIME_PIECE_N_POINTS` in + each segment. `n_aime_gridpoints` from `grid_config` is ignored on + this path; the total is fixed by the PIA structure (32 points). """ - if fixed_params is None or "pia_aime_grid" not in fixed_params: - return LinSpacedGrid( - start=0.0, stop=8_000.0, n_points=grid_config.n_aime_gridpoints - ) kinks = [float(k) for k in np.asarray(fixed_params["pia_aime_grid"])] pieces = ( Piece(interval=f"[{kinks[0]}, {kinks[1]})", n_points=_AIME_PIECE_N_POINTS[0]), Piece(interval=f"[{kinks[1]}, {kinks[2]})", n_points=_AIME_PIECE_N_POINTS[1]), Piece(interval=f"[{kinks[2]}, {kinks[3]}]", n_points=_AIME_PIECE_N_POINTS[2]), ) - return PiecewiseLinSpacedGrid(pieces=pieces) - - -def _has_required_wage_keys(*, wage_params: Mapping[str, Any]) -> bool: - return all( - key in wage_params - for key in ( - "log_ft_wage_mean", - "log_ft_wage_std", - "adj_wage_hours_exp", - "adj_wage_hours_int", - ) + return PiecewiseLinSpacedGrid( + pieces=pieces, batch_size=grid_config.n_aime_batch_size ) @@ -381,7 +378,7 @@ def _compute_max_annual_labor_income( } -def make_active_func(spec: dict[str, str]) -> Callable[..., Any]: +def make_active_func(spec: RegimeSpec) -> Callable[..., Any]: """Return the age predicate for a regime spec.""" key = (spec["mc"], spec["ss"], spec["canwork"]) predicate = _ACTIVE_PREDICATES.get(key) @@ -391,7 +388,7 @@ def make_active_func(spec: dict[str, str]) -> Callable[..., Any]: return predicate -def build_states(spec: dict[str, str], grids: Grids) -> dict: +def build_states(spec: RegimeSpec, grids: Grids) -> dict: """Build the state dict for a non-dead regime.""" can_work = spec["canwork"] == "canwork" @@ -414,7 +411,7 @@ def build_states(spec: dict[str, str], grids: Grids) -> dict: return states -def build_actions(spec: dict[str, str], grids: Grids) -> dict: +def build_actions(spec: RegimeSpec, grids: Grids) -> dict: """Build the action dict for a non-dead regime.""" actions: dict = {} if spec["ss"] == "choose": @@ -423,7 +420,7 @@ def build_actions(spec: dict[str, str], grids: Grids) -> dict: actions["labor_supply"] = DiscreteGrid(LaborSupply) if spec["his"] == "nongroup" and spec["mc"] == "nomc": actions["buy_private"] = DiscreteGrid(BuyPrivate) - actions["consumption"] = grids.consumption + actions["consumption_dollars"] = grids.consumption_dollars return actions @@ -437,14 +434,17 @@ def build_regime_probs(target: FloatND, survival: FloatND) -> FloatND: def build_dead_regime(grids: Grids) -> Regime: """Build the terminal dead regime. - `pref_type` is retained as a state so type-indexed preference params - (`consumption_weight`, `coefficient_rra`, `utility_scale_factor`) can - be indexed by it in the bequest utility. + `pref_type` is retained as a state so the pref-type-indexed DAG + functions (`consumption_weight`, `coefficient_rra`, + `utility_scale_factor`) can resolve their per-cell scalar in the + bequest utility. """ return Regime( transition=None, functions={ - "utility": utility.dead, + "utility": preferences.u_dead, + "consumption_weight": preferences.consumption_weight, + "coefficient_rra": preferences.coefficient_rra, "utility_scale_factor": preferences.utility_scale_factor, }, states={ @@ -455,7 +455,7 @@ def build_dead_regime(grids: Grids) -> Regime: ) -def select_ss_benefit(spec: dict[str, str]) -> Callable[..., Any]: +def select_ss_benefit(spec: RegimeSpec) -> Callable[..., Any]: """Select the appropriate SS benefit function for a regime.""" ss = spec["ss"] @@ -468,21 +468,21 @@ def select_ss_benefit(spec: dict[str, str]) -> Callable[..., Any]: return social_security.benefit_inelig_pre65 -def select_utility(spec: dict[str, str]) -> Callable[..., Any]: +def select_utility(spec: RegimeSpec) -> Callable[..., Any]: """Select the utility function for a regime.""" if spec["canwork"] != "canwork": - return utility.retired - return preferences.utility + return preferences.u_cannot_work + return preferences.u_can_work -def _select_leisure(spec: dict[str, str]) -> Callable[..., Any]: +def _select_leisure(spec: RegimeSpec) -> Callable[..., Any]: """Select the leisure function for a canwork regime.""" if spec["his"] == "tied": return preferences.leisure_tied return preferences.leisure -def build_common_functions(spec: dict[str, str]) -> dict: +def build_common_functions(spec: RegimeSpec) -> dict: """Build the shared functions dict for a non-dead regime. Contains all functions common to every HIS type. Per-HIS modules add @@ -511,9 +511,8 @@ def build_common_functions(spec: dict[str, str]) -> dict: functions["is_married"] = labor_market.is_married functions["equivalence_scale"] = preferences.equivalence_scale functions["utility_scale_factor"] = preferences.utility_scale_factor - # `discount_factor` is a DAG function that indexes the per-type - # Series by the pref_type state and returns a scalar. pylcm's - # default H picks the scalar up as a DAG-output H input. + functions["consumption_weight"] = preferences.consumption_weight + functions["coefficient_rra"] = preferences.coefficient_rra functions["discount_factor"] = preferences.discount_factor # PIA from pre-computed lookup table @@ -545,12 +544,14 @@ def build_common_functions(spec: dict[str, str]) -> dict: # Cash on hand and transfers functions["cash_on_hand"] = assets_and_income.cash_on_hand + functions["consumption_dollars_floor"] = assets_and_income.consumption_dollars_floor functions["transfers"] = assets_and_income.transfers + functions["consumption_equiv"] = preferences.consumption_equiv return functions -def precompute_targets(spec: dict[str, str]) -> dict[str, int]: +def precompute_target_regimes(spec: RegimeSpec) -> MappingProxyType[str, int]: """Pre-compute target regime IDs for each next-age bracket.""" def _resolve(his_val: str, mc_val: str, ss_val: str, canwork_val: str) -> int: @@ -566,22 +567,24 @@ def _resolve(his_val: str, mc_val: str, ss_val: str, canwork_val: str) -> int: ng_his = "nongroup" if spec["his"] == "tied" else spec["his"] - return { - "forcedout": _resolve(ng_his, "oamc", "forced", "forcedout"), - "forcedout_ng": _resolve("nongroup", "oamc", "forced", "forcedout"), - "forced_forced": _resolve(spec["his"], "oamc", "forced", "canwork"), - "forced_forced_ng": _resolve("nongroup", "oamc", "forced", "canwork"), - "forced_choose": _resolve(spec["his"], "oamc", "choose", "canwork"), - "forced_choose_ng": _resolve("nongroup", "oamc", "choose", "canwork"), - "dimc_choose": _resolve(spec["his"], "dimc", "choose", "canwork"), - "dimc_choose_ng": _resolve("nongroup", "dimc", "choose", "canwork"), - "nomc_choose": _resolve(spec["his"], "nomc", "choose", "canwork"), - "nomc_choose_ng": _resolve("nongroup", "nomc", "choose", "canwork"), - "dimc_inelig": _resolve(spec["his"], "dimc", "inelig", "canwork"), - "dimc_inelig_ng": _resolve("nongroup", "dimc", "inelig", "canwork"), - "nomc_inelig": _resolve(spec["his"], "nomc", "inelig", "canwork"), - "nomc_inelig_ng": _resolve("nongroup", "nomc", "inelig", "canwork"), - } + return MappingProxyType( + { + "forcedout": _resolve(ng_his, "oamc", "forced", "forcedout"), + "forcedout_ng": _resolve("nongroup", "oamc", "forced", "forcedout"), + "forced_forced": _resolve(spec["his"], "oamc", "forced", "canwork"), + "forced_forced_ng": _resolve("nongroup", "oamc", "forced", "canwork"), + "forced_choose": _resolve(spec["his"], "oamc", "choose", "canwork"), + "forced_choose_ng": _resolve("nongroup", "oamc", "choose", "canwork"), + "dimc_choose": _resolve(spec["his"], "dimc", "choose", "canwork"), + "dimc_choose_ng": _resolve("nongroup", "dimc", "choose", "canwork"), + "nomc_choose": _resolve(spec["his"], "nomc", "choose", "canwork"), + "nomc_choose_ng": _resolve("nongroup", "nomc", "choose", "canwork"), + "dimc_inelig": _resolve(spec["his"], "dimc", "inelig", "canwork"), + "dimc_inelig_ng": _resolve("nongroup", "dimc", "inelig", "canwork"), + "nomc_inelig": _resolve(spec["his"], "nomc", "inelig", "canwork"), + "nomc_inelig_ng": _resolve("nongroup", "nomc", "inelig", "canwork"), + } + ) _TARGET_KEYS = ( @@ -597,9 +600,9 @@ def _resolve(his_val: str, mc_val: str, ss_val: str, canwork_val: str) -> int: def make_targets(name: str) -> tuple[dict[str, int], dict[str, int]]: """Build own and nongroup target subsets for a regime name.""" - tgts = precompute_targets(REGIME_SPECS[name]) - own = {k: tgts[k] for k in _TARGET_KEYS} - ng = {k: tgts[k + "_ng"] for k in _TARGET_KEYS} + target_regimes = precompute_target_regimes(REGIME_SPECS[name]) + own = {k: target_regimes[k] for k in _TARGET_KEYS} + ng = {k: target_regimes[k + "_ng"] for k in _TARGET_KEYS} return own, ng @@ -638,11 +641,17 @@ def select_target_for_age( ) -def build_state_transitions(spec: dict[str, str]) -> dict: +def build_state_transitions(spec: RegimeSpec) -> dict: """Build the state transitions dict for a non-dead regime.""" transitions: dict = {} - transitions["health"] = _build_per_target_health(spec) - transitions["assets"] = assets_and_income.next_assets + transitions["assets"] = _build_per_target_regime_assets(spec) + transitions["health"] = _build_per_target_regime_health(spec) + claimed_ss_transition = _build_per_target_regime_claimed_ss(spec) + if claimed_ss_transition: + transitions["claimed_ss"] = claimed_ss_transition + lagged_labor_supply_transition = _build_per_target_regime_lagged_labor_supply(spec) + if lagged_labor_supply_transition: + transitions["lagged_labor_supply"] = lagged_labor_supply_transition transitions["pref_type"] = None transitions["aime"] = ( social_security.next_aime @@ -650,28 +659,54 @@ def build_state_transitions(spec: dict[str, str]) -> dict: else social_security.next_aime_disabled ) transitions["spousal_income"] = MarkovTransition(labor_market.next_spousal_income) - lagged_supply_transition = _build_per_target_lagged_labor_supply(spec) - if lagged_supply_transition: - transitions["lagged_labor_supply"] = lagged_supply_transition - claimed_ss_transition = _build_per_target_claimed_ss(spec) - if claimed_ss_transition: - transitions["claimed_ss"] = claimed_ss_transition return transitions -def _build_per_target_health(spec: dict[str, str]) -> dict: +def _build_per_target_regime_assets( + spec: RegimeSpec, +) -> dict[RegimeName, Callable[..., FloatND]]: + """Build per-target assets transitions. + + The `dead` target uses `next_assets_when_dead` (no + `pension_assets_adjustment`), so the dead per-target DAG does not + pull in the `next_aime`-dependent imputation chain — `dead` has no + `aime` state and pylcm cannot resolve `next_aime` there. Non-dead + targets use the full `next_assets` with the pension correction. + """ + target_regimes = precompute_target_regimes(spec) + id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} + + result: dict[RegimeName, Callable[..., FloatND]] = {} + seen_ids: set[int] = set() + + for target_id in target_regimes.values(): + if target_id in seen_ids: + continue + seen_ids.add(target_id) + target_name = id_to_name.get(target_id) + if target_name is None: + continue + result[target_name] = assets_and_income.next_assets + + result["dead"] = assets_and_income.next_assets_when_dead + return result + + +def _build_per_target_regime_health( + spec: RegimeSpec, +) -> dict[RegimeName, MarkovTransition]: """Build per-target health transitions. Pre-65 regimes use HealthWithDisability (3-state), post-65 use Health (2-state). Cross-grid transitions (3->2) happen at the age-65 boundary. """ - targets = precompute_targets(spec) + target_regimes = precompute_target_regimes(spec) id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} - result: dict[str, MarkovTransition] = {} + result: dict[RegimeName, MarkovTransition] = {} seen_ids: set[int] = set() - for target_id in targets.values(): + for target_id in target_regimes.values(): if target_id == RegimeId.dead or target_id in seen_ids: continue seen_ids.add(target_id) @@ -689,7 +724,9 @@ def _build_per_target_health(spec: dict[str, str]) -> dict: return result -def _build_per_target_claimed_ss(spec: dict[str, str]) -> dict: +def _build_per_target_regime_claimed_ss( + spec: RegimeSpec, +) -> dict[RegimeName, Callable[..., BoolND]]: """Build per-target claimed_ss transitions. - `choose` regimes (source has `claimed_ss`): absorbing transition. @@ -699,13 +736,13 @@ def _build_per_target_claimed_ss(spec: dict[str, str]) -> dict: if spec["ss"] in ("forced", "forcedout"): return {} - targets = precompute_targets(spec) + target_regimes = precompute_target_regimes(spec) id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} - result: dict = {} + result: dict[RegimeName, Callable[..., BoolND]] = {} seen_ids: set[int] = set() - for target_id in targets.values(): + for target_id in target_regimes.values(): if target_id == RegimeId.dead or target_id in seen_ids: continue seen_ids.add(target_id) @@ -724,7 +761,9 @@ def _build_per_target_claimed_ss(spec: dict[str, str]) -> dict: return result -def _build_per_target_lagged_labor_supply(spec: dict[str, str]) -> dict: +def _build_per_target_regime_lagged_labor_supply( + spec: RegimeSpec, +) -> dict[RegimeName, Callable[..., BoolND]]: """Build per-target lagged_labor_supply transitions. `lagged_labor_supply` exists in canwork non-tied regimes. Tied regimes @@ -738,13 +777,13 @@ def _build_per_target_lagged_labor_supply(spec: dict[str, str]) -> dict: if spec["canwork"] != "canwork": return {} - targets = precompute_targets(spec) + target_regimes = precompute_target_regimes(spec) id_to_name = {getattr(RegimeId, name): name for name in REGIME_SPECS} - result: dict = {} + result: dict[RegimeName, Callable[..., BoolND]] = {} seen_ids: set[int] = set() - for target_id in targets.values(): + for target_id in target_regimes.values(): if target_id == RegimeId.dead or target_id in seen_ids: continue seen_ids.add(target_id) diff --git a/src/aca_model/baseline/regimes/_nongroup.py b/src/aca_model/baseline/regimes/_nongroup.py index 5cdb6dc..a723b44 100644 --- a/src/aca_model/baseline/regimes/_nongroup.py +++ b/src/aca_model/baseline/regimes/_nongroup.py @@ -15,6 +15,7 @@ from aca_model.baseline.regimes._common import ( REGIME_SPECS, Grids, + RegimeSpec, build_actions, build_common_functions, build_regime_probs, @@ -74,7 +75,7 @@ def transition( return transition -def _build_functions(spec: dict[str, str]) -> dict: +def _build_functions(spec: RegimeSpec) -> dict: """Build functions dict for a nongroup regime.""" can_work = spec["canwork"] == "canwork" functions = build_common_functions(spec) @@ -99,6 +100,10 @@ def _build_functions(spec: dict[str, str]) -> dict: functions["pension_wealth_next_before_adjustment"] = ( pensions.wealth_next_before_adjustment ) + functions["target_his"] = health_insurance.target_his + functions["imputed_pension_wealth_next_period"] = ( + pensions.imputed_pension_wealth_next_period + ) functions["pension_assets_adjustment"] = pensions.assets_adjustment functions["total_to_pia"] = pensions.total_to_pia diff --git a/src/aca_model/baseline/regimes/_retiree.py b/src/aca_model/baseline/regimes/_retiree.py index ac76bfd..4f16faa 100644 --- a/src/aca_model/baseline/regimes/_retiree.py +++ b/src/aca_model/baseline/regimes/_retiree.py @@ -16,6 +16,7 @@ from aca_model.baseline.regimes._common import ( REGIME_SPECS, Grids, + RegimeSpec, build_actions, build_common_functions, build_regime_probs, @@ -86,7 +87,7 @@ def transition( return transition -def _build_functions(spec: dict[str, str]) -> dict: +def _build_functions(spec: RegimeSpec) -> dict: """Build functions dict for a retiree regime.""" can_work = spec["canwork"] == "canwork" functions = build_common_functions(spec) @@ -109,6 +110,10 @@ def _build_functions(spec: dict[str, str]) -> dict: functions["pension_wealth_next_before_adjustment"] = ( pensions.wealth_next_before_adjustment ) + functions["target_his"] = health_insurance.target_his + functions["imputed_pension_wealth_next_period"] = ( + pensions.imputed_pension_wealth_next_period + ) functions["pension_assets_adjustment"] = pensions.assets_adjustment functions["total_to_pia"] = pensions.total_to_pia diff --git a/src/aca_model/baseline/regimes/_tied.py b/src/aca_model/baseline/regimes/_tied.py index 5d59274..df76fa4 100644 --- a/src/aca_model/baseline/regimes/_tied.py +++ b/src/aca_model/baseline/regimes/_tied.py @@ -17,6 +17,7 @@ from aca_model.baseline.regimes._common import ( REGIME_SPECS, Grids, + RegimeSpec, build_actions, build_common_functions, build_regime_probs, @@ -65,7 +66,7 @@ def transition( return transition -def _build_functions(spec: dict[str, str]) -> dict: +def _build_functions(spec: RegimeSpec) -> dict: """Build functions dict for a tied regime.""" functions = build_common_functions(spec) @@ -83,6 +84,10 @@ def _build_functions(spec: dict[str, str]) -> dict: functions["pension_wealth_next_before_adjustment"] = ( pensions.wealth_next_before_adjustment ) + functions["target_his"] = health_insurance.target_his + functions["imputed_pension_wealth_next_period"] = ( + pensions.imputed_pension_wealth_next_period + ) functions["pension_assets_adjustment"] = pensions.assets_adjustment functions["total_to_pia"] = pensions.total_to_pia diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 5a822c9..5b519d5 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -7,20 +7,15 @@ The benchmark substitutes a 2-type `BenchmarkPrefType` for the production 3-type `PrefType`, which saves ~33% of the compile + -execution volume over all 18 regimes. By default the pref_type axis -is handled via pylcm's fused-vmap dispatch (no `DispatchStrategy` -imported — this module stays compatible with pylcm versions that -pre-date the enum). Callers that want partition-lifted dispatch -(`PARTITION_SCAN` / `PARTITION_VMAP`) construct the grid themselves -and pass it via `pref_type_grid`. - -Parameters (`fixed_params` + `params`) are a committed stub fixture -packaged alongside the module at -`src/aca_model/_benchmark_data/benchmark_params.pkl` — aggregate-level -values (policy schedules, transition probabilities, fitted -coefficients) with no runtime dependency on aca-data or any data-prep -package. The pref-type-indexed entries in `params` are truncated to -two rows on load to match `BenchmarkPrefType`. +execution volume over all 18 regimes. + +Parameters (`fixed_params` + `params`) are a committed snapshot at +`src/aca_model/_benchmark_data/benchmark_params.pkl`, generated by +`scripts/regen_benchmark_params.py` against the current aca-data + +aca-estimation + aca-model code. Pref-type-indexed Series in `params` +are pre-truncated to two rows so the snapshot loads with no further +reshaping; regenerate after any change that affects `fixed_params` +shape (regime DAGs, aca-data outputs, key renames). Initial conditions are drawn randomly per call — assets/aime/wage_res from their grid ranges, discrete states from their categories, regimes @@ -34,7 +29,6 @@ import cloudpickle import jax.numpy as jnp import numpy as np -import pandas as pd from jax import Array from lcm import DiscreteGrid, Model @@ -44,6 +38,7 @@ from aca_model.baseline.health_insurance import HealthInsuranceState from aca_model.baseline.model import create_model from aca_model.config import BENCHMARK_GRID_CONFIG +from aca_model.consumption_dollars_grid import inject_consumption_dollars_points _PARAMS_FILE = ( Path(__file__).resolve().parent / "_benchmark_data" / "benchmark_params.pkl" @@ -55,6 +50,7 @@ "good_health": DiscreteGrid(GoodHealth), "is_married": DiscreteGrid(IsMarried), "his": DiscreteGrid(HealthInsuranceState), + "target_his": DiscreteGrid(HealthInsuranceState), "pref_type": DiscreteGrid(BenchmarkPrefType), } @@ -70,66 +66,58 @@ ) -def create_benchmark_model(*, pref_type_grid: DiscreteGrid | None = None) -> Model: +def create_benchmark_model( + *, + n_subjects: int, + pref_type_grid: DiscreteGrid, +) -> Model: """Create the aca baseline with `BENCHMARK_GRID_CONFIG` and frozen fixed_params. The benchmark uses a 2-type `BenchmarkPrefType`. No `batch_size != 0` on any grid (continuous grids inherit - `BENCHMARK_GRID_CONFIG.n_assets_batch_size = 0`). + `BENCHMARK_GRID_CONFIG.n_assets_batch_size = 0` and + `n_aime_batch_size = 0`). Args: - pref_type_grid: Override for the pref_type grid. Default is a plain - `DiscreteGrid(BenchmarkPrefType)` (fused vmap). Pass - `DiscreteGrid(BenchmarkPrefType, dispatch=DispatchStrategy.PARTITION_SCAN)` - (or `PARTITION_VMAP`) to get the partition-lifted kernel — the - recommended production setting for aca-model at scale, but only - supported on pylcm versions that expose `DispatchStrategy`. + n_subjects: Forwarded to `lcm.Model(n_subjects=...)`. When set, the + first matching `simulate(...)` call AOT-compiles all simulate + functions for that batch shape. + pref_type_grid: Pref-type grid; pass `DiscreteGrid(BenchmarkPrefType)`. """ - if pref_type_grid is None: - pref_type_grid = DiscreteGrid(BenchmarkPrefType) - fixed_params, _ = get_benchmark_params() + fixed_params, wage_params, _ = get_benchmark_params(model=None) return create_model( grid_config=BENCHMARK_GRID_CONFIG, fixed_params=fixed_params, + wage_params=wage_params, derived_categoricals=_DERIVED_CATEGORICALS, pref_type_grid=pref_type_grid, + n_subjects=n_subjects, ) -def get_benchmark_params() -> tuple[dict[str, Any], dict[str, Any]]: - """Load the frozen `(fixed_params, params)` snapshot. +def get_benchmark_params( + *, model: Model | None +) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + """Load the frozen `(fixed_params, wage_params, params)` snapshot. - Pref-type-indexed `pd.Series` in `params` are truncated to - `_N_BENCHMARK_PREF_TYPES` rows so they line up with - `BenchmarkPrefType`'s categories. + When `model` is provided, consumption_dollars gridpoints are injected + into `params` for each regime that declares `consumption_dollars` as + an `IrregSpacedGrid` with runtime-supplied points. The lower bound is + read from `params["consumption_dollars_floor"]`. Pass `model=None` to + skip injection (e.g. when constructing the model with `fixed_params`). """ with _PARAMS_FILE.open("rb") as fh: data = cloudpickle.load(fh) fixed_params = data["fixed_params"] - params = _truncate_pref_type_indexed(data["params"]) - return fixed_params, params - - -def _truncate_pref_type_indexed(params: dict[str, Any]) -> dict[str, Any]: - """Return a copy of `params` with pref_type-indexed Series cut to 2 rows. - - A Series is pref_type-indexed when its index labels start with - `"type_"`. The first `_N_BENCHMARK_PREF_TYPES` rows are kept so the - Series aligns with `BenchmarkPrefType.type_0`, `type_1`, ... - """ - out: dict[str, Any] = {} - for key, value in params.items(): - if isinstance(value, pd.Series) and all( - str(label).startswith("type_") for label in value.index - ): - out[key] = value.iloc[:_N_BENCHMARK_PREF_TYPES] - else: - out[key] = value - return out + wage_params = data["wage_params"] + params = data["params"] + if model is not None: + params = inject_consumption_dollars_points(params=params, model=model) + return fixed_params, wage_params, params def get_benchmark_initial_conditions( - *, model: Model, n_subjects: int = 100, seed: int = 42 + *, model: Model, n_subjects: int, seed: int ) -> dict[str, Array]: """Draw random feasible initial conditions across five age-51 regimes. @@ -143,10 +131,14 @@ def get_benchmark_initial_conditions( regime = rng.choice(regime_ids, size=n_subjects).astype(np.int32) # Grid ranges come from any of the five regimes (shared structure). + # Use to_jax() so the helper handles both LinSpacedGrid and + # PiecewiseLinSpacedGrid (the latter has no `.start` / `.stop`). ref_regime = model.regimes[_INITIAL_REGIMES[0]] grids = ref_regime.states - assets_lo, assets_hi = grids["assets"].start, grids["assets"].stop - aime_lo, aime_hi = grids["aime"].start, grids["aime"].stop + assets_pts = np.asarray(grids["assets"].to_jax()) + aime_pts = np.asarray(grids["aime"].to_jax()) + assets_lo, assets_hi = float(assets_pts.min()), float(assets_pts.max()) + aime_lo, aime_hi = float(aime_pts.min()), float(aime_pts.max()) hcc_p_pts = np.asarray(grids["hcc_persistent"].to_jax()) hcc_t_pts = np.asarray(grids["hcc_transitory"].to_jax()) wage_res_pts = np.asarray(grids["log_ft_wage_res"].to_jax()) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 1bae45f..101ef2d 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -29,14 +29,16 @@ class ModelConfig: class GridConfig: n_assets_gridpoints: int = 24 n_aime_gridpoints: int = 12 - n_consumption_gridpoints: int = 70 + n_consumption_dollars_gridpoints: int = 70 n_wage_res_gridpoints: int = 5 n_hcc_persistent_gridpoints: int = 3 n_hcc_transitory_gridpoints: int = 5 - # `batch_size` on the assets grid: chunked vmap stride for the - # outer state loop. Useful at prod sizes for memory reasons; set - # to 0 in BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. - n_assets_batch_size: int = 2 + # `batch_size` on the assets / AIME grids: chunked vmap stride for the + # outer state loop. Both partition the per-period Q intermediate so it + # fits in V100 16 GB once we splay across `pref_type`. Set to 0 in + # BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. + n_assets_batch_size: int = 1 + n_aime_batch_size: int = 1 MODEL_CONFIG = ModelConfig() @@ -45,9 +47,10 @@ class GridConfig: BENCHMARK_GRID_CONFIG = GridConfig( n_assets_gridpoints=3, n_aime_gridpoints=3, - n_consumption_gridpoints=5, + n_consumption_dollars_gridpoints=5, n_wage_res_gridpoints=3, n_hcc_persistent_gridpoints=3, n_hcc_transitory_gridpoints=3, n_assets_batch_size=0, + n_aime_batch_size=0, ) diff --git a/src/aca_model/consumption_dollars_grid.py b/src/aca_model/consumption_dollars_grid.py new file mode 100644 index 0000000..7487fd8 --- /dev/null +++ b/src/aca_model/consumption_dollars_grid.py @@ -0,0 +1,124 @@ +"""Runtime-supplied gridpoints for the consumption_dollars action. + +Consumption is declared as `IrregSpacedGrid(n_points=N)` in +`baseline.regimes._common.build_grids` so the bounds can track +runtime parameters: the lower bound from the per-iteration +`consumption_equiv_floor` parameter (and its couples-scaled twin), +the upper bound from `MAX_CONSUMPTION_DOLLARS` in +`baseline.regimes._common`. Callers must inject the actual gridpoints +into `params` via `inject_consumption_dollars_points` before calling +`model.solve()` / `model.simulate()`. + +The grid pins the two regime-relevant transfer-floor levels exactly +on the action grid so the borrowing constraint's +`max(cash_on_hand, floor)` boundary lands on a feasible action for +both single and married households: + +- `pts[0] = consumption_equiv_floor` (single household: equiv_scale=1) +- `pts[1] = consumption_equiv_floor * 2 ** exponent` (married) +- `pts[2:] = geomspace(pts[1], MAX_CONSUMPTION_DOLLARS, n_points - 1)` +""" + +from collections.abc import Mapping +from typing import Any + +import jax.numpy as jnp +from jax import Array +from lcm import IrregSpacedGrid, Model + +from aca_model.baseline.regimes._common import MAX_CONSUMPTION_DOLLARS + + +def inject_consumption_dollars_points( + *, + params: Mapping[str, Any], + model: Model, +) -> dict[str, Any]: + """Inject consumption_dollars gridpoints into per-regime params. + + Walks every regime, reads its `consumption_dollars` action grid, + and writes `params[regime_name]["consumption_dollars"] = {"points": }`. + + The lower two gridpoints are the single and married unequiv + transfer floors; the rest are geomspaced from the married floor up + to `MAX_CONSUMPTION_DOLLARS`. + + Args: + params: Existing params mapping with `consumption_equiv_floor` + (per-equivalent floor, varies per iteration). Returned as a + new dict; the input is not mutated. + model: Model whose regimes carry the runtime-points grid and + whose `fixed_params["exponent"]` sets the married + equivalence-scale exponent. + + Returns: + New params dict with consumption_dollars points injected. + + Raises: + ValueError: If a regime is missing the `consumption_dollars` + action, or its grid is not an `IrregSpacedGrid` with + `pass_points_at_runtime=True`. + """ + consumption_equiv_floor = jnp.asarray(params["consumption_equiv_floor"]) + exponent = jnp.asarray(model.fixed_params["exponent"]) + out: dict[str, Any] = dict(params) + for regime_name, regime in model.regimes.items(): + if regime.terminal: + continue + grid = regime.actions.get("consumption_dollars") + if grid is None: + msg = ( + f"Regime {regime_name!r} is missing the `consumption_dollars` " + f"action — the runtime-points grid must be on every regime." + ) + raise ValueError(msg) + if not (isinstance(grid, IrregSpacedGrid) and grid.pass_points_at_runtime): + msg = ( + f"Regime {regime_name!r} has a `consumption_dollars` action " + f"whose grid is not an `IrregSpacedGrid(pass_points_at_runtime=True)`; " + f"got {type(grid).__name__}." + ) + raise ValueError(msg) + # Runtime-points grids always have `n_points` set (the constructor + # rejects the (points=None, n_points=None) combo); narrow for ty. + assert grid.n_points is not None + points = _compute_consumption_dollars_points( + consumption_equiv_floor=consumption_equiv_floor, + exponent=exponent, + n_points=grid.n_points, + ) + regime_entry = dict(out.get(regime_name, {})) + regime_entry["consumption_dollars"] = {"points": points} + out[regime_name] = regime_entry + return out + + +def _compute_consumption_dollars_points( + *, + consumption_equiv_floor: Array, + exponent: Array, + n_points: int, +) -> Array: + """Return log-spaced consumption_dollars gridpoints with both floors pinned. + + Single and married households face different unequiv (in-$) floors + (`consumption_equiv_floor` and the married-scaled twin + respectively). Both must land exactly on the action grid so the + borrowing constraint's `max(cash_on_hand, floor)` kink boundary is + a feasible action; otherwise sub-ULP drift can flip the `<=` + comparison for subjects with very negative cash. The geomspace + tail starts at the married floor and runs to + `MAX_CONSUMPTION_DOLLARS` so the two pinned points stay strictly + increasing. + """ + married_unequiv_floor = consumption_equiv_floor * jnp.asarray(2.0) ** exponent + tail = jnp.geomspace( + married_unequiv_floor, MAX_CONSUMPTION_DOLLARS, num=n_points - 1 + ) + pts = jnp.concatenate([consumption_equiv_floor[None], tail]) + # `jnp.geomspace` returns `start * r^0` for the first tail element, + # which mathematically equals `married_unequiv_floor` but drifts by + # sub-ULP on some XLA backends. Pin the slot back to the exact + # arithmetic value so the borrowing-constraint kink boundary at the + # married floor is exactly representable. + return pts.at[1].set(married_unequiv_floor) diff --git a/src/aca_model/environment/pensions.py b/src/aca_model/environment/pensions.py index a23a800..eef72d4 100644 --- a/src/aca_model/environment/pensions.py +++ b/src/aca_model/environment/pensions.py @@ -4,7 +4,7 @@ """ import jax.numpy as jnp -from lcm.typing import FloatND, IntND, Period +from lcm.typing import ContinuousState, FloatND, IntND, Period def benefit( @@ -164,3 +164,42 @@ def assets_adjustment( * unconditional_survival_prob[period] * (pension_wealth_next_before_adjustment - imputed_pension_wealth_next_period) ) + + +def imputed_pension_wealth_next_period( + next_aime: ContinuousState, + target_his: IntND, + period: Period, + pia_table: FloatND, + pia_aime_grid: FloatND, + imp_intercept_next_period: FloatND, + imp_pia_coeff_next_period: FloatND, + imp_pia_kink_0_coeff_next_period: FloatND, + imp_pia_kink_1_coeff_next_period: FloatND, + imp_kink_0_next_period: FloatND, + imp_kink_1_next_period: FloatND, + imp_fraction_receiving_next_period: FloatND, + epdv_constant_pension_next_period: FloatND, +) -> FloatND: + """Imputed pension wealth at next period using the target regime's HIS. + + Mirrors `benefit` and `wealth` but indexes into 1-period-shifted views + of the imputation arrays so all subscripts use bare-name parameters + (`period`, `target_his`). Inlining is required: pylcm's AST shape + inference inspects the registered function's body and does not trace + through nested calls into `benefit`. + """ + next_pia = jnp.interp(next_aime, pia_aime_grid, pia_table) + + intercept = imp_intercept_next_period[period, target_his] + pia_pred = imp_pia_coeff_next_period[period, target_his] * next_pia + kink_0_adj = imp_pia_kink_0_coeff_next_period[period, target_his] * jnp.maximum( + 0.0, next_pia - imp_kink_0_next_period[period] + ) + kink_1_adj = imp_pia_kink_1_coeff_next_period[period, target_his] * jnp.maximum( + 0.0, next_pia - imp_kink_1_next_period[period] + ) + + full_benefit = jnp.maximum(0.0, intercept + pia_pred + kink_0_adj + kink_1_adj) + benefit_next = full_benefit * imp_fraction_receiving_next_period[period] + return benefit_next * epdv_constant_pension_next_period[period] diff --git a/src/aca_model/environment/social_security.py b/src/aca_model/environment/social_security.py index c9ce1f5..e3574cf 100644 --- a/src/aca_model/environment/social_security.py +++ b/src/aca_model/environment/social_security.py @@ -30,7 +30,7 @@ def next_claimed_ss( def enter_claimed_ss() -> DiscreteState: """Initial claimed_ss when entering the SS eligibility window.""" - return ClaimedSS.no + return jnp.int32(ClaimedSS.no) # --- PIA lookup (DAG functions) --- diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1da1dcf --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,4 @@ +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) diff --git a/tests/helpers/model.py b/tests/helpers/model.py new file mode 100644 index 0000000..be778b4 --- /dev/null +++ b/tests/helpers/model.py @@ -0,0 +1,53 @@ +"""Tiny factories that wrap `create_model` with the benchmark snapshot. + +Used by tests that need a structurally faithful model without spelling +out fixed_params, wage_params, and a pref-type grid at every call site. +Production callers (aca-estimation, scripts) assemble these explicitly. +""" + +from lcm import DiscreteGrid, Model + +from aca_model.aca.health_insurance import PolicyVariant +from aca_model.aca.model import create_model as _create_aca_model +from aca_model.agent.health import GoodHealth +from aca_model.agent.labor_market import IsMarried +from aca_model.agent.preferences import BenchmarkPrefType +from aca_model.baseline.health_insurance import HealthInsuranceState +from aca_model.baseline.model import create_model as _create_baseline_model +from aca_model.benchmark import get_benchmark_params +from aca_model.config import BENCHMARK_GRID_CONFIG + +_DERIVED_CATEGORICALS = { + "good_health": DiscreteGrid(GoodHealth), + "is_married": DiscreteGrid(IsMarried), + "his": DiscreteGrid(HealthInsuranceState), + "target_his": DiscreteGrid(HealthInsuranceState), + "pref_type": DiscreteGrid(BenchmarkPrefType), +} + + +def make_baseline_model(*, n_subjects: int) -> Model: + """Baseline model on `BENCHMARK_GRID_CONFIG` with the benchmark snapshot params.""" + fixed_params, wage_params, _ = get_benchmark_params(model=None) + return _create_baseline_model( + n_subjects=n_subjects, + fixed_params=fixed_params, + wage_params=wage_params, + derived_categoricals=_DERIVED_CATEGORICALS, + grid_config=BENCHMARK_GRID_CONFIG, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + + +def make_aca_model(*, n_subjects: int, policy: PolicyVariant) -> Model: + """ACA model on `BENCHMARK_GRID_CONFIG` with the benchmark snapshot params.""" + fixed_params, wage_params, _ = get_benchmark_params(model=None) + return _create_aca_model( + n_subjects=n_subjects, + policy=policy, + fixed_params=fixed_params, + wage_params=wage_params, + derived_categoricals=_DERIVED_CATEGORICALS, + grid_config=BENCHMARK_GRID_CONFIG, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 72fb473..b1be815 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -1,7 +1,10 @@ """Integration test: the benchmark-sized baseline solves + simulates end-to-end.""" +import numpy as np import pytest +from lcm import DiscreteGrid +from aca_model.agent.preferences import BenchmarkPrefType from aca_model.benchmark import ( create_benchmark_model, get_benchmark_initial_conditions, @@ -12,8 +15,11 @@ @pytest.mark.long_running def test_benchmark_model_simulates_end_to_end() -> None: n_subjects = 20 - model = create_benchmark_model() - _, params = get_benchmark_params() + model = create_benchmark_model( + n_subjects=n_subjects, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + _, _, params = get_benchmark_params(model=model) initial_conditions = get_benchmark_initial_conditions( model=model, n_subjects=n_subjects, seed=0 ) @@ -31,3 +37,48 @@ def test_benchmark_model_simulates_end_to_end() -> None: # Period 0 rows reflect initial conditions — no NaN in continuous states. period_0 = df.loc[df["period"] == 0] assert not period_0[["assets", "aime", "value"]].isna().any().any() + + +@pytest.mark.long_running +def test_benchmark_simulate_obeys_borrowing_constraint() -> None: + """`consumption_dollars <= max(cash_on_hand, floor)` holds for every alive row. + + The simulator only ever picks feasible actions — the borrowing + constraint must hold post-hoc on the simulated panel. A regression + that drops the constraint from a regime, replaces the floor with + something looser, or lets an action grid skip the floor would + surface as a row with `consumption_dollars > max(cash_on_hand, floor)`. + + The constraint's RHS is `max(cash_on_hand, floor)` rather than + `cash_on_hand + transfers`: the additive form rounds short by + sub-ULP at extreme `|cash_on_hand|`, so the post-hoc check would + also flip on the same kink. + """ + n_subjects = 4 + model = create_benchmark_model( + n_subjects=n_subjects, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + _, _, params = get_benchmark_params(model=model) + initial_conditions = get_benchmark_initial_conditions( + model=model, n_subjects=n_subjects, seed=0 + ) + + result = model.simulate( + params=params, + initial_conditions=initial_conditions, + period_to_regime_to_V_arr=None, + log_level="off", + check_initial_conditions=False, + ) + + df = result.to_dataframe(additional_targets=["cash_on_hand", "equivalence_scale"]) + alive = df.loc[df["regime"] != "dead"].copy() + consumption_dollars_floor = float(params["consumption_dollars_floor"]) + floor = consumption_dollars_floor * alive["equivalence_scale"].to_numpy() + rhs = np.maximum(alive["cash_on_hand"].to_numpy(), floor) + slack = rhs - alive["consumption_dollars"].to_numpy() + assert (slack >= 0).all(), ( + f"borrowing_constraint violated on {int((slack < 0).sum())} row(s); " + f"min slack = {slack.min():.6g}" + ) diff --git a/tests/test_budget_chain_integration.py b/tests/test_budget_chain_integration.py index 8bf206c..f087d16 100644 --- a/tests/test_budget_chain_integration.py +++ b/tests/test_budget_chain_integration.py @@ -108,7 +108,7 @@ def test_retired_agent_with_pension() -> None: def test_transfers_kick_in_below_floor() -> None: - """When cash_on_hand < consumption_floor, transfers fill the gap.""" + """When cash_on_hand < consumption_dollars_floor, transfers fill the gap.""" functions = { "cash_on_hand": assets_and_income.cash_on_hand, "transfers": assets_and_income.transfers, @@ -126,14 +126,12 @@ def test_transfers_kick_in_below_floor() -> None: ssi_benefit=jnp.array(0.0), hic_premium=jnp.array(0.0), oop_costs=jnp.array(0.0), - consumption_floor=5000.0, - equivalence_scale=jnp.array(1.0), + consumption_dollars_floor=jnp.array(5000.0), pension_assets_adjustment=jnp.array(0.0), - consumption=jnp.array(4000.0), + consumption_dollars=jnp.array(4000.0), ) # cash_on_hand = 500 + 200 = 700 - # floor = 5000 * 1.0 = 5000 # transfers = max(0, 5000 - 700) = 4300 assert jnp.isclose(result["transfers"], 4300.0, atol=ATOL) # next_assets = 700 + 4300 + 0 - 4000 = 1000 diff --git a/tests/test_consumption_dollars_grid.py b/tests/test_consumption_dollars_grid.py new file mode 100644 index 0000000..1f42e6f --- /dev/null +++ b/tests/test_consumption_dollars_grid.py @@ -0,0 +1,86 @@ +"""Consumption-grid invariants required by the borrowing constraint. + +The borrowing constraint in `agent.assets_and_income.borrowing_constraint` +compares the lowest consumption_dollars action against +`max(cash_on_hand, consumption_dollars_floor)`. For subjects with cash +below the floor (large-negative-asset subjects, moderate-negative-asset +retirees, etc.) this RHS collapses to exactly +`consumption_dollars_floor`. The constraint is feasible iff the +relevant household-floor gridpoint is `<= consumption_dollars_floor`. + +For singles (`equivalence_scale = 1`) that floor is +`consumption_equiv_floor`; for married households +(`equivalence_scale = 2 ** exponent`) it is +`consumption_equiv_floor * 2 ** exponent`. Both must land **exactly** +on the consumption_dollars grid. + +`jnp.geomspace(start, stop, num=n)` returns `start * r^i` with +`r = (stop/start)^(1/(n-1))`; mathematically `r^0 == 1` so the first +point equals `start`, but XLA backends can drift by sub-ULP for some +`(start, stop, n)` combinations (observed: CUDA, n=70, drift +2.27e-13). +A positive drift above the floor flips the kink-boundary `<=` and +rejects every action for the affected subjects. + +`_compute_consumption_dollars_points` therefore prepends the singles' +floor as `pts[0]`, runs `geomspace` from the married floor up to +`MAX_CONSUMPTION_DOLLARS` for the rest, and pins the geomspace start +back to the married floor exactly. Test those invariants directly. +""" + +import jax.numpy as jnp +import pytest + +from aca_model.baseline.regimes._common import MAX_CONSUMPTION_DOLLARS +from aca_model.consumption_dollars_grid import _compute_consumption_dollars_points + +EXPONENT = 0.7 # production value (env_constants["exponent"]) +SINGLE_FLOOR = 1597.0921419521899 # production value +MARRIED_SCALE = 2.0**EXPONENT + + +@pytest.mark.parametrize("n_points", [5, 16, 64, 70, 100]) +def test_compute_consumption_dollars_points_first_equals_singles_floor( + n_points: int, +) -> None: + """`pts[0]` equals the singles' floor exactly under any `n_points`.""" + pts = _compute_consumption_dollars_points( + consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), + exponent=jnp.asarray(EXPONENT), + n_points=n_points, + ) + assert float(pts[0]) == SINGLE_FLOOR + + +@pytest.mark.parametrize("n_points", [5, 16, 64, 70, 100]) +def test_compute_consumption_dollars_points_second_equals_married_floor( + n_points: int, +) -> None: + """`pts[1]` equals `consumption_equiv_floor * 2 ** exponent` exactly.""" + pts = _compute_consumption_dollars_points( + consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), + exponent=jnp.asarray(EXPONENT), + n_points=n_points, + ) + expected = float(jnp.asarray(SINGLE_FLOOR) * jnp.asarray(2.0) ** EXPONENT) + assert float(pts[1]) == expected + + +def test_compute_consumption_dollars_points_strictly_increasing() -> None: + """Gridpoints are strictly increasing — no kink-pinning ties.""" + pts = _compute_consumption_dollars_points( + consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), + exponent=jnp.asarray(EXPONENT), + n_points=70, + ) + diffs = jnp.diff(pts) + assert bool((diffs > 0).all()) + + +def test_compute_consumption_dollars_points_last_equals_max() -> None: + """The final point is the configured upper bound.""" + pts = _compute_consumption_dollars_points( + consumption_equiv_floor=jnp.asarray(SINGLE_FLOOR), + exponent=jnp.asarray(EXPONENT), + n_points=70, + ) + assert float(pts[-1]) == pytest.approx(MAX_CONSUMPTION_DOLLARS) diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py new file mode 100644 index 0000000..3b16522 --- /dev/null +++ b/tests/test_initial_conditions_extreme_assets.py @@ -0,0 +1,128 @@ +"""Subjects at extreme negative assets must clear `validate_initial_conditions`. + +The transfer system (`agent.assets_and_income.transfers`) tops cash-on-hand +to the household-$ consumption floor at any starting state, so the lowest +consumption_dollars-grid point is always a feasible action regardless of +how negative starting assets are. The model's constraints — and pylcm's +`validate_initial_conditions` pass — must reflect this. +""" + +import jax.numpy as jnp +from lcm import DiscreteGrid +from lcm.simulation.initial_conditions import validate_initial_conditions + +from aca_model.agent.assets_and_income import borrowing_constraint +from aca_model.agent.preferences import BenchmarkPrefType +from aca_model.benchmark import ( + create_benchmark_model, + get_benchmark_initial_conditions, + get_benchmark_params, +) + + +def test_borrowing_constraint_admits_consumption_dollars_at_floor() -> None: + """`consumption_dollars == consumption_dollars_floor` at the kink is feasible by equality.""" + consumption_dollars_floor = jnp.asarray(5_000.0) + cash_on_hand = jnp.asarray(-50_000.0) # below floor — RHS = floor + + admitted = bool( + borrowing_constraint( + consumption_dollars=consumption_dollars_floor, + cash_on_hand=cash_on_hand, + consumption_dollars_floor=consumption_dollars_floor, + ) + ) + assert admitted + + +def test_borrowing_constraint_admits_consumption_dollars_at_married_floor() -> None: + """At a married household's higher floor, the equivalence-scale-lifted floor is feasible.""" + consumption_equiv_floor = jnp.asarray(5_000.0) + married_floor = consumption_equiv_floor * jnp.asarray(2.0) ** 0.7 + cash_on_hand = jnp.asarray(-50_000.0) + + admitted = bool( + borrowing_constraint( + consumption_dollars=married_floor, + cash_on_hand=cash_on_hand, + consumption_dollars_floor=married_floor, + ) + ) + assert admitted + + +def test_borrowing_constraint_rejects_consumption_dollars_above_post_transfer_resources() -> ( + None +): + """`consumption_dollars > max(cash_on_hand, floor)` is rejected.""" + consumption_dollars_floor = jnp.asarray(5_000.0) + cash_on_hand = jnp.asarray(-50_000.0) + consumption_dollars = consumption_dollars_floor + 1.0 + + admitted = bool( + borrowing_constraint( + consumption_dollars=consumption_dollars, + cash_on_hand=cash_on_hand, + consumption_dollars_floor=consumption_dollars_floor, + ) + ) + assert not admitted + + +def test_borrowing_constraint_admits_floor_at_million_dollar_negative_cash() -> None: + """The kink-boundary check survives sub-ULP rounding at `|cash_on_hand| ~ 1e6`. + + At large negative `assets`, the algebraically equivalent + `cash_on_hand + transfers` form rounds to `floor - 5.7e-11` at fp64, + flipping `consumption_dollars <= ...` for the lowest + consumption_dollars gridpoint. The `max(cash_on_hand, floor)` form + returns `floor` exactly. + """ + consumption_dollars_floor = jnp.asarray(1597.0921419521899) # production value + cash_on_hand = jnp.asarray(-1_000_000.0) + consumption_dollars = consumption_dollars_floor # lowest grid point + + admitted = bool( + borrowing_constraint( + consumption_dollars=consumption_dollars, + cash_on_hand=cash_on_hand, + consumption_dollars_floor=consumption_dollars_floor, + ) + ) + assert admitted + + +def test_extreme_negative_assets_subject_passes_validation() -> None: + """A subject placed at `assets = -1_000_000` clears initial-conditions validation. + + A large-but-reasonable negative value (very bad draws for both HCC shocks) + should remain in the simulated population: the consumption floor / + transfer system absorbs them, with `c = c_floor` always feasible. + """ + n_subjects = 1 + model = create_benchmark_model( + n_subjects=n_subjects, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + _, _, params = get_benchmark_params(model=model) + + initial_conditions = get_benchmark_initial_conditions( + model=model, n_subjects=n_subjects, seed=0 + ) + initial_conditions = { + **initial_conditions, + "assets": jnp.asarray([-1_000_000.0]), + "regime": jnp.asarray( + [model.regime_names_to_ids["retiree_nomc_inelig_canwork"]], + dtype=jnp.int32, + ), + } + + internal_params = model._process_params(params) # noqa: SLF001 + validate_initial_conditions( + initial_conditions=initial_conditions, + internal_regimes=model.internal_regimes, + regime_names_to_ids=model.regime_names_to_ids, + internal_params=internal_params, + ages=model.ages, + ) diff --git a/tests/test_model_components.py b/tests/test_model_components.py index cbb2f72..5b7df6a 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -77,27 +77,23 @@ def test_leisure_bad_health() -> None: def test_utility_positive_leisure() -> None: - result = preferences.utility( - consumption=jnp.array(10000.0), + result = preferences.u_can_work( + consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), - pref_type=jnp.array(0), - consumption_weight=jnp.array([0.4, 0.4, 0.4]), - coefficient_rra=jnp.array([2.0, 2.0, 2.0]), - equivalence_scale=jnp.array(1.0), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + consumption_weight=jnp.array(0.4), + coefficient_rra=jnp.array(2.0), + utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) def test_utility_log_case() -> None: - result = preferences.utility( - consumption=jnp.array(10000.0), + result = preferences.u_can_work( + consumption_equiv=jnp.array(10000.0), leisure=jnp.array(3000.0), - pref_type=jnp.array(0), - consumption_weight=jnp.array([0.4, 0.4, 0.4]), - coefficient_rra=jnp.array([1.0, 1.0, 1.0]), - equivalence_scale=jnp.array(1.0), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + consumption_weight=jnp.array(0.4), + coefficient_rra=jnp.array(1.0), + utility_scale_factor=jnp.array(1.0), ) composite = 10000.0**0.4 * 3000.0**0.6 expected = jnp.log(composite) @@ -107,12 +103,11 @@ def test_utility_log_case() -> None: def test_bequest_positive_assets() -> None: result = preferences.bequest( assets=jnp.array(100000.0), - pref_type=jnp.array(0), bequest_shifter=5000.0, scaled_bequest_weight=0.5, - consumption_weight=jnp.array([0.4, 0.4, 0.4]), - coefficient_rra=jnp.array([2.0, 2.0, 2.0]), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + consumption_weight=jnp.array(0.4), + coefficient_rra=jnp.array(2.0), + utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) @@ -120,12 +115,11 @@ def test_bequest_positive_assets() -> None: def test_bequest_zero_assets() -> None: result = preferences.bequest( assets=jnp.array(0.0), - pref_type=jnp.array(0), bequest_shifter=5000.0, scaled_bequest_weight=0.5, - consumption_weight=jnp.array([0.4, 0.4, 0.4]), - coefficient_rra=jnp.array([2.0, 2.0, 2.0]), - utility_scale_factor=jnp.array([1.0, 1.0, 1.0]), + consumption_weight=jnp.array(0.4), + coefficient_rra=jnp.array(2.0), + utility_scale_factor=jnp.array(1.0), ) assert jnp.isfinite(result) assert result < 0 # CRRA with γ>1 gives negative values @@ -166,8 +160,8 @@ def test_next_aime_accrual() -> None: result = social_security.next_aime( aime=jnp.array(1000.0), labor_income=jnp.array(50000.0), - period=55, - age=55, + period=jnp.int32(55), + age=jnp.int32(55), benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), earnings_test_repealed_age=70, diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index 841f2bb..7ae6e36 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -3,17 +3,38 @@ from collections.abc import Mapping import pytest +from helpers.model import make_aca_model, make_baseline_model +from lcm import DiscreteGrid from aca_model.aca import health_insurance as aca_hi from aca_model.aca.health_insurance import PolicyVariant -from aca_model.aca.model import create_model as create_aca_model -from aca_model.aca.regimes import build_all_regimes as build_aca_regimes -from aca_model.baseline.model import create_model +from aca_model.aca.regimes import build_all_regimes as _build_aca_regimes +from aca_model.agent.preferences import BenchmarkPrefType from aca_model.baseline.regimes import REGIME_SPECS, RegimeId from aca_model.baseline.regimes import build_regime as _build_regime from aca_model.baseline.regimes._common import build_grids +from aca_model.benchmark import get_benchmark_params +from aca_model.config import BENCHMARK_GRID_CONFIG -_GRIDS = build_grids() +_FIXED_PARAMS, _WAGE_PARAMS, _ = get_benchmark_params(model=None) + + +def build_aca_regimes(policy: PolicyVariant) -> dict: + return _build_aca_regimes( + policy=policy, + grid_config=BENCHMARK_GRID_CONFIG, + fixed_params=_FIXED_PARAMS, + wage_params=_WAGE_PARAMS, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + + +_GRIDS = build_grids( + grid_config=BENCHMARK_GRID_CONFIG, + fixed_params=_FIXED_PARAMS, + wage_params=_WAGE_PARAMS, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), +) def build_regime(name: str): @@ -21,24 +42,24 @@ def build_regime(name: str): def test_model_creates_successfully() -> None: - model = create_model() + model = make_baseline_model(n_subjects=1) assert len(model.regimes) == 19 assert model.n_periods == 45 def test_model_age_range() -> None: - model = create_model() + model = make_baseline_model(n_subjects=1) assert model.ages.values[0] == 51.0 assert model.ages.values[-1] == 95.0 def test_dead_regime_is_terminal() -> None: - model = create_model() + model = make_baseline_model(n_subjects=1) assert model.regimes["dead"].terminal def test_non_terminal_regimes_not_terminal() -> None: - model = create_model() + model = make_baseline_model(n_subjects=1) for name in REGIME_SPECS: assert not model.regimes[name].terminal @@ -55,7 +76,7 @@ def test_forcedout_regimes_no_labor_supply(name: str) -> None: regime = build_regime(name) assert "labor_supply" not in regime.actions assert "log_ft_wage_res" not in regime.states - assert "consumption" in regime.actions + assert "consumption_dollars" in regime.actions @pytest.mark.parametrize( @@ -128,7 +149,7 @@ def test_pre65_regimes_use_health_with_disability() -> None: if spec["mc"] in ("nomc", "dimc"): regime = build_regime(name) grid = regime.states["health"] - assert len(grid.categories) == 3, f"{name} should use HealthWithDisability" # ty: ignore[unresolved-attribute] + assert len(grid.categories) == 3, f"{name} should use HealthWithDisability" def test_post65_regimes_use_health() -> None: @@ -136,7 +157,7 @@ def test_post65_regimes_use_health() -> None: if spec["mc"] == "oamc": regime = build_regime(name) grid = regime.states["health"] - assert len(grid.categories) == 2, f"{name} should use Health" # ty: ignore[unresolved-attribute] + assert len(grid.categories) == 2, f"{name} should use Health" def test_all_regimes_have_aime() -> None: @@ -170,7 +191,7 @@ def test_hcc_persistent_and_transitory_are_shock_grids() -> None: def test_aca_model_creates_successfully() -> None: - model = create_aca_model() + model = make_aca_model(n_subjects=1, policy=PolicyVariant.ACA) assert len(model.regimes) == 19 assert model.n_periods == 45 @@ -211,7 +232,7 @@ def test_aca_other_regimes_have_no_aca_policy_keys() -> None: @pytest.mark.parametrize("policy", list(PolicyVariant)) def test_all_policy_variants_create(policy: PolicyVariant) -> None: """All policy variants create valid models.""" - model = create_aca_model(policy=policy) + model = make_aca_model(n_subjects=1, policy=policy) assert len(model.regimes) == 19 @@ -251,5 +272,5 @@ def test_aca_only_medicaid_expansion() -> None: def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" - model = create_model() + model = make_baseline_model(n_subjects=1) assert len(model.regimes) == 19 diff --git a/tests/test_pension_integration.py b/tests/test_pension_integration.py index 287cab6..0f6c07d 100644 --- a/tests/test_pension_integration.py +++ b/tests/test_pension_integration.py @@ -18,7 +18,7 @@ # HIS 0 (retiree): intercept = -50, HIS 1 (nongroup): intercept = -80. N_PERIODS = 30 N_HIS = 2 -PERIOD = 20 +PERIOD = jnp.int32(20) _intercept = jnp.zeros((N_PERIODS, N_HIS)) _intercept = _intercept.at[PERIOD, 0].set(-50.0) @@ -62,7 +62,7 @@ def test_benefit_wealth_dag() -> None: result = combined( pia=jnp.array(500.0), period=PERIOD, - his=0, + his=jnp.int32(0), epdv_constant_pension=EPDV, **IMP_KWARGS, ) @@ -80,7 +80,7 @@ def test_total_to_pia_inverts_benefit_via_dag() -> None: recovered = combined( pia=jnp.array(8000.0), period=PERIOD, - his=0, + his=jnp.int32(0), marginal_tax_rate=jnp.array(0.2), **IMP_KWARGS, ) @@ -95,7 +95,7 @@ def test_next_assets_includes_pension_adjustment() -> None: cash_on_hand=jnp.array(100_000.0), transfers=jnp.array(0.0), pension_assets_adjustment=jnp.array(5_000.0), - consumption=jnp.array(80_000.0), + consumption_dollars=jnp.array(80_000.0), oop_costs=jnp.array(0.0), ) assert jnp.isclose(result, 25_000.0, atol=ATOL) @@ -103,7 +103,7 @@ def test_next_assets_includes_pension_adjustment() -> None: def test_zero_adjustment_when_his_unchanged() -> None: """Pension adjustment is zero when HIS doesn't change.""" - his = 0 + his = jnp.int32(0) pia = jnp.array(8000.0) labor_income = jnp.array(30_000.0) mtr = jnp.array(0.2) @@ -149,8 +149,8 @@ def test_rebalancing_preserves_total_wealth_across_his_change() -> None: from HIS 0 (retiree) to HIS 1 (nongroup), the pension imputation changes. The assets_adjustment compensates so total wealth is preserved. """ - old_his = 0 - new_his = 1 + old_his = jnp.int32(0) + new_his = jnp.int32(1) pia = jnp.array(8000.0) labor_income = jnp.array(30_000.0) mtr = jnp.array(0.0) diff --git a/tests/test_pensions.py b/tests/test_pensions.py index beff910..514ab8c 100644 --- a/tests/test_pensions.py +++ b/tests/test_pensions.py @@ -42,8 +42,8 @@ def test_pension_benefit_zero_pia() -> None: result = pensions.benefit( pia=jnp.array(0.0), - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -59,8 +59,8 @@ def test_pension_benefit_zero_pia() -> None: def test_pension_benefit_below_kink_0() -> None: result = pensions.benefit( pia=jnp.array(500.0), - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -75,8 +75,8 @@ def test_pension_benefit_below_kink_0() -> None: def test_pension_benefit_between_kinks() -> None: result = pensions.benefit( pia=jnp.array(12000.0), - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -92,8 +92,8 @@ def test_pension_benefit_between_kinks() -> None: def test_pension_benefit_above_kink_1() -> None: result = pensions.benefit( pia=jnp.array(20000.0), - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -109,8 +109,8 @@ def test_pension_benefit_above_kink_1() -> None: def test_pension_accrual_no_income() -> None: result = pensions.accrual( labor_income=jnp.array(-1000.0), - period=20, - his=0, + period=jnp.int32(20), + his=jnp.int32(0), accrual_intercept=ACCRUAL_INTERCEPT, accrual_log_earnings=ACCRUAL_LOG_EARNINGS, accrual_prob_intercept=ACCRUAL_PROB_INTERCEPT, @@ -123,8 +123,8 @@ def test_pension_accrual_no_income() -> None: def test_pension_accrual_positive() -> None: result = pensions.accrual( labor_income=jnp.array(10000.0), - period=20, - his=0, + period=jnp.int32(20), + his=jnp.int32(0), accrual_intercept=ACCRUAL_INTERCEPT, accrual_log_earnings=ACCRUAL_LOG_EARNINGS, accrual_prob_intercept=ACCRUAL_PROB_INTERCEPT, @@ -148,7 +148,7 @@ def test_pension_wealth_next_accrual_only() -> None: pension_accrual=jnp.array(accrual), rate_of_return=r, unconditional_survival_prob=SURVIVAL_PROBS, - period=28, + period=jnp.int32(28), ) assert jnp.isclose(result, accrual / 0.99, atol=ATOL) @@ -164,7 +164,7 @@ def test_pension_wealth_next_with_benefit() -> None: pension_accrual=jnp.array(accrual), rate_of_return=r, unconditional_survival_prob=SURVIVAL_PROBS, - period=29, + period=jnp.int32(29), ) expected = ((1 + r) * 3000 + accrual - 2000) / 0.98 assert jnp.isclose(result, expected, atol=ATOL) @@ -177,8 +177,8 @@ def test_convert_total_ben_to_pia_below_kink_0() -> None: pb = pensions.benefit( pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -190,8 +190,8 @@ def test_convert_total_ben_to_pia_below_kink_0() -> None: recovered = pensions.total_to_pia( pension_benefit=pb, pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), marginal_tax_rate=mtr, imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, @@ -210,8 +210,8 @@ def test_convert_total_ben_to_pia_between_kinks() -> None: pb = pensions.benefit( pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -223,8 +223,8 @@ def test_convert_total_ben_to_pia_between_kinks() -> None: recovered = pensions.total_to_pia( pension_benefit=pb, pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), marginal_tax_rate=mtr, imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, @@ -243,8 +243,8 @@ def test_convert_total_ben_to_pia_above_kink_1() -> None: pb = pensions.benefit( pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -256,8 +256,8 @@ def test_convert_total_ben_to_pia_above_kink_1() -> None: recovered = pensions.total_to_pia( pension_benefit=pb, pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), marginal_tax_rate=mtr, imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, @@ -276,8 +276,8 @@ def test_convert_total_ben_to_pia_zero_mtr() -> None: pb = pensions.benefit( pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, imp_pia_kink_0_coeff=PW_IMP_PIA_KINK_0_COEFF, @@ -289,8 +289,8 @@ def test_convert_total_ben_to_pia_zero_mtr() -> None: recovered = pensions.total_to_pia( pension_benefit=pb, pia=pia_input, - period=29, - his=0, + period=jnp.int32(29), + his=jnp.int32(0), marginal_tax_rate=mtr, imp_intercept=PW_IMP_INTERCEPT, imp_pia_coeff=PW_IMP_PIA, diff --git a/tests/test_preferences.py b/tests/test_preferences.py index 8c1921c..1b5107f 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -12,20 +12,11 @@ TIME_DISCOUNT_FACTOR = 0.85 TIME_ENDOWMENT = 5000.0 FIXED_COST_INTERCEPT = 0.0 -FIXED_COST_AGE_TREND = 50.0 AVERAGE_CONSUMPTION = 10000.0 RATE_OF_RETURN = 0.01 BEQUEST_WEIGHT = 0.02 BEQUEST_SHIFTER = 500_000.0 -SCALE_REFERENCE_HOURS = 500.0 -REFERENCE_AGE = 50 -SCALE_REFERENCE_AGE = 60 - -# Pref-type-indexed params: three identical entries so pref_type=0 selects -# the struct-ret scalar value used by the regression tests. -WEIGHT_BY_TYPE = jnp.array([CONSUMPTION_WEIGHT, CONSUMPTION_WEIGHT, CONSUMPTION_WEIGHT]) -RRA_5_BY_TYPE = jnp.array([5.0, 5.0, 5.0]) -RRA_1_BY_TYPE = jnp.array([1.0, 1.0, 1.0]) +REFERENCE_HOURS = 1000.0 # --- utility_scale_factor --- @@ -33,32 +24,26 @@ def test_utility_scale_factor_crra() -> None: result = preferences.utility_scale_factor( - average_consumption=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, + average_consumption_dollars=AVERAGE_CONSUMPTION, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) - assert jnp.isclose(result[0], 9_233_279_397_806_166.0, rtol=1e-6) + assert jnp.isclose(result, 9_233_279_397_806_166.0, rtol=1e-6) def test_utility_scale_factor_log() -> None: result = preferences.utility_scale_factor( - average_consumption=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_1_BY_TYPE, + average_consumption_dollars=AVERAGE_CONSUMPTION, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) - assert jnp.isclose(result[0], 0.113_073_257_794_546_72, rtol=1e-6) + assert jnp.isclose(result, 0.113_073_257_794_546_72, rtol=1e-6) # --- scaled_bequest_weight --- @@ -105,23 +90,18 @@ def test_scaled_bequest_weight_zero() -> None: def test_utility_log_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_1_BY_TYPE, + average_consumption_dollars=AVERAGE_CONSUMPTION, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) - result = preferences.utility( - consumption=jnp.array(50000.0), + result = preferences.u_can_work( + consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - pref_type=jnp.array(0), - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_1_BY_TYPE, - equivalence_scale=jnp.array(1.0), + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(1.0), utility_scale_factor=scale, ) assert jnp.isclose(result, 1.005_046_313_660_588_5, rtol=1e-5) @@ -129,57 +109,45 @@ def test_utility_log_regression() -> None: def test_utility_crra_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, + average_consumption_dollars=AVERAGE_CONSUMPTION, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) - result = preferences.utility( - consumption=jnp.array(50000.0), + result = preferences.u_can_work( + consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - pref_type=jnp.array(0), - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, - equivalence_scale=jnp.array(1.0), + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), utility_scale_factor=scale, ) assert jnp.isclose(result, -0.836_511_642_073_019_1, rtol=1e-5) def test_utility_married_equivalence() -> None: - """Married with equiv-scaled consumption should equal single utility.""" + """Married with equiv-scaled consumption_dollars should equal single utility.""" scale = preferences.utility_scale_factor( - average_consumption=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, + average_consumption_dollars=AVERAGE_CONSUMPTION, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) - single = preferences.utility( - consumption=jnp.array(50000.0), + single = preferences.u_can_work( + consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - pref_type=jnp.array(0), - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, - equivalence_scale=jnp.array(1.0), + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), utility_scale_factor=scale, ) - married = preferences.utility( - consumption=jnp.array(50000.0 * 2**0.7), + married = preferences.u_can_work( + consumption_equiv=jnp.array(50000.0), leisure=jnp.array(400.0), - pref_type=jnp.array(0), - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, - equivalence_scale=jnp.array(2**0.7), + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), utility_scale_factor=scale, ) assert jnp.isclose(single, married, rtol=1e-5) @@ -190,15 +158,12 @@ def test_utility_married_equivalence() -> None: def test_bequest_log_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_1_BY_TYPE, + average_consumption_dollars=AVERAGE_CONSUMPTION, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(1.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) bwt = preferences.scaled_bequest_weight( bequest_weight=BEQUEST_WEIGHT, @@ -210,11 +175,10 @@ def test_bequest_log_regression() -> None: ) result = preferences.bequest( assets=jnp.array(10000.0), - pref_type=jnp.array(0), bequest_shifter=BEQUEST_SHIFTER, scaled_bequest_weight=bwt.item(), - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_1_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(1.0), utility_scale_factor=scale, ) assert jnp.isclose(result, 86.539_249_963_643_88, rtol=1e-5) @@ -222,15 +186,12 @@ def test_bequest_log_regression() -> None: def test_bequest_crra_regression() -> None: scale = preferences.utility_scale_factor( - average_consumption=AVERAGE_CONSUMPTION, - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, + average_consumption_dollars=AVERAGE_CONSUMPTION, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), time_endowment=TIME_ENDOWMENT, fixed_cost_of_work_intercept=FIXED_COST_INTERCEPT, - fixed_cost_of_work_age_trend=FIXED_COST_AGE_TREND, - scale_reference_hours=SCALE_REFERENCE_HOURS, - reference_age=REFERENCE_AGE, - scale_reference_age=SCALE_REFERENCE_AGE, + reference_hours=REFERENCE_HOURS, ) bwt = preferences.scaled_bequest_weight( bequest_weight=BEQUEST_WEIGHT, @@ -242,11 +203,10 @@ def test_bequest_crra_regression() -> None: ) result = preferences.bequest( assets=jnp.array(10000.0), - pref_type=jnp.array(0), bequest_shifter=BEQUEST_SHIFTER, scaled_bequest_weight=bwt.item(), - consumption_weight=WEIGHT_BY_TYPE, - coefficient_rra=RRA_5_BY_TYPE, + consumption_weight=jnp.array(CONSUMPTION_WEIGHT), + coefficient_rra=jnp.array(5.0), utility_scale_factor=scale, ) assert jnp.isclose(result, -37.932_748_117_035_63, rtol=1e-5) diff --git a/tests/test_social_security.py b/tests/test_social_security.py index d75e458..b8ac44a 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -5,11 +5,12 @@ import jax.numpy as jnp import numpy as np +import pandas as pd +from helpers.social_security import compute_di_dropout_scale, compute_pia_table from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from tests.helpers.social_security import compute_di_dropout_scale, compute_pia_table ATOL = 0.01 @@ -56,7 +57,12 @@ RATIO = jnp.array(_RATIO_NP) DI_SCALE = jnp.array( - compute_di_dropout_scale(_RATIO_NP, AIME_ACCRUAL_FACTOR, start_age=0, n_periods=100) + compute_di_dropout_scale( + pd.Series(_RATIO_NP), + AIME_ACCRUAL_FACTOR, + start_age=jnp.int32(0), + n_periods=100, + ) ) # Pre-computed PIA lookup table (4-point exact grid) @@ -119,13 +125,13 @@ def test_next_aime_indexing_high_income() -> None: result = social_security.next_aime( benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, aime=jnp.array(1000.0), labor_income=jnp.array(20000.0), - period=58, - age=58, + period=jnp.int32(58), + age=jnp.int32(58), aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, @@ -140,13 +146,13 @@ def test_next_aime_indexing_low_income() -> None: result = social_security.next_aime( benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, aime=jnp.array(10000.0), labor_income=jnp.array(510.0), - period=58, - age=58, + period=jnp.int32(58), + age=jnp.int32(58), aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, @@ -160,13 +166,13 @@ def test_next_aime_no_indexing_high_income() -> None: result = social_security.next_aime( benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, aime=jnp.array(1000.0), labor_income=jnp.array(20000.0), - period=62, - age=62, + period=jnp.int32(62), + age=jnp.int32(62), aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, @@ -181,13 +187,13 @@ def test_next_aime_no_indexing_low_income() -> None: result = social_security.next_aime( benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, aime=jnp.array(1000.0), labor_income=jnp.array(99.0), - period=62, - age=62, + period=jnp.int32(62), + age=jnp.int32(62), aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, @@ -201,13 +207,13 @@ def test_next_aime_cap_high_aime_high_income() -> None: result = social_security.next_aime( benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, aime=jnp.array(40000.0), labor_income=jnp.array(20000.0), - period=62, - age=62, + period=jnp.int32(62), + age=jnp.int32(62), aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, @@ -221,13 +227,13 @@ def test_next_aime_cap_high_aime_low_income() -> None: result = social_security.next_aime( benefit_withheld_fraction=jnp.array(0.0), earnings_test_credited_back=jnp.zeros(100), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, aime=jnp.array(40000.0), labor_income=jnp.array(3500.0), - period=62, - age=62, + period=jnp.int32(62), + age=jnp.int32(62), aime_accrual_factor=AIME_ACCRUAL_FACTOR, aggregate_wage_growth=AGGREGATE_WAGE_GROWTH, aime_last_age_with_indexing=AIME_LAST_AGE_WITH_INDEXING, @@ -256,7 +262,7 @@ def test_pia_lookup_matches_formula() -> None: def test_ssdi_pia_matches_dropout_adjusted() -> None: """ssdi_pia lookup matches aime_to_pia(aime * di_dropout_scale[period]).""" aime = jnp.array(5000.0) - period = 55 + period = jnp.int32(55) adjusted_aime = aime * DI_SCALE[period] lookup = social_security.ssdi_pia( @@ -292,17 +298,17 @@ def test_benefit_choose_post65_below_et_threshold() -> None: ) result = social_security.benefit_choose_post65( pia=pia_val, - age=67, - period=0, + age=jnp.int32(67), + period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), labor_supply=jnp.array(LaborSupply.h2000), labor_income=jnp.array(4000.0), early_ret_adjustment=jnp.array([1.0]), - normal_retirement_age=66, + normal_retirement_age=jnp.int32(66), earnings_test_threshold=jnp.array([10000.0]), earnings_test_fraction=jnp.array([0.0]), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), ) assert jnp.isclose(result, pia_val, atol=ATOL) @@ -316,17 +322,17 @@ def test_benefit_choose_post65_partially_reduced() -> None: ) result = social_security.benefit_choose_post65( pia=pia_val, - age=60, - period=0, + age=jnp.int32(60), + period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), labor_supply=jnp.array(LaborSupply.h2000), labor_income=jnp.array(6000.0), early_ret_adjustment=jnp.array([1.0]), - normal_retirement_age=66, + normal_retirement_age=jnp.int32(66), earnings_test_threshold=jnp.array([2000.0]), earnings_test_fraction=jnp.array([0.2]), - earnings_test_repealed_age=70, + earnings_test_repealed_age=jnp.int32(70), ) expected = pia_val - (6000 - 2000) * 0.2 assert jnp.isclose(result, expected, atol=ATOL) @@ -336,7 +342,7 @@ def test_benefit_inelig_pre65_disabled_below_sga() -> None: """Disabled agent below SGA: benefit = ssdi_pia.""" ssdi_val = social_security.ssdi_pia( aime=jnp.array(5000.0), - period=55, + period=jnp.int32(55), di_dropout_scale=DI_SCALE, pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, @@ -354,7 +360,7 @@ def test_benefit_inelig_pre65_disabled_above_sga() -> None: """Disabled agent above SGA: benefit = 0.""" ssdi_val = social_security.ssdi_pia( aime=jnp.array(5000.0), - period=55, + period=jnp.int32(55), di_dropout_scale=DI_SCALE, pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, @@ -385,12 +391,16 @@ def test_benefit_inelig_pre65_not_disabled() -> None: def test_di_dropout_round_trip_zero_years() -> None: aime = jnp.array(10000.0) scaled = aime * DI_SCALE[52] - round_tripped = social_security.adjust_aime_di_dropout_inv(52, scaled, DI_SCALE) + round_tripped = social_security.adjust_aime_di_dropout_inv( + jnp.int32(52), scaled, DI_SCALE + ) assert jnp.isclose(aime, round_tripped, atol=ATOL) def test_di_dropout_round_trip_positive_years() -> None: aime = jnp.array(10000.0) scaled = aime * DI_SCALE[62] - round_tripped = social_security.adjust_aime_di_dropout_inv(62, scaled, DI_SCALE) + round_tripped = social_security.adjust_aime_di_dropout_inv( + jnp.int32(62), scaled, DI_SCALE + ) assert jnp.isclose(aime, round_tripped, rtol=0.0002) diff --git a/tests/test_ss_benefit_integration.py b/tests/test_ss_benefit_integration.py index 0e77ea5..5e74e9a 100644 --- a/tests/test_ss_benefit_integration.py +++ b/tests/test_ss_benefit_integration.py @@ -5,11 +5,11 @@ """ import jax.numpy as jnp +from helpers.social_security import compute_pia_table from aca_model.agent.labor_market import LaborSupply from aca_model.environment import social_security from aca_model.environment.social_security import ClaimedSS -from tests.helpers.social_security import compute_pia_table ATOL = 0.01 @@ -43,7 +43,7 @@ def test_earnings_test_reduces_benefit_before_fra() -> None: n_periods = 45 ssdi_pia_val = social_security.ssdi_pia( aime=jnp.array(3000.0), - period=12, + period=jnp.int32(12), di_dropout_scale=jnp.ones(n_periods + 1), pia_table=PIA_TABLE, pia_aime_grid=PIA_AIME_GRID, @@ -52,36 +52,36 @@ def test_earnings_test_reduces_benefit_before_fra() -> None: benefit_working = social_security.benefit_choose_pre65( pia=pia_val, ssdi_pia=ssdi_pia_val, - age=63, - period=0, + age=jnp.int32(63), + period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), health=jnp.array(2), labor_supply=jnp.array(LaborSupply.h2000), labor_income=jnp.array(30000.0), early_ret_adjustment=jnp.array([0.75]), - normal_retirement_age=66, + normal_retirement_age=jnp.int32(66), earnings_test_threshold=jnp.array([17640.0]), earnings_test_fraction=jnp.array([0.5]), - earnings_test_repealed_age=66, + earnings_test_repealed_age=jnp.int32(66), ssdi_substantial_gainful_activity=13560.0, ) benefit_not_working = social_security.benefit_choose_pre65( pia=pia_val, ssdi_pia=ssdi_pia_val, - age=63, - period=0, + age=jnp.int32(63), + period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), health=jnp.array(2), labor_supply=jnp.array(LaborSupply.do_not_work), labor_income=jnp.array(0.0), early_ret_adjustment=jnp.array([0.75]), - normal_retirement_age=66, + normal_retirement_age=jnp.int32(66), earnings_test_threshold=jnp.array([17640.0]), earnings_test_fraction=jnp.array([0.5]), - earnings_test_repealed_age=66, + earnings_test_repealed_age=jnp.int32(66), ssdi_substantial_gainful_activity=13560.0, ) @@ -99,32 +99,32 @@ def test_earnings_test_not_applied_after_fra() -> None: benefit_post65 = social_security.benefit_choose_post65( pia=pia_val, - age=67, - period=0, + age=jnp.int32(67), + period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), labor_supply=jnp.array(LaborSupply.h2000), labor_income=jnp.array(50000.0), early_ret_adjustment=jnp.array([1.0]), - normal_retirement_age=66, + normal_retirement_age=jnp.int32(66), earnings_test_threshold=jnp.array([17640.0]), earnings_test_fraction=jnp.array([0.5]), - earnings_test_repealed_age=66, + earnings_test_repealed_age=jnp.int32(66), ) benefit_not_working = social_security.benefit_choose_post65( pia=pia_val, - age=67, - period=0, + age=jnp.int32(67), + period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), labor_supply=jnp.array(LaborSupply.do_not_work), labor_income=jnp.array(0.0), early_ret_adjustment=jnp.array([1.0]), - normal_retirement_age=66, + normal_retirement_age=jnp.int32(66), earnings_test_threshold=jnp.array([17640.0]), earnings_test_fraction=jnp.array([0.5]), - earnings_test_repealed_age=66, + earnings_test_repealed_age=jnp.int32(66), ) assert jnp.isclose(benefit_post65, benefit_not_working, atol=ATOL)