- pymc.sampling_jax.sample_blackjax_nuts(draws=1000, tune=1000, chains=4, target_accept=0.8, random_seed=None, initvals=None, model=None, var_names=None, keep_untransformed=False, chain_method='parallel', postprocessing_backend=None, idata_kwargs=None)#
Draw samples from the posterior using the NUTS method from the
int, default 1000
The number of samples to draw. The number of tuned samples are discarded by default.
int, default 1000
Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the
int, default 4
The number of chains to sample.
The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors.
Random seed used by the sampling steps.
Model to sample from. The model needs to have free random variables. When inside a
withmodel context, it defaults to that model, otherwise the model must be passed explicitly.
- var_namesiterable of
Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior
- keep_untransformedbool, default
Include untransformed variables in the posterior samples. Defaults to False.
str, default “parallel”
Specify how samples should be drawn. The choices include “parallel”, and “vectorized”.
Specify how postprocessing should be computed. gpu or cpu
Keyword arguments for
arviz.from_dict(). It also accepts a boolean as value for the
log_likelihoodkey to indicate that the pointwise log likelihood should not be included in the returned object.
InferenceDataobject that contains the posterior samples, together with their respective sample stats and pointwise log likeihood values (unless skipped with