pymc.sampling.jax.sample_numpyro_nuts#

pymc.sampling.jax.sample_numpyro_nuts(draws=1000, *, tune=1000, chains=4, target_accept=0.8, random_seed=None, initvals=None, jitter=True, model=None, var_names=None, nuts_kwargs=None, progressbar=True, keep_untransformed=False, chain_method='parallel', postprocessing_backend=None, postprocessing_vectorize=None, postprocessing_chunks=None, idata_kwargs=None, compute_convergence_checks=True, nuts_sampler='numpyro')#

Draw samples from the posterior using a jax NUTS method.

Parameters:
drawsint, default 1000

The number of samples to draw. The number of tuned samples are discarded by default.

tuneint, default 1000

Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the draws argument. Tuned samples are discarded.

chainsint, default 4

The number of chains to sample.

target_acceptfloat in [0, 1].

The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors.

random_seedint, RandomState or Generator, optional

Random seed used by the sampling steps.

initvals: StartDict or Sequence[Optional[StartDict]], optional

Initial values for random variables provided as a dictionary (or sequence of dictionaries) mapping the random variable (by name or reference) to desired starting values.

jitter: bool, default True

If True, add jitter to initial points.

modelModel, optional

Model to sample from. The model needs to have free random variables. When inside a with model context, it defaults to that model, otherwise the model must be passed explicitly.

var_namessequence of str, optional

Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior.

nuts_kwargsdict, optional

Keyword arguments for the underlying nuts sampler

progressbarbool, default True

If True, display a progressbar while sampling

keep_untransformedbool, default False

Include untransformed variables in the posterior samples.

chain_methodstr, default “parallel”

Specify how samples should be drawn. The choices include “parallel”, and “vectorized”.

postprocessing_backendOptional[Literal[“cpu”, “gpu”]], default None,

Specify how postprocessing should be computed. gpu or cpu

postprocessing_vectorizeLiteral[“vmap”, “scan”], default “scan”

How to vectorize the postprocessing: vmap or sequential scan

postprocessing_chunksNone

This argument is deprecated

idata_kwargsdict, optional

Keyword arguments for arviz.from_dict(). It also accepts a boolean as value for the log_likelihood key to indicate that the pointwise log likelihood should not be included in the returned object. Values for observed_data, constant_data, coords, and dims are inferred from the model argument if not provided in idata_kwargs. If coords and dims are provided, they are used to update the inferred dictionaries.

compute_convergence_checksbool, default True

If True, compute ess and rhat values and warn if they indicate potential sampling issues.

nuts_samplerLiteral[“numpyro”, “blackjax”]

Nuts sampler library to use - do not change - use sample_numpyro_nuts or sample_blackjax_nuts as appropriate

Returns:
InferenceData

ArviZ InferenceData object that contains the posterior samples, together with their respective sample stats and pointwise log likeihood values (unless skipped with idata_kwargs).