pymc.norm_constraint(tensor_var, max_norm, norm_axes=None, epsilon=1e-07)[source]#

Max weight norm constraints and gradient clipping

This takes a TensorVariable and rescales it so that incoming weight norms are below a specified constraint value. Vectors violating the constraint are rescaled so that they are within the allowed range.

tensor_var: TensorVariable

PyTensor expression for update, gradient, or other quantity.

max_norm: scalar

This value sets the maximum allowed value of any norm in tensor_var.

norm_axes: sequence (list or tuple)

The axes over which to compute the norm. This overrides the default norm axes defined for the number of dimensions in tensor_var. When this is not specified and tensor_var is a matrix (2D), this is set to (0,). If tensor_var is a 3D, 4D or 5D tensor, it is set to a tuple listing all axes but axis 0. The former default is useful for working with dense layers, the latter is useful for 1D, 2D and 3D convolutional layers. (Optional)

epsilon: scalar, optional

Value used to prevent numerical instability when dividing by very small or zero norms.


Input tensor_var with rescaling applied to weight vectors that violate the specified constraints.


When norm_axes is not specified, the axes over which the norm is computed depend on the dimensionality of the input variable. If it is 2D, it is assumed to come from a dense layer, and the norm is computed over axis 0. If it is 3D, 4D or 5D, it is assumed to come from a convolutional layer and the norm is computed over all trailing axes beyond axis 0. For other uses, you should explicitly specify the axes over which to compute the norm using norm_axes.


>>> param = pytensor.shared(
...     np.random.randn(100, 200).astype(pytensor.config.floatX))
>>> update = param + 100
>>> update = norm_constraint(update, 10)
>>> func = pytensor.function([], [], updates=[(param, update)])
>>> # Apply constrained update
>>> _ = func()
>>> from lasagne.utils import compute_norms
>>> norms = compute_norms(param.get_value())
>>> np.isclose(np.max(norms), 10)