Source code for pytensor.tensor.special

from textwrap import dedent

import numpy as np
import scipy

from pytensor.graph.basic import Apply
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.elemwise import get_normalized_batch_axes
from pytensor.tensor.math import gamma, gammaln, log, neg, sum


class SoftmaxGrad(COp):
    """
    Gradient wrt x of the Softmax Op.

    """

    nin = 2
    nout = 1
    __props__ = ("axis",)

    def __init__(self, axis):
        if axis is not None and not isinstance(axis, int):
            raise TypeError("axis must be an integer or `None`")
        self.axis = axis

    def make_node(self, dy, sm):
        dy = as_tensor_variable(dy)
        sm = as_tensor_variable(sm)

        if self.axis is not None and (self.axis >= sm.ndim or self.axis < -sm.ndim):
            raise ValueError(
                f"SoftmaxGrad axis(={self.axis}) out of bounds for {sm.ndim}D array {sm}"
            )

        return Apply(self, [dy, sm], [sm.type()])

    def perform(self, node, input_storage, output_storage):
        dy, sm = input_storage

        dy_times_sm = dy * sm
        dx = dy_times_sm - np.sum(dy_times_sm, axis=self.axis, keepdims=True) * sm
        output_storage[0][0] = dx

    def grad(self, inp, grads):
        dy, sm = inp
        (g,) = grads

        tmp = g + neg(sum(g * sm, axis=self.axis, keepdims=True))
        g_dy = tmp * sm

        tmp2 = sum(dy * sm, axis=self.axis, keepdims=True)
        g_sm = tmp * dy - g * tmp2

        return g_dy, g_sm

    def infer_shape(self, fgraph, node, shape):
        return [shape[1]]

    def c_code_cache_version(self):
        return (4,)

    def c_code(self, node, name, inp, out, sub):
        dy, sm = inp
        (dx,) = out
        axis = self.axis if self.axis is not None else np.MAXDIMS
        fail = sub["fail"]

        return dedent(
            f"""
            PyArrayObject* op[3];
            npy_uint32 op_flags[3];
            npy_uint32 iter_flags;
            NpyIter* iter;
            NpyIter_IterNextFunc* get_next;
            char** data_ptr;

            int sm_ndim = PyArray_NDIM({sm});
            int axis = {axis};
            int iterate_axis = !(axis == NPY_MAXDIMS || sm_ndim == 1);

            // Validate inputs
            if ((PyArray_TYPE({dy}) != NPY_DOUBLE) &&
                (PyArray_TYPE({dy}) != NPY_FLOAT))
            {{
                PyErr_SetString(PyExc_TypeError, "types should be float or float64");
                {fail};
            }}
            if ((PyArray_TYPE({sm}) != NPY_DOUBLE) &&
                (PyArray_TYPE({sm}) != NPY_FLOAT))
            {{
                PyErr_SetString(PyExc_TypeError, "types should be float or float64");
                {fail};
            }}

            if (axis < 0) axis = sm_ndim + axis;
            if ((axis < 0) || (iterate_axis && (axis > sm_ndim)))
            {{
                PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad");
                {fail};
            }}

            if (({dx} == NULL)
                || !(PyArray_CompareLists(PyArray_DIMS({dx}), PyArray_DIMS({sm}), sm_ndim)))
            {{
                Py_XDECREF({dx});
                {dx} = (PyArrayObject*)PyArray_SimpleNew(sm_ndim,
                                                         PyArray_DIMS({sm}),
                                                         PyArray_TYPE({sm}));
                if (!{dx})
                {{
                    PyErr_SetString(PyExc_MemoryError, "failed to alloc SoftMaxGrad dx output");
                    {fail};
                }}
            }}

            // Create numpy iterator
            op[0] = {dy};
            op[1] = {sm};
            op[2] = {dx};
            op_flags[0] = NPY_ITER_READONLY;
            op_flags[1] = NPY_ITER_READONLY;
            op_flags[2] = NPY_ITER_READWRITE;
            iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
            iter = NpyIter_MultiNew(
                3,
                op,
                iter_flags,
                NPY_KEEPORDER,
                NPY_NO_CASTING,
                op_flags,
                NULL
            );

            if (iter == NULL)
            {{
                PyErr_SetString(PyExc_MemoryError, "failed to create softmax iterator");
                {fail};
            }}

            // SoftmaxGrad is applied across the entire array
            if (!iterate_axis)
            {{
                get_next = NpyIter_GetIterNext(iter, NULL);
                if (get_next == NULL)
                {{
                    NpyIter_Deallocate(iter);
                    PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftMaxGrad GetIterNext");
                    {fail};
                }}
                data_ptr = NpyIter_GetDataPtrArray(iter);

                // Compute and accumulate dy * sm
                dtype_{dx} sum_dy_times_sm = 0.0;
                do
                {{
                    dtype_{dy}* dy_ptr = (dtype_{dy}*)data_ptr[0];
                    dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
                    dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];

                    *dx_ptr = (dtype_{dx})((*dy_ptr) * (*sm_ptr));
                    sum_dy_times_sm += *dx_ptr;
                }} while(get_next(iter));

                // Reset Iterator
                if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
                {{
                    PyErr_SetString(PyExc_RuntimeError, "Failed to reset softmax iterator");
                    {fail};
                }}

                // Subtract sum(dy*sm) * sm
                do
                {{
                    dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
                    dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];
                    *dx_ptr -= sum_dy_times_sm * ((dtype_{dx})(*sm_ptr));
                }} while(get_next(iter));
            }}

            // SoftmaxGrad is applied across a specific axis
            else {{
                // Collect axis strides and remove it from iteration
                npy_intp axis_size = PyArray_DIM({sm}, axis);
                npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
                if  (axis_stride == NULL)
                {{
                    PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax axis strides");
                    {fail};
                }}
                npy_intp dy_axis_stride = axis_stride[0] / sizeof(dtype_{dy});
                npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
                npy_intp dx_axis_stride = axis_stride[2] / sizeof(dtype_{dx});

                if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
                {{
                    PyErr_SetString(PyExc_RuntimeError, "Failed to remove SoftmaxGrad axis from iterator");
                    {fail};
                }}

                // Iterate over remaining axes
                get_next = NpyIter_GetIterNext(iter, NULL);
                if (get_next == NULL)
                {{
                    NpyIter_Deallocate(iter);
                    PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftamGrad GetIterNext");
                    {fail};
                }}

                data_ptr = NpyIter_GetDataPtrArray(iter);
                do
                {{
                    dtype_{dy}* dy_axis = (dtype_{dy}*)data_ptr[0];
                    dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
                    dtype_{dx}* dx_axis = (dtype_{dx}*)data_ptr[2];

                    // Compute and accumulate dy * sm
                    dtype_{dx} sum_dy_times_sm = 0.0;
                    for (npy_intp i = 0; i < axis_size; i++)
                    {{
                        dx_axis[i * dx_axis_stride] = (dtype_{dx})(dy_axis[i * dy_axis_stride] * sm_axis[i * sm_axis_stride]);
                        sum_dy_times_sm += dx_axis[i * dx_axis_stride];
                    }}

                    // Subtract sum(dy*sm) * sm
                    for (npy_intp i = 0; i < axis_size; i++)
                    {{
                        dx_axis[i * dx_axis_stride] -= sum_dy_times_sm * (dtype_{dx})(sm_axis[i * sm_axis_stride]);
                    }}

                }} while(get_next(iter));
            }}
            NpyIter_Deallocate(iter);
            """
        )


