pymc.sampling.jax.sample_numpyro_nuts#
- pymc.sampling.jax.sample_numpyro_nuts(draws=1000, *, tune=1000, chains=4, target_accept=0.8, random_seed=None, initvals=None, jitter=True, model=None, var_names=None, nuts_kwargs=None, progressbar=True, keep_untransformed=False, chain_method='parallel', postprocessing_backend=None, postprocessing_vectorize=None, postprocessing_chunks=None, idata_kwargs=None, compute_convergence_checks=True, nuts_sampler='numpyro')#
Draw samples from the posterior using a jax NUTS method.
- Parameters:
- draws
int
, default 1000 The number of samples to draw. The number of tuned samples are discarded by default.
- tune
int
, default 1000 Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the
draws
argument. Tuned samples are discarded.- chains
int
, default 4 The number of chains to sample.
- target_accept
float
in
[0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors.
- random_seed
int
,RandomState
orGenerator
, optional Random seed used by the sampling steps.
- initvals: StartDict or Sequence[Optional[StartDict]], optional
Initial values for random variables provided as a dictionary (or sequence of dictionaries) mapping the random variable (by name or reference) to desired starting values.
- jitter: bool, default True
If True, add jitter to initial points.
- model
Model
, optional Model to sample from. The model needs to have free random variables. When inside a
with
model context, it defaults to that model, otherwise the model must be passed explicitly.- var_namessequence of
str
, optional Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior.
- nuts_kwargs
dict
, optional Keyword arguments for the underlying nuts sampler
- progressbarbool, default
True
If True, display a progressbar while sampling
- keep_untransformedbool, default
False
Include untransformed variables in the posterior samples.
- chain_method
str
, default “parallel” Specify how samples should be drawn. The choices include “parallel”, and “vectorized”.
- postprocessing_backend
Optional
[Literal
[“cpu”, “gpu”]], default None, Specify how postprocessing should be computed. gpu or cpu
- postprocessing_vectorize
Literal
[“vmap”, “scan”], default “scan” How to vectorize the postprocessing: vmap or sequential scan
- postprocessing_chunks
None
This argument is deprecated
- idata_kwargs
dict
, optional Keyword arguments for
arviz.from_dict()
. It also accepts a boolean as value for thelog_likelihood
key to indicate that the pointwise log likelihood should not be included in the returned object. Values forobserved_data
,constant_data
,coords
, anddims
are inferred from themodel
argument if not provided inidata_kwargs
. Ifcoords
anddims
are provided, they are used to update the inferred dictionaries.- compute_convergence_checksbool, default
True
If True, compute ess and rhat values and warn if they indicate potential sampling issues.
- nuts_sampler
Literal
[“numpyro”, “blackjax”] Nuts sampler library to use - do not change - use sample_numpyro_nuts or sample_blackjax_nuts as appropriate
- draws
- Returns:
InferenceData
ArviZ
InferenceData
object that contains the posterior samples, together with their respective sample stats and pointwise log likeihood values (unless skipped withidata_kwargs
).