pymc.SVGD#

class pymc.SVGD(n_particles=100, jitter=1, model=None, start=None, random_seed=None, estimator=<class 'pymc.variational.operators.KSD'>, kernel=<pymc.variational.test_functions.RBF object>, **kwargs)[source]#

Stein Variational Gradient Descent

This inference is based on Kernelized Stein Discrepancy it’s main idea is to move initial noisy particles so that they fit target distribution best.

Algorithm is outlined below

Input: A target distribution with density function \(p(x)\)

and a set of initial particles \(\{x^0_i\}^n_{i=1}\)

Output: A set of particles \(\{x^{*}_i\}^n_{i=1}\) that approximates the target distribution.

\[\begin{split}x_i^{l+1} &\leftarrow x_i^{l} + \epsilon_l \hat{\phi}^{*}(x_i^l) \\ \hat{\phi}^{*}(x) &= \frac{1}{n}\sum^{n}_{j=1}[k(x^l_j,x) \nabla_{x^l_j} logp(x^l_j)+ \nabla_{x^l_j} k(x^l_j,x)]\end{split}\]
Parameters
n_particles: `int`

number of particles to use for approximation

jitter: `float`

noise sd for initial point

model:class:pymc.Model

PyMC model for inference

kernel: `callable`

kernel function for KSD \(f(histogram) -> (k(x,.), \nabla_x k(x,.))\)

temperature: float

parameter responsible for exploration, higher temperature gives more broad posterior estimate

start: `dict`

initial point for inference

random_seed: None or int

leave None to use package global RandomStream or other valid value to create instance specific one

start: `Point`

starting point for inference

kwargs: other keyword arguments passed to estimator

References

  • Qiang Liu, Dilin Wang (2016) Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm arXiv:1608.04471

  • Yang Liu, Prajit Ramachandran, Qiang Liu, Jian Peng (2017) Stein Variational Policy Gradient arXiv:1704.02399

Methods

SVGD.__init__([n_particles, jitter, model, ...])

SVGD.fit([n, score, callbacks, progressbar])

Perform Operator Variational Inference

SVGD.refine(n[, progressbar])

Refine the solution using the last compiled step function

SVGD.run_profiling([n, score])

Attributes

approx