Source code for pymc.distributions.mixture

#   Copyright 2023 The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
import warnings

import numpy as np
import pytensor
import pytensor.tensor as pt

from pytensor.graph.basic import Node, equal_computations
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable

from pymc.distributions import transforms
from pymc.distributions.continuous import Normal, get_tau_sigma
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import (
    Distribution,
    SymbolicRandomVariable,
    _moment,
    moment,
)
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size
from pymc.distributions.transforms import _default_transform
from pymc.logprob.abstract import _logcdf, _logcdf_helper, _logprob, _logprob_helper
from pymc.logprob.transforms import IntervalTransform
from pymc.logprob.utils import ignore_logprob
from pymc.util import check_dist_not_registered
from pymc.vartypes import continuous_types, discrete_types

__all__ = ["Mixture", "NormalMixture"]


class MarginalMixtureRV(SymbolicRandomVariable):
    """A placeholder used to specify a log-likelihood for a mixture sub-graph."""

    default_output = 1
    _print_name = ("MarginalMixture", "\\operatorname{MarginalMixture}")

    def update(self, node: Node):
        # Update for the internal mix_indexes RV
        return {node.inputs[0]: node.outputs[0]}


[docs]class Mixture(Distribution): R""" Mixture log-likelihood Often used to model subpopulation heterogeneity .. math:: f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i) ======== ============================================ Support :math:`\cup_{i = 1}^n \textrm{support}(f_i)` Mean :math:`\sum_{i = 1}^n w_i \mu_i` ======== ============================================ Parameters ---------- w : tensor_like of float w >= 0 and w <= 1 the mixture weights comp_dists : iterable of unnamed distributions or single batched distribution Distributions should be created via the `.dist()` API. If a single distribution is passed, the last size dimension (not shape) determines the number of mixture components (e.g. `pm.Poisson.dist(..., size=components)`) :math:`f_1, \ldots, f_n` .. warning:: comp_dists will be cloned, rendering them independent of the ones passed as input. Examples -------- .. code-block:: python # Mixture of 2 Poisson variables with pm.Model() as model: w = pm.Dirichlet('w', a=np.array([1, 1])) # 2 mixture weights lam1 = pm.Exponential('lam1', lam=1) lam2 = pm.Exponential('lam2', lam=1) # As we just need the logp, rather than add a RV to the model, we need to call `.dist()` # These two forms are equivalent, but the second benefits from vectorization components = [ pm.Poisson.dist(mu=lam1), pm.Poisson.dist(mu=lam2), ] # `shape=(2,)` indicates 2 mixture components components = pm.Poisson.dist(mu=pm.math.stack([lam1, lam2]), shape=(2,)) like = pm.Mixture('like', w=w, comp_dists=components, observed=data) .. code-block:: python # Mixture of Normal and StudentT variables with pm.Model() as model: w = pm.Dirichlet('w', a=np.array([1, 1])) # 2 mixture weights mu = pm.Normal("mu", 0, 1) components = [ pm.Normal.dist(mu=mu, sigma=1), pm.StudentT.dist(nu=4, mu=mu, sigma=1), ] like = pm.Mixture('like', w=w, comp_dists=components, observed=data) .. code-block:: python # Mixture of (5 x 3) Normal variables with pm.Model() as model: # w is a stack of 5 independent size 3 weight vectors # If shape was `(3,)`, the weights would be shared across the 5 replication dimensions w = pm.Dirichlet('w', a=np.ones(3), shape=(5, 3)) # Each of the 3 mixture components has an independent mean mu = pm.Normal('mu', mu=np.arange(3), sigma=1, shape=3) # These two forms are equivalent, but the second benefits from vectorization components = [ pm.Normal.dist(mu=mu[0], sigma=1, shape=(5,)), pm.Normal.dist(mu=mu[1], sigma=1, shape=(5,)), pm.Normal.dist(mu=mu[2], sigma=1, shape=(5,)), ] components = pm.Normal.dist(mu=mu, sigma=1, shape=(5, 3)) # The mixture is an array of 5 elements # Each element can be thought of as an independent scalar mixture of 3 # components with different means like = pm.Mixture('like', w=w, comp_dists=components, observed=data) .. code-block:: python # Mixture of 2 Dirichlet variables with pm.Model() as model: w = pm.Dirichlet('w', a=np.ones(2)) # 2 mixture weights # These two forms are equivalent, but the second benefits from vectorization components = [ pm.Dirichlet.dist(a=[1, 10, 100], shape=(3,)), pm.Dirichlet.dist(a=[100, 10, 1], shape=(3,)), ] components = pm.Dirichlet.dist(a=[[1, 10, 100], [100, 10, 1]], shape=(2, 3)) # The mixture is an array of 3 elements # Each element comes from only one of the two core Dirichlet components like = pm.Mixture('like', w=w, comp_dists=components, observed=data) """ rv_type = MarginalMixtureRV
[docs] @classmethod def dist(cls, w, comp_dists, **kwargs): if not isinstance(comp_dists, (tuple, list)): # comp_dists is a single component comp_dists = [comp_dists] elif len(comp_dists) == 1: warnings.warn( "Single component will be treated as a mixture across the last size dimension.\n" "To disable this warning do not wrap the single component inside a list or tuple", UserWarning, ) if len(comp_dists) > 1: if not ( all(comp_dist.dtype in continuous_types for comp_dist in comp_dists) or all(comp_dist.dtype in discrete_types for comp_dist in comp_dists) ): raise ValueError( "All distributions in comp_dists must be either discrete or continuous.\n" "See the following issue for more information: https://github.com/pymc-devs/pymc/issues/4511." ) # Check that components are not associated with a registered variable in the model components_ndim_supp = set() for dist in comp_dists: # TODO: Allow these to not be a RandomVariable as long as we can call `ndim_supp` on them # and resize them if not isinstance(dist, TensorVariable) or not isinstance( dist.owner.op, (RandomVariable, SymbolicRandomVariable) ): raise ValueError( f"Component dist must be a distribution created via the `.dist()` API, got {type(dist)}" ) check_dist_not_registered(dist) components_ndim_supp.add(dist.owner.op.ndim_supp) if len(components_ndim_supp) > 1: raise ValueError( f"Mixture components must all have the same support dimensionality, got {components_ndim_supp}" ) w = pt.as_tensor_variable(w) return super().dist([w, *comp_dists], **kwargs)
[docs] @classmethod def rv_op(cls, weights, *components, size=None): # Create new rng for the mix_indexes internal RV mix_indexes_rng = pytensor.shared(np.random.default_rng()) single_component = len(components) == 1 ndim_supp = components[0].owner.op.ndim_supp if size is not None: components = cls._resize_components(size, *components) elif not single_component: # We might need to broadcast components when size is not specified shape = tuple(pt.broadcast_shape(*components)) size = shape[: len(shape) - ndim_supp] components = cls._resize_components(size, *components) # Extract replication ndims from components and weights ndim_batch = components[0].ndim - ndim_supp if single_component: # One dimension is taken by the mixture axis in the single component case ndim_batch -= 1 # The weights may imply extra batch dimensions that go beyond what is already # implied by the component dimensions (ndim_batch) weights_ndim_batch = max(0, weights.ndim - ndim_batch - 1) # If weights are large enough that they would broadcast the component distributions # we try to resize them. This in necessary to avoid duplicated values in the # random method and for equivalency with the logp method if weights_ndim_batch: new_size = pt.concatenate( [ weights.shape[:weights_ndim_batch], components[0].shape[:ndim_batch], ] ) components = cls._resize_components(new_size, *components) # Extract support and batch ndims from components and weights ndim_batch = components[0].ndim - ndim_supp if single_component: ndim_batch -= 1 weights_ndim_batch = max(0, weights.ndim - ndim_batch - 1) assert weights_ndim_batch == 0 # Component RVs terms are accounted by the Mixture logprob, so they can be # safely ignored in the logprob graph components = [ignore_logprob(component) for component in components] # Create a OpFromGraph that encapsulates the random generating process # Create dummy input variables with the same type as the ones provided weights_ = weights.type() components_ = [component.type() for component in components] mix_indexes_rng_ = mix_indexes_rng.type() mix_axis = -ndim_supp - 1 # Stack components across mixture axis if single_component: # If single component, we consider it as being already "stacked" stacked_components_ = components_[0] else: stacked_components_ = pt.stack(components_, axis=mix_axis) # Broadcast weights to (*batched dimensions, stack dimension), ignoring support dimensions weights_broadcast_shape_ = stacked_components_.shape[: ndim_batch + 1] weights_broadcasted_ = pt.broadcast_to(weights_, weights_broadcast_shape_) # Draw mixture indexes and append (stack + ndim_supp) broadcastable dimensions to the right mix_indexes_ = pt.random.categorical(weights_broadcasted_, rng=mix_indexes_rng_) mix_indexes_padded_ = pt.shape_padright(mix_indexes_, ndim_supp + 1) # Index components and squeeze mixture dimension mix_out_ = pt.take_along_axis(stacked_components_, mix_indexes_padded_, axis=mix_axis) mix_out_ = pt.squeeze(mix_out_, axis=mix_axis) # Output mix_indexes rng update so that it can be updated in place mix_indexes_rng_next_ = mix_indexes_.owner.outputs[0] mix_op = MarginalMixtureRV( inputs=[mix_indexes_rng_, weights_, *components_], outputs=[mix_indexes_rng_next_, mix_out_], ndim_supp=components[0].owner.op.ndim_supp, ) # Create the actual MarginalMixture variable mix_out = mix_op(mix_indexes_rng, weights, *components) return mix_out
@classmethod def _resize_components(cls, size, *components): if len(components) == 1: # If we have a single component, we need to keep the length of the mixture # axis intact, because that's what determines the number of mixture components mix_axis = -components[0].owner.op.ndim_supp - 1 mix_size = components[0].shape[mix_axis] size = tuple(size) + (mix_size,) return [change_dist_size(component, size) for component in components]
@_change_dist_size.register(MarginalMixtureRV) def change_marginal_mixture_size(op, dist, new_size, expand=False): weights, *components = dist.owner.inputs[1:] if expand: component = components[0] # Old size is equal to `shape[:-ndim_supp]`, with care needed for `ndim_supp == 0` size_dims = component.ndim - component.owner.op.ndim_supp if len(components) == 1: # If we have a single component, new size should ignore the mixture axis # dimension, as that is not touched by `_resize_components` size_dims -= 1 old_size = components[0].shape[:size_dims] new_size = tuple(new_size) + tuple(old_size) return Mixture.rv_op(weights, *components, size=new_size) @_logprob.register(MarginalMixtureRV) def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs): (value,) = values # single component if len(components) == 1: # Need to broadcast value across mixture axis mix_axis = -components[0].owner.op.ndim_supp - 1 components_logp = _logprob_helper(components[0], pt.expand_dims(value, mix_axis)) else: components_logp = pt.stack( [_logprob_helper(component, value) for component in components], axis=-1, ) mix_logp = pt.logsumexp(pt.log(weights) + components_logp, axis=-1) mix_logp = check_parameters( mix_logp, 0 <= weights, weights <= 1, pt.isclose(pt.sum(weights, axis=-1), 1), msg="0 <= weights <= 1, sum(weights) == 1", ) return mix_logp @_logcdf.register(MarginalMixtureRV) def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs): # single component if len(components) == 1: # Need to broadcast value across mixture axis mix_axis = -components[0].owner.op.ndim_supp - 1 components_logcdf = _logcdf_helper(components[0], pt.expand_dims(value, mix_axis)) else: components_logcdf = pt.stack( [_logcdf_helper(component, value) for component in components], axis=-1, ) mix_logcdf = pt.logsumexp(pt.log(weights) + components_logcdf, axis=-1) mix_logcdf = check_parameters( mix_logcdf, 0 <= weights, weights <= 1, pt.isclose(pt.sum(weights, axis=-1), 1), msg="0 <= weights <= 1, sum(weights) == 1", ) return mix_logcdf @_moment.register(MarginalMixtureRV) def marginal_mixture_moment(op, rv, rng, weights, *components): ndim_supp = components[0].owner.op.ndim_supp weights = pt.shape_padright(weights, ndim_supp) mix_axis = -ndim_supp - 1 if len(components) == 1: moment_components = moment(components[0]) else: moment_components = pt.stack( [moment(component) for component in components], axis=mix_axis, ) mix_moment = pt.sum(weights * moment_components, axis=mix_axis) if components[0].dtype in discrete_types: mix_moment = pt.round(mix_moment) return mix_moment # List of transforms that can be used by Mixture, either because they do not require # special handling or because we have custom logic to enable them. If new default # transforms are implemented, this list and function should be updated allowed_default_mixture_transforms = ( transforms.CholeskyCovPacked, transforms.CircularTransform, transforms.IntervalTransform, transforms.LogTransform, transforms.LogExpM1, transforms.LogOddsTransform, transforms.Ordered, transforms.SimplexTransform, transforms.SumTo1, ) class MixtureTransformWarning(UserWarning): pass @_default_transform.register(MarginalMixtureRV) def marginal_mixture_default_transform(op, rv): def transform_warning(): warnings.warn( f"No safe default transform found for Mixture distribution {rv}. This can " "happen when components have different supports or default transforms.\n" "If appropriate, you can specify a custom transform for more efficient sampling.", MixtureTransformWarning, stacklevel=2, ) rng, weights, *components = rv.owner.inputs default_transforms = [ _default_transform(component.owner.op, component) for component in components ] # If there are more than one type of default transforms, we do not apply any if len({type(transform) for transform in default_transforms}) != 1: transform_warning() return None default_transform = default_transforms[0] if default_transform is None: return None if not isinstance(default_transform, allowed_default_mixture_transforms): transform_warning() return None if isinstance(default_transform, IntervalTransform): # If there are more than one component, we need to check the IntervalTransform # of the components are actually equivalent (e.g., we don't have an # Interval(0, 1), and an Interval(0, 2)). if len(default_transforms) > 1: value = rv.type() backward_expressions = [ transform.backward(value, *component.owner.inputs) for transform, component in zip(default_transforms, components) ] for expr1, expr2 in zip(backward_expressions[:-1], backward_expressions[1:]): if not equal_computations([expr1], [expr2]): transform_warning() return None # We need to create a new IntervalTransform that expects the Mixture inputs args_fn = default_transform.args_fn def mixture_args_fn(rng, weights, *components): # We checked that the interval transforms of each component are equivalent, # so we can just pass the inputs of the first component return args_fn(*components[0].owner.inputs) return IntervalTransform(args_fn=mixture_args_fn) else: return default_transform
[docs]class NormalMixture: R""" Normal mixture log-likelihood .. math:: f(x \mid w, \mu, \sigma^2) = \sum_{i = 1}^n w_i N(x \mid \mu_i, \sigma^2_i) ======== ======================================= Support :math:`x \in \mathbb{R}` Mean :math:`\sum_{i = 1}^n w_i \mu_i` Variance :math:`\sum_{i = 1}^n w_i (\sigma^2_i + \mu_i^2) - \left(\sum_{i = 1}^n w_i \mu_i\right)^2` ======== ======================================= Parameters ---------- w : tensor_like of float w >= 0 and w <= 1 the mixture weights mu : tensor_like of float the component means sigma : tensor_like of float the component standard deviations tau : tensor_like of float the component precisions comp_shape : shape of the Normal component notice that it should be different than the shape of the mixture distribution, with the last axis representing the number of components. Notes ----- You only have to pass in sigma or tau, but not both. Examples -------- .. code-block:: python n_components = 3 with pm.Model() as gauss_mix: μ = pm.Normal( "μ", mu=data.mean(), sigma=10, shape=n_components, transform=pm.transforms.ordered, initval=[1, 2, 3], ) σ = pm.HalfNormal("σ", sigma=10, shape=n_components) weights = pm.Dirichlet("w", np.ones(n_components)) y = pm.NormalMixture("y", w=weights, mu=μ, sigma=σ, observed=data) """ def __new__(cls, name, w, mu, sigma=None, tau=None, comp_shape=(), **kwargs): _, sigma = get_tau_sigma(tau=tau, sigma=sigma) return Mixture(name, w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs)
[docs] @classmethod def dist(cls, w, mu, sigma=None, tau=None, comp_shape=(), **kwargs): _, sigma = get_tau_sigma(tau=tau, sigma=sigma) return Mixture.dist(w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs)