handle_dims#
- pymc_extras.prior.handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) pt.TensorVariable[source]#
Take a tensor of dims dims and align it to desired_dims.
Doesn’t check for validity of the dims
- Parameters:
x (pt.TensorLike) – The tensor to align.
dims (Dims) – The current dimensions of the tensor.
desired_dims (Dims) – The desired dimensions of the tensor.
- Returns:
The aligned tensor.
- Return type:
pt.TensorVariable
Examples
Handle transpose 1D to 2D with new dimension.
x = np.array([1, 2, 3]) dims = "channel" desired_dims = ("channel", "group") handle_dims(x, dims, desired_dims)