handle_dims#

pymc_extras.prior.handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) pt.TensorVariable[source]#

Take a tensor of dims dims and align it to desired_dims.

Doesn’t check for validity of the dims

Parameters:
  • x (pt.TensorLike) – The tensor to align.

  • dims (Dims) – The current dimensions of the tensor.

  • desired_dims (Dims) – The desired dimensions of the tensor.

Returns:

The aligned tensor.

Return type:

pt.TensorVariable

Examples

Handle transpose 1D to 2D with new dimension.

x = np.array([1, 2, 3])
dims = "channel"

desired_dims = ("channel", "group")

handle_dims(x, dims, desired_dims)