class Softmax(COp):
    r"""
    Softmax activation function
    :math:`\\varphi(\\mathbf{x})_j =
    \\frac{e^{\mathbf{x}_j}}{\sum_{k=1}^K e^{\mathbf{x}_k}}`
    where :math:`K` is the total number of neurons in the layer. This
    activation function gets applied row-wise.

    """

    nin = 1
    nout = 1
    __props__ = ("axis",)

    def __init__(self, axis):
        if axis is not None and not isinstance(axis, int):
            raise TypeError("axis must be an integer or `None`")
        self.axis = axis

    def make_node(self, x):
        x = as_tensor_variable(x)

        if self.axis is not None and (self.axis >= x.ndim or self.axis < -x.ndim):
            raise ValueError(
                f"Softmax axis(={self.axis}) out of bounds for {x.ndim}D array {x}"
            )

        return Apply(self, [x], [x.type()])

    def perform(self, node, input_storage, output_storage):
        (x,) = input_storage
        (z,) = output_storage
        z[0] = scipy.special.softmax(x, axis=self.axis)

    def L_op(self, inp, outputs, grads):
        (x,) = inp
        (g_sm,) = grads
        return [SoftmaxGrad(axis=self.axis)(g_sm, outputs[0])]

    def R_op(self, inputs, eval_points):
        # I think the Jacobian is symmetric so the R_op
        # is the same as the grad
        if None in eval_points:
            return [None]
        return self.L_op(inputs, [self(*inputs)], eval_points)

    def infer_shape(self, fgraph, node, shape):
        return shape

    def c_headers(self, **kwargs):
        return ["<iostream>", "<cmath>"]

    def c_code(self, node, name, inp, out, sub):
        (x,) = inp
        (sm,) = out
        axis = self.axis if self.axis is not None else np.MAXDIMS
        fail = sub["fail"]
        # dtype = node.inputs[0].type.dtype_specs()[1]
        # TODO: put this into a templated function, in the support code
        # TODO: declare the max of each row as an Op output
        # TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
        return dedent(
            f"""
            PyArrayObject* op[2];
            npy_uint32 op_flags[2];
            npy_uint32 iter_flags;
            NpyIter* iter;
            NpyIter_IterNextFunc* get_next;
            char** data_ptr;

            int x_ndim = PyArray_NDIM({x});
            int axis = {axis};
            int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);

            // Validate inputs
            if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
                (PyArray_TYPE({x}) != NPY_FLOAT))
            {{
                PyErr_SetString(PyExc_TypeError, "not a float");
                {fail}
            }}

            if (axis < 0) axis = x_ndim + axis;
            if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
            {{
                PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax");
                {fail}
            }}

            // Allocate Output Array
            if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim)))
            {{
                Py_XDECREF({sm});
                {sm} = (PyArrayObject*)PyArray_SimpleNew(x_ndim, PyArray_DIMS({x}), PyArray_TYPE({x}));
                if(!{sm}) {{
                    PyErr_SetString(PyExc_MemoryError, "failed to alloc Softmax output");
                    {fail}
                }}
            }}

            // Create numpy iterator
            op[0] = {x};
            op[1] = {sm};
            op_flags[0] = NPY_ITER_READONLY;
            op_flags[1] = NPY_ITER_READWRITE;
            iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
            iter = NpyIter_MultiNew(
                2,
                op,
                iter_flags,
                NPY_KEEPORDER,
                NPY_NO_CASTING,
                op_flags,
                NULL
            );

            if (iter == NULL)
            {{
                PyErr_SetString(PyExc_MemoryError, "failed to create Softmax iterator");
                {fail}
            }}

            // Softmax is applied across the entire array
            if (!iterate_axis)
            {{
                get_next = NpyIter_GetIterNext(iter, NULL);
                if (get_next == NULL)
                {{
                    NpyIter_Deallocate(iter);
                    PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax GetIterNext");
                    {fail}
                }}
                data_ptr = NpyIter_GetDataPtrArray(iter);

                // Find axis max
                dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
                dtype_{x} max = *x_ptr;
                if (get_next(iter))
                {{
                    do
                    {{
                        dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
                        max = (*x_ptr > max)? *x_ptr : max;
                    }} while(get_next(iter));
                }}

                // Reset Iterator
                if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
                {{
                    PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
                    {fail}
                }}

                // Compute and accumulate exp(x-max(x)) exponent
                double sum_exp_dev = 0.0;
                do
                {{
                    dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
                    dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
                    *sm_ptr = (dtype_{sm}) exp(*x_ptr - max);
                    sum_exp_dev += *sm_ptr;
                }} while(get_next(iter));

                // Reset Iterator
                if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
                {{
                    PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
                    {fail}
                }}

                // Divide by sum(exp(x-max(x)))
                double inv_sum_exp_dev = 1.0 / sum_exp_dev;
                do
                {{
                    dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
                    *sm_ptr *= inv_sum_exp_dev;
                }} while(get_next(iter));
            }}

            // Softmax is applied across a specific axis
            else {{
                // Collect axis strides and remove it from iteration
                npy_intp axis_size = PyArray_DIM({x}, axis);
                npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
                if  (axis_stride == NULL)
                {{
                    PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax axis strides");
                    {fail}
                }}
                npy_intp x_axis_stride = axis_stride[0] / sizeof(dtype_{x});
                npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});

                if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
                {{
                    PyErr_SetString(PyExc_RuntimeError, "Failed to remove softmax axis from iterator");
                    {fail}
                }}

                // Iterate over remaining axes
                get_next = NpyIter_GetIterNext(iter, NULL);
                if (get_next == NULL)
                {{
                    NpyIter_Deallocate(iter);
                    PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax GetIterNext");
                    {fail}
                }}

                data_ptr = NpyIter_GetDataPtrArray(iter);
                do
                {{
                    dtype_{x}* x_axis = (dtype_{x}*)data_ptr[0];
                    dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];

                    // Find axis max
                    dtype_{x} max = x_axis[0];
                    for (npy_intp i = 1; i < axis_size; i++)
                    {{
                        dtype_{x} x_val = x_axis[i * x_axis_stride];
                        max = (x_val > max)? x_val : max;
                    }}

                    // Compute and accumulate exp(x-max(x)) exponent
                    dtype_{sm} sum_exp_dev = 0.0;
                    for (npy_intp i = 0; i < axis_size; i++)
                    {{
                        sm_axis[i * sm_axis_stride] = (dtype_{sm}) exp(x_axis[i * x_axis_stride] - max);
                        sum_exp_dev += sm_axis[i * sm_axis_stride];
                    }}

                    // Divide by sum(exp(x-max(x)))
                    dtype_{sm} inv_sum_exp_dev = 1.0 / sum_exp_dev;
                    for (npy_intp i = 0; i < axis_size; i++)
                    {{
                        sm_axis[i * sm_axis_stride] *= inv_sum_exp_dev;
                    }}

                }} while(get_next(iter));
            }}
            NpyIter_Deallocate(iter);
            """
        )

    @staticmethod
    def c_code_cache_version():
        return (4,)


