# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from functools import partial, reduce
import numpy as np
import pytensor
import pytensor.sparse
import pytensor.tensor as pt
import pytensor.tensor.slinalg
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import (
abs,
and_,
arccos,
arccosh,
arcsin,
arcsinh,
arctan,
arctanh,
broadcast_to,
ceil,
clip,
concatenate,
constant,
cos,
cosh,
cumprod,
cumsum,
dot,
eq,
erf,
erfc,
erfcinv,
erfinv,
exp,
flatten,
floor,
full,
full_like,
ge,
gt,
le,
log,
log1pexp,
logaddexp,
logsumexp,
lt,
matmul,
max,
maximum,
mean,
min,
minimum,
neq,
ones,
ones_like,
or_,
prod,
round,
sgn,
sigmoid,
sin,
sinh,
sqr,
sqrt,
stack,
sum,
switch,
tan,
tanh,
where,
zeros,
zeros_like,
)
from pytensor.tensor.linalg import solve_triangular
from pytensor.tensor.nlinalg import matrix_inverse
from pytensor.tensor.special import log_softmax, softmax
from pymc.pytensorf import floatX
__all__ = [
"abs",
"and_",
"arccos",
"arccosh",
"arcsin",
"arcsinh",
"arctan",
"arctanh",
"broadcast_to",
"ceil",
"clip",
"concatenate",
"constant",
"cos",
"cosh",
"cumprod",
"cumsum",
"dot",
"eq",
"erf",
"erfc",
"erfcinv",
"erfinv",
"exp",
"full",
"full_like",
"flatten",
"floor",
"ge",
"gt",
"le",
"log",
"log1pexp",
"logaddexp",
"logsumexp",
"lt",
"matmul",
"max",
"maximum",
"mean",
"min",
"minimum",
"neq",
"ones",
"ones_like",
"or_",
"prod",
"round",
"sgn",
"sigmoid",
"sin",
"sinh",
"sqr",
"sqrt",
"stack",
"sum",
"switch",
"tan",
"tanh",
"where",
"zeros",
"zeros_like",
"kronecker",
"cartesian",
"kron_dot",
"kron_solve_lower",
"kron_solve_upper",
"kron_diag",
"flat_outer",
"logdiffexp",
"invlogit",
"softmax",
"log_softmax",
"logbern",
"logit",
"log1mexp",
"flatten_list",
"logdet",
"probit",
"invprobit",
"expand_packed_triangular",
"batched_diag",
"block_diagonal",
"round",
]
[docs]
def kronecker(*Ks):
r"""Return the Kronecker product of arguments.
math:`K_1 \otimes K_2 \otimes ... \otimes K_D`
Parameters
----------
Ks : Iterable of 2D array-like
Arrays of which to take the product.
Returns
-------
np.ndarray :
Block matrix Kroncker product of the argument matrices.
"""
return reduce(pt.slinalg.kron, Ks)
[docs]
def cartesian(*arrays):
"""Make the Cartesian product of arrays.
Parameters
----------
arrays: N-D array-like
N-D arrays where earlier arrays loop more slowly than later ones
"""
N = len(arrays)
arrays_np = [np.asarray(x) for x in arrays]
arrays_2d = [x[:, None] if np.asarray(x).ndim == 1 else x for x in arrays_np]
arrays_integer = [np.arange(len(x)) for x in arrays_2d]
product_integers = np.stack(np.meshgrid(*arrays_integer, indexing="ij"), -1).reshape(-1, N)
return np.concatenate(
[array[product_integers[:, i]] for i, array in enumerate(arrays_2d)], axis=-1
)
def kron_matrix_op(krons, m, op):
r"""Apply op to krons and m in a way that reproduces ``op(kronecker(*krons), m)``.
Parameters
----------
krons : list of square 2D array-like objects
D square matrices :math:`[A_1, A_2, ..., A_D]` to be Kronecker'ed
:math:`A = A_1 \otimes A_2 \otimes ... \otimes A_D`
Product of column dimensions must be :math:`N`
m : NxM array or 1D array (treated as Nx1)
Object that krons act upon
Returns
-------
numpy array
"""
def flat_matrix_op(flat_mat, mat):
Nmat = mat.shape[1]
flat_shape = flat_mat.shape
mat2 = flat_mat.reshape((Nmat, -1))
return op(mat, mat2).T.reshape(flat_shape)
def kron_vector_op(v):
return reduce(flat_matrix_op, krons, v)
if m.ndim == 1:
m = m[:, None] # Treat 1D array as Nx1 matrix
if m.ndim != 2: # Has not been tested otherwise
raise ValueError(f"m must have ndim <= 2, not {m.ndim}")
result = kron_vector_op(m)
result_shape = result.shape
return pt.reshape(result, (result_shape[1], result_shape[0])).T
# Define kronecker functions that work on 1D and 2D arrays
kron_dot = partial(kron_matrix_op, op=pt.dot)
kron_solve_lower = partial(kron_matrix_op, op=partial(solve_triangular, lower=True))
kron_solve_upper = partial(kron_matrix_op, op=partial(solve_triangular, lower=False))
[docs]
def flat_outer(a, b):
return pt.outer(a, b).ravel()
[docs]
def kron_diag(*diags):
"""Return diagonal of a kronecker product.
Parameters
----------
diags: 1D arrays
The diagonals of matrices that are to be Kroneckered
"""
return reduce(flat_outer, diags)
[docs]
def logdiffexp(a, b):
"""Return log(exp(a) - exp(b))."""
return a + pt.log1mexp(b - a)
def logdiffexp_numpy(a, b):
"""Return log(exp(a) - exp(b))."""
warnings.warn(
"pymc.math.logdiffexp_numpy is being deprecated.",
FutureWarning,
stacklevel=2,
)
return a + log1mexp_numpy(b - a, negative_input=True)
invlogit = sigmoid
[docs]
def logbern(log_p, rng=None):
if np.isnan(log_p):
raise FloatingPointError("log_p can't be nan.")
return np.log((rng or np.random).uniform()) < log_p
[docs]
def logit(p):
return pt.log(p / (floatX(1) - p))
[docs]
def log1mexp(x, *, negative_input=False):
r"""Return log(1 - exp(-x)).
This function is numerically more stable than the naive approach.
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
References
----------
.. [Machler2012] Martin Mächler (2012).
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
"""
if not negative_input:
warnings.warn(
"pymc.math.log1mexp will expect a negative input in a future "
"version of PyMC.\n To suppress this warning set `negative_input=True`",
FutureWarning,
stacklevel=2,
)
x = -x
return pt.log1mexp(x)
def log1mexp_numpy(x, *, negative_input=False):
"""Return log(1 - exp(x)).
This function is numerically more stable than the naive approach.
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
"""
warnings.warn(
"pymc.math.log1mexp_numpy is being deprecated.",
FutureWarning,
stacklevel=2,
)
x = np.asarray(x, dtype="float")
if not negative_input:
warnings.warn(
"pymc.math.log1mexp_numpy will expect a negative input in a future "
"version of PyMC.\n To suppress this warning set `negative_input=True`",
FutureWarning,
stacklevel=2,
)
x = -x
out = np.empty_like(x)
mask = x < -0.6931471805599453 # log(1/2)
out[mask] = np.log1p(-np.exp(x[mask]))
mask = ~mask
out[mask] = np.log(-np.expm1(x[mask]))
return out
[docs]
def flatten_list(tensors):
return pt.concatenate([var.ravel() for var in tensors])
class LogDet(Op):
r"""Compute the logarithm of the absolute determinant of a square matrix M, log(abs(det(M))) on the CPU.
Avoids det(M) overflow/underflow.
Notes
-----
Once PR #3959 (https://github.com/Theano/Theano/pull/3959/) by harpone is merged,
this must be removed.
"""
def make_node(self, x):
x = pytensor.tensor.as_tensor_variable(x)
o = pytensor.tensor.scalar(dtype=x.dtype)
return Apply(self, [x], [o])
def perform(self, node, inputs, outputs, params=None):
try:
(x,) = inputs
(z,) = outputs
s = np.linalg.svd(x, compute_uv=False)
log_det = np.sum(np.log(np.abs(s)))
z[0] = np.asarray(log_det, dtype=x.dtype)
except Exception:
raise ValueError(f"Failed to compute logdet of {x}.")
def grad(self, inputs, g_outputs):
[gz] = g_outputs
[x] = inputs
return [gz * matrix_inverse(x).T]
def __str__(self):
return "LogDet"
logdet = LogDet()
[docs]
def probit(p):
return -sqrt(2.0) * erfcinv(2.0 * p)
[docs]
def invprobit(x):
return 0.5 * erfc(-x / sqrt(2.0))
[docs]
def expand_packed_triangular(n, packed, lower=True, diagonal_only=False):
r"""Convert a packed triangular matrix into a two dimensional array.
Triangular matrices can be stored with better space efficiency by
storing the non-zero values in a one-dimensional array. We number
the elements by row like this (for lower or upper triangular matrices):
[[0 - - -] [[0 1 2 3]
[1 2 - -] [- 4 5 6]
[3 4 5 -] [- - 7 8]
[6 7 8 9]] [- - - 9]
Parameters
----------
n: int
The number of rows of the triangular matrix.
packed: pytensor.vector
The matrix in packed format.
lower: bool, default=True
If true, assume that the matrix is lower triangular.
diagonal_only: bool
If true, return only the diagonal of the matrix.
"""
if packed.ndim != 1:
raise ValueError("Packed triangular is not one dimensional.")
if not isinstance(n, int):
raise TypeError("n must be an integer")
if diagonal_only and lower:
diag_idxs = np.arange(1, n + 1).cumsum() - 1
return packed[diag_idxs]
elif diagonal_only and not lower:
diag_idxs = np.arange(2, n + 2)[::-1].cumsum() - n - 1
return packed[diag_idxs]
elif lower:
out = pt.zeros((n, n), dtype=pytensor.config.floatX)
idxs = np.tril_indices(n)
# tag as lower triangular to enable pytensor rewrites
out = pt.set_subtensor(out[idxs], packed)
out.tag.lower_triangular = True
return out
elif not lower:
out = pt.zeros((n, n), dtype=pytensor.config.floatX)
idxs = np.triu_indices(n)
# tag as upper triangular to enable pytensor rewrites
out = pt.set_subtensor(out[idxs], packed)
out.tag.upper_triangular = True
return out
class BatchedDiag(Op):
"""Fast BatchedDiag allocation."""
__props__ = ()
def make_node(self, diag):
diag = pt.as_tensor_variable(diag)
if diag.type.ndim != 2:
raise TypeError("data argument must be a matrix", diag.type)
return Apply(self, [diag], [pt.tensor3(dtype=diag.dtype)])
def perform(self, node, ins, outs, params=None):
(C,) = ins
(z,) = outs
bc = C.shape[0]
dim = C.shape[-1]
Cd = np.zeros((bc, dim, dim), C.dtype)
bidx = np.repeat(np.arange(bc), dim)
didx = np.tile(np.arange(dim), bc)
Cd[bidx, didx, didx] = C.flatten()
z[0] = Cd
def grad(self, inputs, gout):
(gz,) = gout
idx = pt.arange(gz.shape[-1])
return [gz[..., idx, idx]]
def infer_shape(self, fgraph, nodes, shapes):
return [(shapes[0][0],) + (shapes[0][1],) * 2]
[docs]
def batched_diag(C):
C = pt.as_tensor(C)
dim = C.shape[-1]
if C.ndim == 2:
# diag -> matrices
return BatchedDiag()(C)
elif C.ndim == 3:
# matrices -> diag
idx = pt.arange(dim)
return C[..., idx, idx]
else:
raise ValueError("Input should be 2 or 3 dimensional")
[docs]
def block_diagonal(matrices, sparse=False, format="csr"):
r"""See pt.slinalg.block_diag or pytensor.sparse.basic.block_diag for reference.
Parameters
----------
matrices: tensors
format: str (default 'csr')
must be one of: 'csr', 'csc'
sparse: bool (default False)
if True return sparse format
Returns
-------
matrix
"""
warnings.warn(
"pymc.math.block_diagonal is deprecated in favor of `pytensor.tensor.linalg.block_diag` and `pytensor.sparse.block_diag` functions. This function will be removed in a future release",
)
if len(matrices) == 1: # graph optimization
return matrices[0]
if sparse:
return pytensor.sparse.basic.block_diag(*matrices, format=format)
else:
return pt.slinalg.block_diag(*matrices)