# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import singledispatch
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor import config, graph_replace, scan
from pytensor.graph import Op
from pytensor.graph.basic import Node
from pytensor.raise_op import CheckAndRaise
from pytensor.scan import until
from pytensor.tensor import TensorConstant, TensorVariable
from pytensor.tensor.random.basic import NormalRV
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import (
CustomSymbolicDistRV,
Distribution,
SymbolicRandomVariable,
_support_point,
support_point,
)
from pymc.distributions.shape_utils import (
_change_dist_size,
change_dist_size,
rv_size_is_none,
to_tuple,
)
from pymc.distributions.transforms import _default_transform
from pymc.exceptions import TruncationError
from pymc.logprob.abstract import _logcdf, _logprob
from pymc.logprob.basic import icdf, logcdf, logp
from pymc.math import logdiffexp
from pymc.pytensorf import collect_default_updates
from pymc.util import check_dist_not_registered
class TruncatedRV(SymbolicRandomVariable):
"""
An `Op` constructed from an PyTensor graph
that represents a truncated univariate random variable.
"""
default_output: int = 0
base_rv_op: Op
max_n_steps: int
def __init__(
self,
*args,
base_rv_op: Op,
max_n_steps: int,
**kwargs,
):
self.base_rv_op = base_rv_op
self.max_n_steps = max_n_steps
self._print_name = (
f"Truncated{self.base_rv_op._print_name[0]}",
f"\\operatorname{{{self.base_rv_op._print_name[1]}}}",
)
super().__init__(*args, **kwargs)
@classmethod
def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None):
# We don't accept rng because we don't have control over it when using a specialized Op
# and there may be a need for multiple RNGs in dist.
# Try to use specialized Op
try:
return _truncated(dist.owner.op, lower, upper, size, *dist.owner.inputs)
except NotImplementedError:
pass
lower = pt.as_tensor_variable(lower) if lower is not None else pt.constant(-np.inf)
upper = pt.as_tensor_variable(upper) if upper is not None else pt.constant(np.inf)
if size is not None:
size = pt.as_tensor(size, dtype="int64", ndim=1)
if rv_size_is_none(size):
size = pt.broadcast_shape(dist, lower, upper)
dist = change_dist_size(dist, new_size=size)
rv_inputs = [
inp
if not isinstance(inp.type, RandomType)
else pytensor.shared(np.random.default_rng())
for inp in dist.owner.inputs
]
graph_inputs = [*rv_inputs, lower, upper]
rv = dist.owner.op.make_node(*rv_inputs).default_output()
# Try to use inverted cdf sampling
# truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper))))
try:
logcdf_lower, logcdf_upper = cls._create_logcdf_exprs(rv, rv, lower, upper)
# We use the first RNG from the base RV, so we don't have to introduce a new one
# This is not problematic because the RNG won't be used in the RV logcdf graph
uniform_rng = next(inp for inp in rv_inputs if isinstance(inp.type, RandomType))
uniform_next_rng, uniform = pt.random.uniform(
pt.exp(logcdf_lower),
pt.exp(logcdf_upper),
rng=uniform_rng,
size=rv.shape,
).owner.outputs
truncated_rv = icdf(rv, uniform, warn_rvs=False)
return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs,
outputs=[truncated_rv, uniform_next_rng],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)
except NotImplementedError:
pass
# Fallback to rejection sampling
# truncated_rv = zeros(rv.shape)
# reject_draws = ones(rv.shape, dtype=bool)
# while any(reject_draws):
# truncated_rv[reject_draws] = draw(rv)[reject_draws]
# reject_draws = (truncated_rv < lower) | (truncated_rv > upper)
def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs):
new_truncated_rv = dist.owner.op.make_node(*rv_inputs).default_output()
# Avoid scalar boolean indexing
if truncated_rv.type.ndim == 0:
truncated_rv = new_truncated_rv
else:
truncated_rv = pt.set_subtensor(
truncated_rv[reject_draws],
new_truncated_rv[reject_draws],
)
reject_draws = pt.or_((truncated_rv < lower), (truncated_rv > upper))
return (
(truncated_rv, reject_draws),
collect_default_updates(new_truncated_rv, inputs=rv_inputs),
until(~pt.any(reject_draws)),
)
(truncated_rv, reject_draws_), updates = scan(
loop_fn,
outputs_info=[
pt.zeros_like(rv),
pt.ones_like(rv, dtype=bool),
],
non_sequences=[lower, upper, *rv_inputs],
n_steps=max_n_steps,
strict=True,
)
truncated_rv = truncated_rv[-1]
convergence = ~pt.any(reject_draws_[-1])
truncated_rv = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
truncated_rv, convergence
)
# Sort updates of each RNG so that they show in the same order as the input RNGs
def sort_updates(update):
rng, next_rng = update
return graph_inputs.index(rng)
next_rngs = [next_rng for rng, next_rng in sorted(updates.items(), key=sort_updates)]
return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs,
outputs=[truncated_rv, *next_rngs],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)
@staticmethod
def _create_logcdf_exprs(
base_rv: TensorVariable,
value: TensorVariable,
lower: TensorVariable,
upper: TensorVariable,
) -> tuple[TensorVariable, TensorVariable]:
"""Create lower and upper logcdf expressions for base_rv.
Uses `value` as a template for broadcasting.
"""
# For left truncated discrete RVs, we need to include the whole lower bound.
lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower
lower_value = pt.full_like(value, lower_value, dtype=config.floatX)
upper_value = pt.full_like(value, upper, dtype=config.floatX)
lower_logcdf = logcdf(base_rv, lower_value, warn_rvs=False)
upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value})
return lower_logcdf, upper_logcdf
def update(self, node: Node):
"""Return the update mapping for the internal RNGs.
TruncatedRVs are created in a way that the rng updates follow the same order as the input RNGs.
"""
rngs = [inp for inp in node.inputs if isinstance(inp.type, RandomType)]
next_rngs = [out for out in node.outputs if isinstance(out.type, RandomType)]
return dict(zip(rngs, next_rngs))
@singledispatch
def _truncated(op: Op, lower, upper, size, *params):
"""Return the truncated equivalent of another `RandomVariable`."""
raise NotImplementedError(f"{op} does not have an equivalent truncated version implemented")
class TruncationCheck(CheckAndRaise):
"""Implements a check in truncated graphs.
Raises `TruncationError` if the check is not True.
"""
def __init__(self, msg=""):
super().__init__(TruncationError, msg)
def __str__(self):
return f"TruncationCheck{{{self.msg}}}"
[docs]
class Truncated(Distribution):
r"""
Truncated distribution
The pdf of a Truncated distribution is
.. math::
\begin{cases}
0 & \text{for } x < lower, \\
\frac{\text{PDF}(x, dist)}{\text{CDF}(upper, dist) - \text{CDF}(lower, dist)}
& \text{for } lower <= x <= upper, \\
0 & \text{for } x > upper,
\end{cases}
Parameters
----------
dist: unnamed distribution
Univariate distribution created via the `.dist()` API, which will be truncated.
This distribution must be a pure RandomVariable and have a logcdf method
implemented for MCMC sampling.
.. warning:: dist will be cloned, rendering it independent of the one passed as input.
lower: tensor_like of float or None
Lower (left) truncation point. If `None` the distribution will not be left truncated.
upper: tensor_like of float or None
Upper (right) truncation point. If `None`, the distribution will not be right truncated.
max_n_steps: int, defaults 10_000
Maximum number of resamples that are attempted when performing rejection sampling.
A `TruncationError` is raised if convergence is not reached after that many steps.
Returns
-------
truncated_distribution: TensorVariable
Graph representing a truncated `RandomVariable`. A specialized `Op` may be used
if the `Op` of the dist has a dispatched `_truncated` function. Otherwise, a
`SymbolicRandomVariable` graph representing the truncation process, via inverse
CDF sampling (if the underlying dist has a logcdf method), or rejection sampling
is returned.
Examples
--------
.. code-block:: python
with pm.Model():
normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
truncated_normal = pm.Truncated("truncated_normal", normal_dist, lower=-1, upper=1)
"""
rv_type = TruncatedRV
rv_op = rv_type.rv_op
[docs]
@classmethod
def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs):
if not (
isinstance(dist, TensorVariable)
and isinstance(dist.owner.op, RandomVariable | CustomSymbolicDistRV)
):
if isinstance(dist.owner.op, SymbolicRandomVariable):
raise NotImplementedError(
f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}.\n"
f"You can try wrapping the distribution inside a CustomDist instead."
)
raise ValueError(
f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}"
)
if dist.owner.op.ndim_supp > 0:
raise NotImplementedError("Truncation not implemented for multivariate distributions")
check_dist_not_registered(dist)
if lower is None and upper is None:
raise ValueError("lower and upper cannot both be None")
return super().dist([dist, lower, upper, max_n_steps], **kwargs)
@_change_dist_size.register(TruncatedRV)
def change_truncated_size(op: TruncatedRV, truncated_rv, new_size, expand):
*rv_inputs, lower, upper = truncated_rv.owner.inputs
untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
if expand:
new_size = to_tuple(new_size) + tuple(truncated_rv.shape)
return Truncated.rv_op(
untruncated_rv,
lower=lower,
upper=upper,
size=new_size,
max_n_steps=op.max_n_steps,
)
@_support_point.register(TruncatedRV)
def truncated_support_point(op: TruncatedRV, truncated_rv, *inputs):
*rv_inputs, lower, upper = inputs
# recreate untruncated rv and respective support_point
untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
untruncated_support_point = support_point(untruncated_rv)
fallback_support_point = pt.switch(
pt.and_(pt.bitwise_not(pt.isinf(lower)), pt.bitwise_not(pt.isinf(upper))),
(upper - lower) / 2, # lower and upper are finite
pt.switch(
pt.isinf(upper),
lower + 1, # only lower is finite
upper - 1, # only upper is finite
),
)
return pt.switch(
pt.and_(pt.ge(untruncated_support_point, lower), pt.le(untruncated_support_point, upper)),
untruncated_support_point, # untruncated support_point is between lower and upper
fallback_support_point,
)
@_default_transform.register(TruncatedRV)
def truncated_default_transform(op, truncated_rv):
# Don't transform discrete truncated distributions
if truncated_rv.type.dtype.startswith("int"):
return None
# Lower and Upper are the arguments -2 and -1
return bounded_cont_transform(op, truncated_rv, bound_args_indices=(-2, -1))
@_logprob.register(TruncatedRV)
def truncated_logprob(op, values, *inputs, **kwargs):
(value,) = values
*rv_inputs, lower, upper = inputs
base_rv_op = op.base_rv_op
base_rv = base_rv_op.make_node(*rv_inputs).default_output()
base_logp = logp(base_rv, value)
lower_logcdf, upper_logcdf = TruncatedRV._create_logcdf_exprs(base_rv, value, lower, upper)
if base_rv_op.name:
base_logp.name = f"{base_rv_op}_logprob"
lower_logcdf.name = f"{base_rv_op}_lower_logcdf"
upper_logcdf.name = f"{base_rv_op}_upper_logcdf"
is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
lognorm = 0
if is_lower_bounded and is_upper_bounded:
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
elif is_lower_bounded:
lognorm = pt.log1mexp(lower_logcdf)
elif is_upper_bounded:
lognorm = upper_logcdf
truncated_logp = base_logp - lognorm
if is_lower_bounded:
truncated_logp = pt.switch(value < lower, -np.inf, truncated_logp)
if is_upper_bounded:
truncated_logp = pt.switch(value <= upper, truncated_logp, -np.inf)
if is_lower_bounded and is_upper_bounded:
truncated_logp = check_parameters(
truncated_logp,
pt.le(lower, upper),
msg="lower_bound <= upper_bound",
)
return truncated_logp
@_logcdf.register(TruncatedRV)
def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):
*rv_inputs, lower, upper = inputs
base_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
base_logcdf = logcdf(base_rv, value)
lower_logcdf, upper_logcdf = TruncatedRV._create_logcdf_exprs(base_rv, value, lower, upper)
is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
lognorm = 0
if is_lower_bounded and is_upper_bounded:
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
elif is_lower_bounded:
lognorm = pt.log1mexp(lower_logcdf)
elif is_upper_bounded:
lognorm = upper_logcdf
logcdf_numerator = logdiffexp(base_logcdf, lower_logcdf) if is_lower_bounded else base_logcdf
logcdf_trunc = logcdf_numerator - lognorm
if is_lower_bounded:
logcdf_trunc = pt.switch(value < lower, -np.inf, logcdf_trunc)
if is_upper_bounded:
logcdf_trunc = pt.switch(value <= upper, logcdf_trunc, 0.0)
if is_lower_bounded and is_upper_bounded:
logcdf_trunc = check_parameters(
logcdf_trunc,
pt.le(lower, upper),
msg="lower_bound <= upper_bound",
)
return logcdf_trunc
@_truncated.register(NormalRV)
def _truncated_normal(op, lower, upper, size, rng, old_size, dtype, mu, sigma):
return TruncatedNormal.dist(
mu=mu,
sigma=sigma,
lower=lower,
upper=upper,
rng=None, # Do not reuse rng to avoid weird dependencies
size=size,
dtype=dtype,
)