pymc.init_nuts#
- pymc.init_nuts(*, init='auto', chains=1, n_init=500000, model=None, random_seed=None, progressbar=True, jitter_max_retries=10, tune=None, initvals=None, **kwargs)[source]#
Set up the mass matrix initialization for NUTS.
NUTS convergence and sampling speed is extremely dependent on the choice of mass/scaling matrix. This function implements different methods for choosing or adapting the mass matrix.
- Parameters:
- init
str
Initialization method to use.
auto: Choose a default initialization method automatically. Currently, this is
jitter+adapt_diag
, but this can change in the future. If you depend on the exact behaviour, choose an initialization method explicitly.adapt_diag: Start with a identity mass matrix and then adapt a diagonal based on the variance of the tuning samples. All chains use the test value (usually the prior mean) as starting point.
jitter+adapt_diag: Same as
adapt_diag
, but use test value plus a uniform jitter in [-1, 1] as starting point in each chain.jitter+adapt_diag_grad: An experimental initialization method that uses information from gradients and samples during tuning.
advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the sample variance of the tuning samples.
advi: Run ADVI to estimate posterior mean and diagonal mass matrix.
advi_map: Initialize ADVI with MAP and use MAP as starting point.
map: Use the MAP as starting point. This is discouraged.
adapt_full: Adapt a dense mass matrix using the sample covariances. All chains use the test value (usually the prior mean) as starting point.
jitter+adapt_full: Same as
adapt_full
, but use test value plus a uniform jitter in [-1, 1] as starting point in each chain.
- chains
int
Number of jobs to start.
- initvalsoptional,
dict
orlist
ofdicts
Dict or list of dicts with initial values to use instead of the defaults from Model.initial_values. The keys should be names of transformed random variables.
- n_init
int
Number of iterations of initializer. Only works for ‘ADVI’ init methods.
- model
Model
(optionalif
in
with
context
) - random_seed
int
, array_like ofint
,RandomState
orGenerator
, optional Seed for the random number generator.
- progressbarbool
Whether or not to display a progressbar for advi sampling.
- jitter_max_retries
int
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter that yields a finite probability. This applies to
jitter+adapt_diag
andjitter+adapt_full
init methods.- **kwargs
keyword
arguments
Extra keyword arguments are forwarded to pymc.NUTS.
- init
- Returns:
- initial_points
list
Starting points for each chain.
- nuts_sampler
pymc.step_methods.NUTS
Instantiated and initialized NUTS sampler object
- initial_points