Bug Description
When a site's shape needs to change (such as for out of sample predictions), Predictive will fail due to broadcasting issues if that site's posterior samples are provided to the predictive.
While this isnt really a bug, it feels like less than ideal behavior
Steps to Reproduce
Steps to reproduce the behavior.
import numpy as np
import numpyro
from numpyro import handlers
from numpyro.infer.reparam import LocScaleReparam
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from jax import random
y_obs = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def eight_schools_model(sigma, y=None):
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
with numpyro.plate("J", sigma.shape[0]):
theta = numpyro.sample("theta", dist.Normal(mu, tau))
return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
nc_model = handlers.reparam(eight_schools_model, config={"theta": LocScaleReparam(0)})
nuts = NUTS(nc_model)
mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc.run(random.PRNGKey(0), sigma=sigma, y=y_obs)
post_samples = mcmc.get_samples())
# post_samples.pop("theta") # this fixes it
predictive_oos = Predictive(nc_model, post_samples)
predictions_oos = predictive_oos(random.PRNGKey(2), sigma=np.array([3, 7, 10]))
ValueError: Incompatible shapes for broadcasting: shapes=[(), (8,), (3,)]
Expected Behavior
naively, I think a lot of newer users would expect successful predictions with shape (3,) and theta site with sample shape (3,) - and that is what happens when theta is excluded from posterior samples provided to Predictive.
To be fair I think most PPLs will have this behavior, but are there any good workflow solutions for this that 1) use the same model (as opposed to making a second model for predictions) and 2) dont requiring popping sites from the posterior samples?
This might be a good opportunity for more examples or a new pattern altogether?
Bug Description
When a site's shape needs to change (such as for out of sample predictions), Predictive will fail due to broadcasting issues if that site's posterior samples are provided to the predictive.
While this isnt really a bug, it feels like less than ideal behavior
Steps to Reproduce
Steps to reproduce the behavior.
ValueError: Incompatible shapes for broadcasting: shapes=[(), (8,), (3,)]Expected Behavior
naively, I think a lot of newer users would expect successful predictions with shape (3,) and
thetasite with sample shape (3,) - and that is what happens whenthetais excluded from posterior samples provided toPredictive.To be fair I think most PPLs will have this behavior, but are there any good workflow solutions for this that 1) use the same model (as opposed to making a second model for predictions) and 2) dont requiring popping sites from the posterior samples?
This might be a good opportunity for more examples or a new pattern altogether?