Skip to content

Some Out-Of-Sample predictions require manually removing certain sites #2158

@kylejcaron

Description

@kylejcaron

Bug Description

When a site's shape needs to change (such as for out of sample predictions), Predictive will fail due to broadcasting issues if that site's posterior samples are provided to the predictive.

While this isnt really a bug, it feels like less than ideal behavior

Steps to Reproduce

Steps to reproduce the behavior.

import numpy as np
import numpyro
from numpyro import handlers
from numpyro.infer.reparam import LocScaleReparam
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from jax import random

y_obs = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])


def eight_schools_model(sigma, y=None):
    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.HalfCauchy(5))
    with numpyro.plate("J", sigma.shape[0]):
        theta = numpyro.sample("theta", dist.Normal(mu, tau))
        return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)

nc_model = handlers.reparam(eight_schools_model, config={"theta": LocScaleReparam(0)})

nuts = NUTS(nc_model)
mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc.run(random.PRNGKey(0), sigma=sigma, y=y_obs)


post_samples =  mcmc.get_samples())
# post_samples.pop("theta") # this fixes it
predictive_oos = Predictive(nc_model, post_samples)
predictions_oos = predictive_oos(random.PRNGKey(2),  sigma=np.array([3, 7, 10]))

ValueError: Incompatible shapes for broadcasting: shapes=[(), (8,), (3,)]

Expected Behavior

naively, I think a lot of newer users would expect successful predictions with shape (3,) and theta site with sample shape (3,) - and that is what happens when theta is excluded from posterior samples provided to Predictive.

To be fair I think most PPLs will have this behavior, but are there any good workflow solutions for this that 1) use the same model (as opposed to making a second model for predictions) and 2) dont requiring popping sites from the posterior samples?

This might be a good opportunity for more examples or a new pattern altogether?

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions