prior_from_idata#

pymc_experimental.utils.prior.prior_from_idata(idata: InferenceData, name='trace_prior_', *, var_names: Sequence[str] = (), **kwargs: ParamCfg | Transform | str | Tuple) Dict[str, TensorVariable][source]#

Create a prior from posterior using MvNormal approximation.

The approximation uses MvNormal distribution. Keep in mind that this function will only work well for unimodal posteriors and will fail when complicated interactions happen.

Moreover, if a retrieved variable is constrained, you should specify a transform for the variable, e.g. pymc.distributions.transforms.log for standard deviation posterior.

Parameters:
  • idata (arviz.InferenceData) – Inference data with posterior group

  • var_names (Sequence[str]) – names of variables to take as is from the posterior

  • kwargs (Union[ParamCfg, Transform, str, Tuple]) – names of variables with additional configuration, see more in Examples

Examples

>>> import pymc as pm
>>> import pymc.distributions.transforms as transforms
>>> import numpy as np
>>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model1:
...     a = pm.Normal("a")
...     b = pm.Normal("b", dims="test")
...     c = pm.HalfNormal("c")
...     d = pm.Normal("d")
...     e = pm.Normal("e")
...     f = pm.Dirichlet("f", np.ones(3), dims="options")
...     trace = pm.sample(progressbar=False)

You can reuse the posterior in the new model.

>>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model2:
...     priors = prior_from_idata(
...         trace,                  # the old trace (posterior)
...         var_names=["a", "d"],   # take variables as is
...
...         e="new_e",              # assign new name "new_e" for a variable
...                                 # similar to dict(name="new_e")
...
...         b=("test", ),           # set a dim to "test"
...                                 # similar to dict(dims=("test", ))
...
...         c=transforms.log,       # apply log transform to a positive variable
...                                 # similar to dict(transform=transforms.log)
...
...                                 # set a name, assign a dim and apply simplex transform
...         f=dict(name="new_f", dims="options", transform=transforms.simplex)
...     )
...     trace1 = pm.sample_prior_predictive(100)