register_tensor_transform#

pymc_extras.prior.register_tensor_transform(name: str, transform: Transform) None[source]#

Register a tensor transform function to be used in the Prior class.

Parameters:
  • name (str) – The name of the transform.

  • func (Callable[[pt.TensorLike], pt.TensorLike]) – The function to apply to the tensor.

Examples

Register a custom transform function.

from pymc_extras.prior import (
    Prior,
    register_tensor_transform,
)

def custom_transform(x):
    return x**2


register_tensor_transform("square", custom_transform)

custom_distribution = Prior("Normal", transform="square")