# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from functools import singledispatch
import numpy as np
import pytensor.tensor as pt
# ignore mypy error because it somehow considers that
# "numpy.core.numeric has no attribute normalize_axis_tuple"
from numpy.core.numeric import normalize_axis_tuple # type: ignore[attr-defined]
from pytensor.graph import Op
from pytensor.tensor import TensorVariable
from pymc.logprob.transforms import (
ChainedTransform,
CircularTransform,
IntervalTransform,
LogOddsTransform,
LogTransform,
SimplexTransform,
Transform,
)
__all__ = [
"Chain",
"CholeskyCovPacked",
"Interval",
"Transform",
"ZeroSumTransform",
"circular",
"log",
"log_exp_m1",
"logodds",
"ordered",
"simplex",
"sum_to_1",
]
def __getattr__(name):
if name in ("univariate_ordered", "multivariate_ordered"):
warnings.warn(f"{name} has been deprecated, use ordered instead.", FutureWarning)
return ordered
if name in ("univariate_sum_to_1", "multivariate_sum_to_1"):
warnings.warn(f"{name} has been deprecated, use sum_to_1 instead.", FutureWarning)
return sum_to_1
if name == "RVTransform":
warnings.warn("RVTransform has been renamed to Transform", FutureWarning)
return Transform
raise AttributeError(f"module {__name__} has no attribute {name}")
@singledispatch
def _default_transform(op: Op, rv: TensorVariable):
"""Return default transform for a given Distribution `Op`."""
return None
[docs]
class LogExpM1(Transform):
name = "log_exp_m1"
[docs]
def backward(self, value, *inputs):
return pt.softplus(value)
[docs]
def forward(self, value, *inputs):
"""Inverse operation of softplus.
y = Log(Exp(x) - 1)
= Log(1 - Exp(-x)) + x
"""
return pt.log(1.0 - pt.exp(-value)) + value
[docs]
def log_jac_det(self, value, *inputs):
return -pt.softplus(-value)
[docs]
class Ordered(Transform):
name = "ordered"
[docs]
def __init__(self, ndim_supp=None):
if ndim_supp is not None:
warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning)
[docs]
def backward(self, value, *inputs):
x = pt.zeros(value.shape)
x = pt.set_subtensor(x[..., 0], value[..., 0])
x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:]))
return pt.cumsum(x, axis=-1)
[docs]
def forward(self, value, *inputs):
y = pt.zeros(value.shape)
y = pt.set_subtensor(y[..., 0], value[..., 0])
y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1]))
return y
[docs]
def log_jac_det(self, value, *inputs):
return pt.sum(value[..., 1:], axis=-1)
class SumTo1(Transform):
"""
Transforms K - 1 dimensional simplex space (K values in [0, 1] that sum to 1) to a K - 1 vector of values in [0, 1].
This transformation operates on the last dimension of the input tensor.
"""
name = "sumto1"
def __init__(self, ndim_supp=None):
if ndim_supp is not None:
warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning)
def backward(self, value, *inputs):
remaining = 1 - pt.sum(value[..., :], axis=-1, keepdims=True)
return pt.concatenate([value[..., :], remaining], axis=-1)
def forward(self, value, *inputs):
return value[..., :-1]
def log_jac_det(self, value, *inputs):
y = pt.zeros(value.shape)
return pt.sum(y, axis=-1)
[docs]
class CholeskyCovPacked(Transform):
"""Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the log scale."""
name = "cholesky-cov-packed"
[docs]
def __init__(self, n):
"""Create a CholeskyCovPack object.
Parameters
----------
n: int
Number of diagonal entries in the LKJCholeskyCov distribution
"""
self.diag_idxs = pt.arange(1, n + 1).cumsum() - 1
[docs]
def backward(self, value, *inputs):
return pt.set_subtensor(value[..., self.diag_idxs], pt.exp(value[..., self.diag_idxs]))
[docs]
def forward(self, value, *inputs):
return pt.set_subtensor(value[..., self.diag_idxs], pt.log(value[..., self.diag_idxs]))
[docs]
def log_jac_det(self, value, *inputs):
return pt.sum(value[..., self.diag_idxs], axis=-1)
Chain = ChainedTransform
simplex = SimplexTransform()
simplex.__doc__ = """
Instantiation of :class:`pymc.logprob.transforms.SimplexTransform`
for use in the ``transform`` argument of a random variable."""
logodds = LogOddsTransform()
logodds.__doc__ = """
Instantiation of :class:`pymc.logprob.transforms.LogOddsTransform`
for use in the ``transform`` argument of a random variable."""
[docs]
class Interval(IntervalTransform):
"""Wrapper around :class:`pymc.logprob.transforms.IntervalTransform` for use in the ``transform`` argument of a random variable.
Parameters
----------
lower : int or float, optional
Lower bound of the interval transform. Must be a constant finite value.
By default (``lower=None``), the interval is not bounded below.
upper : int or float, optional
Upper bound of the interval transform. Must be a constant finite value.
By default (``upper=None``), the interval is not bounded above.
bounds_fn : callable, optional
Alternative to lower and upper. Must return a tuple of lower and upper bounds
as a symbolic function of the respective distribution inputs. If one of lower or
upper is ``None``, the interval is unbounded on that edge.
.. warning:: Expressions returned by `bounds_fn` should depend only on the
distribution inputs or other constants. Expressions that depend on nonlocal
variables, such as other distributions defined in the model context will
likely break sampling.
Examples
--------
Create an interval transform between -1 and +1
.. code-block:: python
with pm.Model():
interval = pm.distributions.transforms.Interval(lower=-1, upper=1)
x = pm.Normal("x", transform=interval)
Create a lower-bounded interval transform at 0, using a callable
.. code-block:: python
def get_bounds(rng, size, mu, sigma):
return 0, None
with pm.Model():
interval = pm.distributions.transforms.Interval(bounds_fn=get_bounds)
x = pm.Normal("x", transform=interval)
Create a lower-bounded interval transform that depends on a distribution parameter
.. code-block:: python
def get_bounds(rng, size, mu, sigma):
return mu - 1, None
interval = pm.distributions.transforms.Interval(bounds_fn=get_bounds)
with pm.Model():
mu = pm.Normal("mu")
x = pm.Normal("x", mu=mu, sigma=2, transform=interval)
"""
[docs]
def __init__(self, lower=None, upper=None, *, bounds_fn=None):
if bounds_fn is None:
try:
bounds = tuple(
None if bound is None else pt.constant(bound, ndim=0).data
for bound in (lower, upper)
)
except (ValueError, TypeError):
raise ValueError(
"Interval bounds must be constant values. If you need expressions that "
"depend on symbolic variables use `args_fn`"
)
lower, upper = (
None if (bound is None or np.isinf(bound)) else bound for bound in bounds
)
if lower is None and upper is None:
raise ValueError("Lower and upper interval bounds cannot both be None")
def bounds_fn(*rv_inputs):
return lower, upper
super().__init__(args_fn=bounds_fn)
log_exp_m1 = LogExpM1()
log_exp_m1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.LogExpM1`
for use in the ``transform`` argument of a random variable."""
ordered = Ordered()
ordered.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.Ordered`
for use in the ``transform`` argument of a random variable."""
log = LogTransform()
log.__doc__ = """
Instantiation of :class:`pymc.logprob.transforms.LogTransform`
for use in the ``transform`` argument of a random variable."""
sum_to_1 = SumTo1()
sum_to_1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
for use in the ``transform`` argument of a random variable."""
circular = CircularTransform()
circular.__doc__ = """
Instantiation of :class:`pymc.logprob.transforms.CircularTransform`
for use in the ``transform`` argument of a random variable."""