[docs] def softmax(c, axis=None): c = as_tensor_variable(c) return Softmax(axis=axis)(c)
class LogSoftmax(COp): r""" LogSoftmax activation function :math:`\\varphi(\\mathbf{x})_j = \\e^{(\mathbf{x}_j - log{\sum_{k=1}^K e^{\mathbf{x}_k})}} where :math:`K` is the total number of neurons in the layer. This activation function gets applied row-wise. """ nin = 1 nout = 1 __props__ = ("axis",) def __init__(self, axis): if axis is not None and not isinstance(axis, int): raise TypeError("axis must be an integer or `None`") self.axis = axis def make_node(self, x): x = as_tensor_variable(x) if self.axis is not None and (self.axis >= x.ndim or self.axis < -x.ndim): raise ValueError( f"LogSoftmax axis(={self.axis}) out of bounds for {x.ndim}D array {x}" ) return Apply(self, [x], [x.type()]) def perform(self, node, input_storage, output_storage): (x,) = input_storage (z,) = output_storage z[0] = scipy.special.log_softmax(x, axis=self.axis) def grad(self, inp, grads): (x,) = inp sm = Softmax(axis=self.axis)(x) return [grads[0] - sum(grads[0], axis=self.axis, keepdims=True) * sm] def R_op(self, inputs, eval_points): # I think the Jacobian is symmetric so the R_op # is the same as the grad if None in eval_points: return [None] return self.grad(inputs, eval_points) def infer_shape(self, fgraph, node, shape): return shape def c_headers(self, **kwargs): return ["<cmath>"] def c_code(self, node, name, inp, out, sub): (x,) = inp (sm,) = out axis = self.axis if self.axis is not None else np.MAXDIMS fail = sub["fail"] return dedent( f""" PyArrayObject* op[2]; npy_uint32 op_flags[2]; npy_uint32 iter_flags; NpyIter* iter; NpyIter_IterNextFunc* get_next; char** data_ptr; int x_ndim = PyArray_NDIM({x}); int axis = {axis}; int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1); // Validate inputs if ((PyArray_TYPE({x}) != NPY_DOUBLE) && (PyArray_TYPE({x}) != NPY_FLOAT)) {{ PyErr_SetString(PyExc_TypeError, "not a float"); {fail} }} if (axis < 0) axis = x_ndim + axis; if ((axis < 0) || (iterate_axis && (axis > x_ndim))) {{ PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax"); {fail} }} // Allocate Output Array if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim))) {{ Py_XDECREF({sm}); {sm} = (PyArrayObject*)PyArray_SimpleNew(x_ndim, PyArray_DIMS({x}), PyArray_TYPE({x})); if(!{sm}) {{ PyErr_SetString(PyExc_MemoryError, "failed to alloc LogSoftmax output"); {fail} }} }} // Create numpy iterator op[0] = {x}; op[1] = {sm}; op_flags[0] = NPY_ITER_READONLY; op_flags[1] = NPY_ITER_READWRITE; iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0; iter = NpyIter_MultiNew( 2, op, iter_flags, NPY_KEEPORDER, NPY_NO_CASTING, op_flags, NULL ); if (iter == NULL) {{ PyErr_SetString(PyExc_MemoryError, "failed to create LogSoftmax iterator"); {fail} }} // LogSoftmax is applied across the entire array if (!iterate_axis) {{ get_next = NpyIter_GetIterNext(iter, NULL); if (get_next == NULL) {{ NpyIter_Deallocate(iter); PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax GetIterNext"); {fail} }} data_ptr = NpyIter_GetDataPtrArray(iter); // Find axis max dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0]; dtype_{x} max = *x_ptr; if (get_next(iter)) {{ do {{ dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0]; max = (*x_ptr > max)? *x_ptr : max; }} while(get_next(iter)); }} // Reset Iterator if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL) {{ PyErr_SetString(PyExc_RuntimeError, "Failed to reset LogSoftmax iterator"); {fail} }} // Compute xdev and sum(exp(xdev)) dtype_{sm} sum_exp_xdev = 0.0; do {{ dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0]; dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1]; *sm_ptr = (dtype_{sm})((*x_ptr) - max); sum_exp_xdev += exp(*sm_ptr); }} while(get_next(iter)); // Reset Iterator if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL) {{ PyErr_SetString(PyExc_RuntimeError, "Failed to reset LogSoftmax iterator"); {fail} }} // Subtract log(sum(exp(xdev))) dtype_{sm} log_sum_exp_xdev = log(sum_exp_xdev); do {{ dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1]; *sm_ptr -= log_sum_exp_xdev; }} while(get_next(iter)); }} // LogSoftmax is applied across a specific axis else {{ // Collect axis strides and remove it from iteration npy_intp axis_size = PyArray_DIM({x}, axis); npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis); if (axis_stride == NULL) {{ PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax axis strides"); {fail} }} npy_intp x_axis_stride = axis_stride[0] / sizeof(dtype_{x}); npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm}); if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL) {{ PyErr_SetString(PyExc_RuntimeError, "Failed to remove LogSoftmax axis from iterator"); {fail} }} // Iterate over remaining axes get_next = NpyIter_GetIterNext(iter, NULL); if (get_next == NULL) {{ NpyIter_Deallocate(iter); PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax GetIterNext"); {fail} }} data_ptr = NpyIter_GetDataPtrArray(iter); do {{ dtype_{x}* x_axis = (dtype_{x}*)data_ptr[0]; dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1]; // Find axis max dtype_{x} max = x_axis[0]; for (npy_intp i = 1; i < axis_size; i++) {{ dtype_{x} x_val = x_axis[i * x_axis_stride]; max = (x_val > max)? x_val : max; }} // Compute xdev and sum(exp(xdev)) dtype_{sm} sum_exp_xdev = 0.0; for (npy_intp i = 0; i < axis_size; i++) {{ sm_axis[i * sm_axis_stride] = (dtype_{x})(x_axis[i * x_axis_stride] - max); sum_exp_xdev += exp(sm_axis[i * sm_axis_stride]); }} // Subtract log(sum(exp(xdev)) dtype_{sm} log_sum_exp_xdev = log(sum_exp_xdev); for (npy_intp i = 0; i < axis_size; i++) {{ sm_axis[i * sm_axis_stride] -= log_sum_exp_xdev; }} }} while(get_next(iter)); }} NpyIter_Deallocate(iter); """ ) @staticmethod def c_code_cache_version(): return (1,)
[docs] def log_softmax(c, axis=None): c = as_tensor_variable(c) return LogSoftmax(axis=axis)(c)
@_vectorize_node.register(Softmax) @_vectorize_node.register(LogSoftmax) def vectorize_softmax_node(op, node, batched_x): """ Vectorize Softmax and LogSoftmax nodes. """ core_ndim = node.inputs[0].type.ndim batch_ndim = batched_x.type.ndim - core_ndim if not batch_ndim: return op.make_node(batched_x) batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim) if len(batch_axes) > 1: from pytensor.tensor.blockwise import vectorize_node_fallback # The softmax Ops only allow a specific axis (integer) or all axis (None). # If the vectorized operation requires more than one axis we have to default to a Blockwise return vectorize_node_fallback(op, node, batched_x) [batch_axis] = batch_axes return type(op)(axis=batch_axis).make_node(batched_x) def poch(z, m): """ Pochhammer symbol (rising factorial) function. """ return gamma(z + m) / gamma(z) def factorial(n): """ Factorial function of a scalar or array of numbers. """ return gamma(n + 1) def logit(x): """ Logit function. """ return log(x / (1 - x)) def beta(a, b): """ Beta function. """ return (gamma(a) * gamma(b)) / gamma(a + b) def betaln(a, b): """ Log beta function. """ return gammaln(a) + gammaln(b) - gammaln(a + b) __all__ = [ "softmax", "log_softmax", "poch", "factorial", "logit", "beta", "betaln", ]