pymc.norm_constraint#
- pymc.norm_constraint(tensor_var, max_norm, norm_axes=None, epsilon=1e-07)[source]#
Max weight norm constraints and gradient clipping
This takes a TensorVariable and rescales it so that incoming weight norms are below a specified constraint value. Vectors violating the constraint are rescaled so that they are within the allowed range.
- Parameters:
- tensor_var: TensorVariable
PyTensor expression for update, gradient, or other quantity.
- max_norm: scalar
This value sets the maximum allowed value of any norm in tensor_var.
- norm_axes: sequence (list or tuple)
The axes over which to compute the norm. This overrides the default norm axes defined for the number of dimensions in tensor_var. When this is not specified and tensor_var is a matrix (2D), this is set to (0,). If tensor_var is a 3D, 4D or 5D tensor, it is set to a tuple listing all axes but axis 0. The former default is useful for working with dense layers, the latter is useful for 1D, 2D and 3D convolutional layers. (Optional)
- epsilon: scalar, optional
Value used to prevent numerical instability when dividing by very small or zero norms.
- Returns:
TensorVariable
Input tensor_var with rescaling applied to weight vectors that violate the specified constraints.
Notes
When norm_axes is not specified, the axes over which the norm is computed depend on the dimensionality of the input variable. If it is 2D, it is assumed to come from a dense layer, and the norm is computed over axis 0. If it is 3D, 4D or 5D, it is assumed to come from a convolutional layer and the norm is computed over all trailing axes beyond axis 0. For other uses, you should explicitly specify the axes over which to compute the norm using norm_axes.
Examples
>>> param = pytensor.shared( ... np.random.randn(100, 200).astype(pytensor.config.floatX)) >>> update = param + 100 >>> update = norm_constraint(update, 10) >>> func = pytensor.function([], [], updates=[(param, update)]) >>> # Apply constrained update >>> _ = func() >>> from lasagne.utils import compute_norms >>> norms = compute_norms(param.get_value()) >>> np.isclose(np.max(norms), 10) True