# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import singledispatch
import numpy as np
import pytensor.tensor as pt
# ignore mypy error because it somehow considers that
# "numpy.core.numeric has no attribute normalize_axis_tuple"
from numpy.core.numeric import normalize_axis_tuple # type: ignore
from pytensor.graph import Op
from pytensor.tensor import TensorVariable
from pymc.logprob.transforms import (
CircularTransform,
IntervalTransform,
LogOddsTransform,
LogTransform,
RVTransform,
SimplexTransform,
)
__all__ = [
"RVTransform",
"simplex",
"logodds",
"Interval",
"log_exp_m1",
"univariate_ordered",
"multivariate_ordered",
"log",
"sum_to_1",
"univariate_sum_to_1",
"multivariate_sum_to_1",
"circular",
"CholeskyCovPacked",
"Chain",
"ZeroSumTransform",
]
@singledispatch
def _default_transform(op: Op, rv: TensorVariable):
"""Return default transform for a given Distribution `Op`"""
return None
[docs]class LogExpM1(RVTransform):
name = "log_exp_m1"
[docs] def backward(self, value, *inputs):
return pt.softplus(value)
[docs] def forward(self, value, *inputs):
"""Inverse operation of softplus.
y = Log(Exp(x) - 1)
= Log(1 - Exp(-x)) + x
"""
return pt.log(1.0 - pt.exp(-value)) + value
[docs] def log_jac_det(self, value, *inputs):
return -pt.softplus(-value)
[docs]class Ordered(RVTransform):
name = "ordered"
[docs] def __init__(self, ndim_supp=0):
if ndim_supp > 1:
raise ValueError(
f"For Ordered transformation number of core dimensions"
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
)
self.ndim_supp = ndim_supp
[docs] def backward(self, value, *inputs):
x = pt.zeros(value.shape)
x = pt.inc_subtensor(x[..., 0], value[..., 0])
x = pt.inc_subtensor(x[..., 1:], pt.exp(value[..., 1:]))
return pt.cumsum(x, axis=-1)
[docs] def forward(self, value, *inputs):
y = pt.zeros(value.shape)
y = pt.inc_subtensor(y[..., 0], value[..., 0])
y = pt.inc_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1]))
return y
[docs] def log_jac_det(self, value, *inputs):
if self.ndim_supp == 0:
return pt.sum(value[..., 1:], axis=-1, keepdims=True)
else:
return pt.sum(value[..., 1:], axis=-1)
[docs]class SumTo1(RVTransform):
"""
Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of values in [0,1]
This Transformation operates on the last dimension of the input tensor.
"""
name = "sumto1"
[docs] def __init__(self, ndim_supp=0):
if ndim_supp > 1:
raise ValueError(
f"For SumTo1 transformation number of core dimensions"
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
)
self.ndim_supp = ndim_supp
[docs] def backward(self, value, *inputs):
remaining = 1 - pt.sum(value[..., :], axis=-1, keepdims=True)
return pt.concatenate([value[..., :], remaining], axis=-1)
[docs] def forward(self, value, *inputs):
return value[..., :-1]
[docs] def log_jac_det(self, value, *inputs):
y = pt.zeros(value.shape)
if self.ndim_supp == 0:
return pt.sum(y, axis=-1, keepdims=True)
else:
return pt.sum(y, axis=-1)
[docs]class CholeskyCovPacked(RVTransform):
"""
Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the
log scale
"""
name = "cholesky-cov-packed"
[docs] def __init__(self, n):
"""
Parameters
----------
n: int
Number of diagonal entries in the LKJCholeskyCov distribution
"""
self.diag_idxs = pt.arange(1, n + 1).cumsum() - 1
[docs] def backward(self, value, *inputs):
return pt.set_subtensor(value[..., self.diag_idxs], pt.exp(value[..., self.diag_idxs]))
[docs] def forward(self, value, *inputs):
return pt.set_subtensor(value[..., self.diag_idxs], pt.log(value[..., self.diag_idxs]))
[docs] def log_jac_det(self, value, *inputs):
return pt.sum(value[..., self.diag_idxs], axis=-1)
[docs]class Chain(RVTransform):
__slots__ = ("param_extract_fn", "transform_list", "name")
[docs] def __init__(self, transform_list):
self.transform_list = transform_list
self.name = "+".join([transf.name for transf in self.transform_list])
[docs] def forward(self, value, *inputs):
y = value
for transf in self.transform_list:
# TODO:Needs proper discussion as to what should be
# passed as inputs here
y = transf.forward(y, *inputs)
return y
[docs] def backward(self, value, *inputs):
x = value
for transf in reversed(self.transform_list):
x = transf.backward(x, *inputs)
return x
[docs] def log_jac_det(self, value, *inputs):
y = pt.as_tensor_variable(value)
det_list = []
ndim0 = y.ndim
for transf in reversed(self.transform_list):
det_ = transf.log_jac_det(y, *inputs)
det_list.append(det_)
y = transf.backward(y, *inputs)
ndim0 = min(ndim0, det_.ndim)
# match the shape of the smallest log_jac_det
det = 0.0
for det_ in det_list:
if det_.ndim > ndim0:
det += det_.sum(axis=-1)
else:
det += det_
return det
simplex = SimplexTransform()
simplex.__doc__ = """
Instantiation of :class:`pymc.logprob.transforms.SimplexTransform`
for use in the ``transform`` argument of a random variable."""
logodds = LogOddsTransform()
logodds.__doc__ = """
Instantiation of :class:`pymc.logprob.transforms.LogOddsTransform`
for use in the ``transform`` argument of a random variable."""
[docs]class Interval(IntervalTransform):
"""Wrapper around :class:`pymc.logprob.transforms.IntervalTransform` for use in the
``transform`` argument of a random variable.
Parameters
----------
lower : int or float, optional
Lower bound of the interval transform. Must be a constant finite value.
By default (``lower=None``), the interval is not bounded below.
upper : int or float, optional
Upper bound of the interval transform. Must be a constant finite value.
By default (``upper=None``), the interval is not bounded above.
bounds_fn : callable, optional
Alternative to lower and upper. Must return a tuple of lower and upper bounds
as a symbolic function of the respective distribution inputs. If one of lower or
upper is ``None``, the interval is unbounded on that edge.
.. warning:: Expressions returned by `bounds_fn` should depend only on the
distribution inputs or other constants. Expressions that depend on nonlocal
variables, such as other distributions defined in the model context will
likely break sampling.
Examples
--------
Create an interval transform between -1 and +1
.. code-block:: python
with pm.Model():
interval = pm.distributions.transforms.Interval(lower=-1, upper=1)
x = pm.Normal("x", transform=interval)
Create a lower-bounded interval transform at 0, using a callable
.. code-block:: python
def get_bounds(rng, size, dtype, mu, sigma):
return 0, None
with pm.Model():
interval = pm.distributions.transforms.Interval(bounds_fn=get_bounds)
x = pm.Normal("x", transform=interval)
Create a lower-bounded interval transform that depends on a distribution parameter
.. code-block:: python
def get_bounds(rng, size, dtype, mu, sigma):
return mu - 1, None
interval = pm.distributions.transforms.Interval(bounds_fn=get_bounds)
with pm.Model():
mu = pm.Normal("mu")
x = pm.Normal("x", mu=mu, sigma=2, transform=interval)
"""
[docs] def __init__(self, lower=None, upper=None, *, bounds_fn=None):
if bounds_fn is None:
try:
bounds = tuple(
None if bound is None else pt.constant(bound, ndim=0).data
for bound in (lower, upper)
)
except (ValueError, TypeError):
raise ValueError(
"Interval bounds must be constant values. If you need expressions that "
"depend on symbolic variables use `args_fn`"
)
lower, upper = (
None if (bound is None or np.isinf(bound)) else bound for bound in bounds
)
if lower is None and upper is None:
raise ValueError("Lower and upper interval bounds cannot both be None")
def bounds_fn(*rv_inputs):
return lower, upper
super().__init__(args_fn=bounds_fn)
def extend_axis(array, axis):
n = array.shape[axis] + 1
sum_vals = array.sum(axis, keepdims=True)
norm = sum_vals / (pt.sqrt(n) + n)
fill_val = norm - sum_vals / pt.sqrt(n)
out = pt.concatenate([array, fill_val], axis=axis)
return out - norm
def extend_axis_rev(array, axis):
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]
n = array.shape[normalized_axis]
last = pt.take(array, [-1], axis=normalized_axis)
sum_vals = -last * pt.sqrt(n)
norm = sum_vals / (pt.sqrt(n) + n)
slice_before = (slice(None, None),) * normalized_axis
return array[slice_before + (slice(None, -1),)] + norm
log_exp_m1 = LogExpM1()
log_exp_m1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.LogExpM1`
for use in the ``transform`` argument of a random variable."""
univariate_ordered = Ordered(ndim_supp=0)
univariate_ordered.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument of a univariate random variable."""
multivariate_ordered = Ordered(ndim_supp=1)
multivariate_ordered.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument of a multivariate random variable."""
log = LogTransform()
log.__doc__ = """
Instantiation of :class:`pymc.logprob.transforms.LogTransform`
for use in the ``transform`` argument of a random variable."""
univariate_sum_to_1 = SumTo1(ndim_supp=0)
univariate_sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a univariate random variable."""
multivariate_sum_to_1 = SumTo1(ndim_supp=1)
multivariate_sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a multivariate random variable."""
# backwards compatibility
sum_to_1 = SumTo1(ndim_supp=1)
sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a random variable.
This instantiation is for backwards compatibility only.
Please use `univariate_sum_to_1` or `multivariate_sum_to_1` instead."""
circular = CircularTransform()
circular.__doc__ = """
Instantiation of :class:`pymc.logprob.transforms.CircularTransform`
for use in the ``transform`` argument of a random variable."""