# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from typing import Sequence, Tuple, Union
import numpy as np
import pymc as pm
import pytensor.tensor as pt
__all__ = ["R2D2M2CP"]
def _psivar2musigma(
psi: pt.TensorVariable,
explained_var: pt.TensorVariable,
psi_mask: Union[pt.TensorLike, None],
) -> Tuple[pt.TensorVariable, pt.TensorVariable]:
sign = pt.sign(psi - 0.5)
if psi_mask is not None:
# any computation might be ignored for ~psi_mask
# sign and explained_var are used
psi = pt.where(psi_mask, psi, 0.5)
pi = pt.erfinv(2 * psi - 1)
f = (1 / (2 * pi**2 + 1)) ** 0.5
sigma = explained_var**0.5 * f
mu = sigma * pi * 2**0.5
if psi_mask is not None:
return (
pt.where(psi_mask, mu, sign * explained_var**0.5),
pt.where(psi_mask, sigma, 0),
)
else:
return mu, sigma
def _R2D2M2CP_beta(
name: str,
output_sigma: pt.TensorVariable,
input_sigma: pt.TensorVariable,
r2: pt.TensorVariable,
phi: pt.TensorVariable,
psi: pt.TensorVariable,
*,
psi_mask,
dims: Union[str, Sequence[str]],
centered=False,
) -> pt.TensorVariable:
"""R2D2M2CP beta prior.
Parameters
----------
name: str
Name for the distribution
output_sigma: tensor
standard deviation of the outcome
input_sigma: tensor
standard deviation of the explanatory variables
r2: tensor
expected R2 for the linear regression
phi: tensor
variance weights that sums up to 1
psi: tensor
probability of a coefficients to be positive
"""
explained_variance = phi * pt.expand_dims(r2 * output_sigma**2, (-1,))
mu_param, std_param = _psivar2musigma(psi, explained_variance, psi_mask=psi_mask)
if not centered:
with pm.Model(name):
if psi_mask is not None and psi_mask.any():
# limit case where some probs are not 1 or 0
# setsubtensor is required
r_idx = psi_mask.nonzero()
with pm.Model("raw"):
raw = pm.Normal("masked", shape=len(r_idx[0]))
raw = pt.set_subtensor(pt.zeros_like(mu_param)[r_idx], raw)
raw = pm.Deterministic("raw", raw, dims=dims)
elif psi_mask is not None:
# all variables are deterministic
raw = pt.zeros_like(mu_param)
else:
raw = pm.Normal("raw", dims=dims)
beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims)
else:
if psi_mask is not None and psi_mask.any():
# limit case where some probs are not 1 or 0
# setsubtensor is required
r_idx = psi_mask.nonzero()
with pm.Model(name):
mean = (mu_param / input_sigma)[r_idx]
sigma = (std_param / input_sigma)[r_idx]
masked = pm.Normal(
"masked",
mean,
sigma,
shape=len(r_idx[0]),
)
beta = pt.set_subtensor(mean, masked)
beta = pm.Deterministic(name, beta, dims=dims)
elif psi_mask is not None:
# all variables are deterministic
beta = pm.Deterministic(name, (mu_param / input_sigma), dims=dims)
else:
beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims)
return beta
def _broadcast_as_dims(
*values: np.ndarray,
dims: Sequence[str],
) -> Union[Tuple[np.ndarray, ...], np.ndarray]:
model = pm.modelcontext(None)
shape = [len(model.coords[d]) for d in dims]
ret = tuple(np.broadcast_to(v, shape) for v in values)
# strip output
if len(values) == 1:
ret = ret[0]
return ret
def _psi_masked(
positive_probs: pt.TensorLike,
positive_probs_std: pt.TensorLike,
*,
dims: Sequence[str],
) -> Tuple[Union[pt.TensorLike, None], pt.TensorVariable]:
if not (
isinstance(positive_probs, pt.Constant) and isinstance(positive_probs_std, pt.Constant)
):
raise TypeError(
"Only constant values for positive_probs and positive_probs_std are accepted"
)
positive_probs, positive_probs_std = _broadcast_as_dims(
positive_probs.data, positive_probs_std.data, dims=dims
)
mask = ~np.bitwise_or(positive_probs == 1, positive_probs == 0)
if np.bitwise_and(~mask, positive_probs_std != 0).any():
raise ValueError("Can't have both positive_probs == '1 or 0' and positive_probs_std != 0")
if (~mask).any() and mask.any():
# limit case where some probs are not 1 or 0
# setsubtensor is required
r_idx = mask.nonzero()
with pm.Model("psi"):
psi = pm.Beta(
"masked",
mu=positive_probs[r_idx],
sigma=positive_probs_std[r_idx],
shape=len(r_idx[0]),
)
psi = pt.set_subtensor(pt.as_tensor(positive_probs)[r_idx], psi)
psi = pm.Deterministic("psi", psi, dims=dims)
elif (~mask).all():
# limit case where all the probs are limit case
psi = pt.as_tensor(positive_probs)
else:
psi = pm.Beta("psi", mu=positive_probs, sigma=positive_probs_std, dims=dims)
mask = None
return mask, psi
def _psi(
positive_probs: pt.TensorLike,
positive_probs_std: Union[pt.TensorLike, None],
*,
dims: Sequence[str],
) -> Tuple[Union[pt.TensorLike, None], pt.TensorVariable]:
if positive_probs_std is not None:
mask, psi = _psi_masked(
positive_probs=pt.as_tensor(positive_probs),
positive_probs_std=pt.as_tensor(positive_probs_std),
dims=dims,
)
else:
positive_probs = pt.as_tensor(positive_probs)
if not isinstance(positive_probs, pt.Constant):
raise TypeError("Only constant values for positive_probs are allowed")
psi = _broadcast_as_dims(positive_probs.data, dims=dims)
mask = np.atleast_1d(~np.bitwise_or(psi == 1, psi == 0))
if mask.all():
mask = None
return mask, psi
def _phi(
variables_importance: Union[pt.TensorLike, None],
variance_explained: Union[pt.TensorLike, None],
importance_concentration: Union[pt.TensorLike, None],
*,
dims: Sequence[str],
) -> pt.TensorVariable:
*broadcast_dims, dim = dims
model = pm.modelcontext(None)
if variables_importance is not None:
if variance_explained is not None:
raise TypeError("Can't use variable importance with variance explained")
if len(model.coords[dim]) <= 1:
raise TypeError("Can't use variable importance with less than two variables")
variables_importance = pt.as_tensor(variables_importance)
if importance_concentration is not None:
variables_importance *= importance_concentration
return pm.Dirichlet("phi", variables_importance, dims=broadcast_dims + [dim])
elif variance_explained is not None:
if len(model.coords[dim]) <= 1:
raise TypeError("Can't use variance explained with less than two variables")
phi = pt.as_tensor(variance_explained)
else:
phi = _broadcast_as_dims(1.0, dims=dims)
if importance_concentration is not None:
return pm.Dirichlet("phi", importance_concentration * phi, dims=broadcast_dims + [dim])
else:
return phi
R2D2M2CPOut = namedtuple("R2D2M2CPOut", ["eps", "beta"])
[docs]
def R2D2M2CP(
name: str,
output_sigma: pt.TensorLike,
input_sigma: pt.TensorLike,
*,
dims: Sequence[str],
r2: pt.TensorLike,
variables_importance: Union[pt.TensorLike, None] = None,
variance_explained: Union[pt.TensorLike, None] = None,
importance_concentration: Union[pt.TensorLike, None] = None,
r2_std: Union[pt.TensorLike, None] = None,
positive_probs: Union[pt.TensorLike, None] = 0.5,
positive_probs_std: Union[pt.TensorLike, None] = None,
centered: bool = False,
) -> R2D2M2CPOut:
"""R2D2M2CP Prior.
Parameters
----------
name : str
Name for the distribution
output_sigma : Tensor
Output standard deviation
input_sigma : Tensor
Input standard deviation
dims : Union[str, Sequence[str]]
Dims for the distribution
r2 : Tensor
:math:`R^2` estimate
variables_importance : Tensor, optional
Optional estimate for variables importance, positive, by default None
variance_explained : Tensor, optional
Alternative estimate for variables importance which is point estimate of
variance explained, should sum up to one, by default None
importance_concentration : Tensor, optional
Confidence around variance explained or variable importance estimate
r2_std : Tensor, optional
Optional uncertainty over :math:`R^2`, by default None
positive_probs : Tensor, optional
Optional probability of variables contribution to be positive, by default 0.5
positive_probs_std : Tensor, optional
Optional uncertainty over effect direction probability, by default None
centered : bool, optional
Centered or Non-Centered parametrization of the distribution, by default Non-Centered. Advised to check both
Returns
-------
residual_sigma, coefficients
Output variance (sigma squared) is split in residual variance and explained variance.
Raises
------
TypeError
If parametrization is wrong.
Notes
-----
The R2D2M2CP prior is a modification of R2D2M2 prior.
- ``(R2D2M2)`` CP is taken from https://arxiv.org/abs/2208.07132
- R2D2M2 ``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
Examples
--------
Here are arguments explained in a synthetic example
.. warning::
To use the prior in a linear regression
- make sure :math:`X` is centered around zero
- intercept represents prior predictive mean when :math:`X` is centered
- setting named dims is required
.. code-block:: python
import pymc_experimental as pmx
import pymc as pm
import numpy as np
X = np.random.randn(10, 3)
b = np.random.randn(3)
y = X @ b + np.random.randn(10) * 0.04 + 5
with pm.Model(coords=dict(variables=["a", "b", "c"])) as model:
eps, beta = pmx.distributions.R2D2M2CP(
"beta",
y.std(),
X.std(0),
dims="variables",
# NOTE: global shrinkage
r2=0.8,
# NOTE: if you are unsure about r2
r2_std=0.2,
# NOTE: if you know where a variable should go
# if you do not know, leave as 0.5
positive_probs=[0.8, 0.5, 0.1],
# NOTE: if you have different opinions about
# where a variable should go.
# NOTE: if you put 0.5 previously,
# just put 0.1 there, but other
# sigmas should work fine too
positive_probs_std=[0.3, 0.1, 0.2],
# NOTE: variable importances are relative to each other,
# but larget numbers put "more" weight in the relation
# use
# * 1-10 for small confidence
# * 10-30 for moderate confidence
# * 30+ for high confidence
# EXAMPLE:
# "a" - is likely to be useful
# "b" - no idea if it is useful
# "c" - a must have in the relation
variables_importance=[10, 1, 34],
# NOTE: try both
centered=True
)
# intercept prior centering should be around prior predictive mean
intercept = y.mean()
# regressors should be centered around zero
Xc = X - X.mean(0)
obs = pm.Normal("obs", intercept + Xc @ beta, eps, observed=y)
There can be special cases by choosing specific set of arguments
Here the prior distribution of beta is ``Normal(0, y.std() * r2 ** .5)``
.. code-block:: python
with pm.Model(coords=dict(variables=["a", "b", "c"])) as model:
eps, beta = pmx.distributions.R2D2M2CP(
"beta",
y.std(),
X.std(0),
dims="variables",
# NOTE: global shrinkage
r2=0.8,
# NOTE: if you are unsure about r2
r2_std=0.2,
# NOTE: if you know where a variable should go
# if you do not know, leave as 0.5
centered=False
)
# intercept prior centering should be around prior predictive mean
intercept = y.mean()
# regressors should be centered around zero
Xc = X - X.mean(0)
obs = pm.Normal("obs", intercept + Xc @ beta, eps, observed=y)
It is fine to leave some of the ``_std`` arguments unspecified.
You can also specify only ``positive_probs``, and all
the variables are assumed to explain same amount of variance (same importance)
.. code-block:: python
with pm.Model(coords=dict(variables=["a", "b", "c"])) as model:
eps, beta = pmx.distributions.R2D2M2CP(
"beta",
y.std(),
X.std(0),
dims="variables",
# NOTE: global shrinkage
r2=0.8,
# NOTE: if you are unsure about r2
r2_std=0.2,
# NOTE: if you know where a variable should go
# if you do not know, leave as 0.5
positive_probs=[0.8, 0.5, 0.1],
# NOTE: try both
centered=True
)
intercept = y.mean()
obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)
Notes
-----
To reference R2D2M2CP implementation, you can use the following bibtex entry:
.. code-block::
@misc{pymc-experimental-r2d2m2cp,
title = {pymc-devs/pymc-experimental: {P}ull {R}equest 137, {R2D2M2CP}},
url = {https://github.com/pymc-devs/pymc-experimental/pull/137},
author = {Max Kochurov},
howpublished = {GitHub},
year = {2023}
}
"""
if not isinstance(dims, (list, tuple)):
dims = (dims,)
*broadcast_dims, dim = dims
input_sigma = pt.as_tensor(input_sigma)
output_sigma = pt.as_tensor(output_sigma)
with pm.Model(name):
if r2_std is not None:
r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims)
phi = _phi(
variables_importance=variables_importance,
variance_explained=variance_explained,
importance_concentration=importance_concentration,
dims=dims,
)
mask, psi = _psi(
positive_probs=positive_probs, positive_probs_std=positive_probs_std, dims=dims
)
beta = _R2D2M2CP_beta(
name,
output_sigma,
input_sigma,
r2,
phi,
psi,
dims=broadcast_dims + [dim],
centered=centered,
psi_mask=mask,
)
resid_sigma = (1 - r2) ** 0.5 * output_sigma
return R2D2M2CPOut(resid_sigma, beta)