pymc.sample#

pymc.sample(draws=1000, step=None, init='auto', n_init=200000, initvals=None, trace=None, chain_idx=0, chains=None, cores=None, tune=1000, progressbar=True, model=None, random_seed=None, discard_tuned_samples=True, compute_convergence_checks=True, callback=None, jitter_max_retries=10, *, return_inferencedata=True, idata_kwargs=None, mp_ctx=None, **kwargs)[source]#

Draw samples from the posterior using the given step methods.

Multiple step methods are supported via compound step methods.

Parameters
drawsint

The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded by default. See discard_tuned_samples.

initstr

Initialization method to use for auto-assigned NUTS samplers. See pm.init_nuts for a list of all options. This argument is ignored when manually passing the NUTS step method.

stepfunction or iterable of functions

A step function or collection of functions. If there are variables without step methods, step methods for those variables will be assigned automatically. By default the NUTS step method will be used, if appropriate to the model.

n_initint

Number of iterations of initializer. Only works for ‘ADVI’ init methods.

initvalsoptional, dict, array of dict

Dict or list of dicts with initial value strategies to use instead of the defaults from Model.initial_values. The keys should be names of transformed random variables. Initialization methods for NUTS (see init keyword) can overwrite the default.

tracebackend or list

This should be a backend instance, or a list of variables to track. If None or a list of variables, the NDArray backend is used.

chain_idxint

Chain number used to store sample in backend. If chains is greater than one, chain numbers will start here.

chainsint

The number of chains to sample. Running independent chains is important for some convergence statistics and can also reveal multiple modes in the posterior. If None, then set to either cores or 2, whichever is larger.

coresint

The number of chains to run in parallel. If None, set to the number of CPUs in the system, but at most 4.

tuneint

Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the draws argument, and will be discarded unless discard_tuned_samples is set to False.

progressbarbool, optional default=True

Whether or not to display a progress bar in the command line. The bar shows the percentage of completion, the sampling speed in samples per second (SPS), and the estimated remaining time until completion (“expected time of arrival”; ETA).

modelModel (optional if in with context)

Model to sample from. The model needs to have free random variables.

random_seedint, array_like of int, RandomState or Generator, optional

Random seed(s) used by the sampling steps. If a list, tuple or array of ints is passed, each entry will be used to seed each chain. A ValueError will be raised if the length does not match the number of chains.

discard_tuned_samplesbool

Whether to discard posterior samples of the tune interval.

compute_convergence_checksbool, default=True

Whether to compute sampler statistics like Gelman-Rubin and effective_n.

callbackfunction, default=None

A function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw and will contain all samples for a single trace. the draw.chain argument can be used to determine which of the active chains the sample is drawn from. Sampling can be interrupted by throwing a KeyboardInterrupt in the callback.

jitter_max_retriesint

Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter that yields a finite probability. This applies to jitter+adapt_diag and jitter+adapt_full init methods.

return_inferencedatabool

Whether to return the trace as an arviz.InferenceData (True) object or a MultiTrace (False). Defaults to True.

idata_kwargsdict, optional

Keyword arguments for pymc.to_inference_data()

mp_ctxmultiprocessing.context.BaseContent

A multiprocessing context for parallel sampling. See multiprocessing documentation for details.

Returns
tracepymc.backends.base.MultiTrace or arviz.InferenceData

A MultiTrace or ArviZ InferenceData object that contains the samples.

Notes

Optional keyword arguments can be passed to sample to be delivered to the step_methods used during sampling.

For example:

  1. target_accept to NUTS: nuts={‘target_accept’:0.9}

  2. transit_p to BinaryGibbsMetropolis: binary_gibbs_metropolis={‘transit_p’:.7}

Note that available step names are:

nuts, hmc, metropolis, binary_metropolis, binary_gibbs_metropolis, categorical_gibbs_metropolis, DEMetropolis, DEMetropolisZ, slice

The NUTS step method has several options including:

  • target_accept : float in [0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. This argument can be passed directly to sample.

  • max_treedepth : The maximum depth of the trajectory tree

  • step_scale : float, default 0.25 The initial guess for the step size scaled down by \(1/n**(1/4)\), where n is the dimensionality of the parameter space

Alternatively, if you manually declare the step_methods, within the step

kwarg, then you can address the step_method kwargs directly. e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis, you could send

step=[pm.NUTS([freeRV1, freeRV2], target_accept=0.9),
      pm.BinaryGibbsMetropolis([freeRV3], transit_p=.7)]

You can find a full list of arguments in the docstring of the step methods.

Examples

In [1]: import pymc as pm
   ...: n = 100
   ...: h = 61
   ...: alpha = 2
   ...: beta = 2

In [2]: with pm.Model() as model: # context management
   ...:     p = pm.Beta("p", alpha=alpha, beta=beta)
   ...:     y = pm.Binomial("y", n=n, p=p, observed=h)
   ...:     idata = pm.sample()

In [3]: az.summary(idata, kind="stats")

Out[3]:
    mean     sd  hdi_3%  hdi_97%
p  0.609  0.047   0.528    0.699