pymc.sampling.jax.sample_numpyro_nuts#

pymc.sampling.jax.sample_numpyro_nuts(draws=1000, *, tune=1000, chains=4, target_accept=0.8, random_seed=None, initvals=None, jitter=True, model=None, var_names=None, nuts_kwargs=None, progressbar=True, keep_untransformed=False, chain_method='parallel', postprocessing_backend=None, postprocessing_vectorize='scan', postprocessing_chunks=None, idata_kwargs=None, compute_convergence_checks=True, nuts_sampler='numpyro')#