ยป Keras API reference / Layers API / Layer weight constraints

Layer weight constraints

Usage of constraints

Classes from the tf.keras.constraints module allow setting constraints (eg. non-negativity) on model parameters during training. They are per-variable projection functions applied to the target variable after each gradient update (when using fit()).

The exact API will depend on the layer, but the layers Dense, Conv1D, Conv2D and Conv3D have a unified API.

These layers expose two keyword arguments:

  • kernel_constraint for the main weights matrix
  • bias_constraint for the bias.
from tensorflow.keras.constraints import max_norm
model.add(Dense(64, kernel_constraint=max_norm(2.)))

Available weight constraints

[source]

MaxNorm class

tf.keras.constraints.MaxNorm(max_value=2, axis=0)

MaxNorm weight constraint.

Constrains the weights incident to each hidden unit to have a norm less than or equal to a desired value.

Also available via the shortcut function tf.keras.constraints.max_norm.

Arguments

  • max_value: the maximum norm value for the incoming weights.
  • axis: integer, axis along which to calculate weight norms. For instance, in a Dense layer the weight matrix has shape (input_dim, output_dim), set axis to 0 to constrain each weight vector of length (input_dim,). In a Conv2D layer with data_format="channels_last", the weight tensor has shape (rows, cols, input_depth, output_depth), set axis to [0, 1, 2] to constrain the weights of each filter tensor of size (rows, cols, input_depth).

[source]

MinMaxNorm class

tf.keras.constraints.MinMaxNorm(
    min_value=0.0, max_value=1.0, rate=1.0, axis=0
)

MinMaxNorm weight constraint.

Constrains the weights incident to each hidden unit to have the norm between a lower bound and an upper bound.

Also available via the shortcut function tf.keras.constraints.min_max_norm.

Arguments

  • min_value: the minimum norm for the incoming weights.
  • max_value: the maximum norm for the incoming weights.
  • rate: rate for enforcing the constraint: weights will be rescaled to yield (1 - rate) * norm + rate * norm.clip(min_value, max_value). Effectively, this means that rate=1.0 stands for strict enforcement of the constraint, while rate<1.0 means that weights will be rescaled at each step to slowly move towards a value inside the desired interval.
  • axis: integer, axis along which to calculate weight norms. For instance, in a Dense layer the weight matrix has shape (input_dim, output_dim), set axis to 0 to constrain each weight vector of length (input_dim,). In a Conv2D layer with data_format="channels_last", the weight tensor has shape (rows, cols, input_depth, output_depth), set axis to [0, 1, 2] to constrain the weights of each filter tensor of size (rows, cols, input_depth).

[source]

NonNeg class

tf.keras.constraints.NonNeg()

Constrains the weights to be non-negative.

Also available via the shortcut function tf.keras.constraints.non_neg.


[source]

UnitNorm class

tf.keras.constraints.UnitNorm(axis=0)

Constrains the weights incident to each hidden unit to have unit norm.

Also available via the shortcut function tf.keras.constraints.unit_norm.

Arguments

  • axis: integer, axis along which to calculate weight norms. For instance, in a Dense layer the weight matrix has shape (input_dim, output_dim), set axis to 0 to constrain each weight vector of length (input_dim,). In a Conv2D layer with data_format="channels_last", the weight tensor has shape (rows, cols, input_depth, output_depth), set axis to [0, 1, 2] to constrain the weights of each filter tensor of size (rows, cols, input_depth).

[source]

RadialConstraint class

tf.keras.constraints.RadialConstraint()

Constrains Conv2D kernel weights to be the same for each radius.

Also available via the shortcut function tf.keras.constraints.radial_constraint.

For example, the desired output for the following 4-by-4 kernel:

    kernel = [[v_00, v_01, v_02, v_03],
              [v_10, v_11, v_12, v_13],
              [v_20, v_21, v_22, v_23],
              [v_30, v_31, v_32, v_33]]

is this::

    kernel = [[v_11, v_11, v_11, v_11],
              [v_11, v_33, v_33, v_11],
              [v_11, v_33, v_33, v_11],
              [v_11, v_11, v_11, v_11]]

This constraint can be applied to any Conv2D layer version, including Conv2DTranspose and SeparableConv2D, and with either "channels_last" or "channels_first" data format. The method assumes the weight tensor is of shape (rows, cols, input_depth, output_depth).


Creating custom weight constraints

A weight constraint can be any callable that takes a tensor and returns a tensor with the same shape and dtype. You would typically implement your constraints as subclasses of tf.keras.constraints.Constraint.

Here's a simple example: a constraint that forces weight tensors to be centered around a specific value on average.

class CenterAround(tf.keras.constraints.Constraint):
  """Constrains weight tensors to be centered around `ref_value`."""

  def __init__(self, ref_value):
    self.ref_value = ref_value

  def __call__(self, w):
    mean = tf.reduce_mean(w)
    return w - mean + self.ref_value

  def get_config(self):
    return {'ref_value': self.ref_value}

Optionally, you an also implement the method get_config and the class method from_config in order to support serialization -- just like with any Keras object. Note that we don't have to implement from_config in the example above since the constructor arguments of the class the keys in the config returned by get_config are the same. In this case, the default from_config works fine.