Author: Sayak Paul
Date created: 2021/04/13
Last modified: 2026/04/30
Description: Training with consistency regularization for robustness against data distribution shifts.
Deep learning models excel in many image recognition tasks when the data is independent and identically distributed (i.i.d.). However, they can suffer from performance degradation caused by subtle distribution shifts in the input data (such as random noise, contrast change, and blurring). So, naturally, there arises a question of why. As discussed in A Fourier Perspective on Model Robustness in Computer Vision, there's no reason for deep learning models to be robust against such shifts. Standard model training procedures (such as standard image classification training workflows) don't enable a model to learn beyond what's fed to it in the form of training data.
In this example, we will be training an image classification model enforcing a sense of consistency inside it by doing the following:
This overall training workflow finds its roots in works like FixMatch, Unsupervised Data Augmentation for Consistency Training, and Noisy Student Training. Since this training process encourages a model yield consistent predictions for clean as well as noisy images, it's often referred to as consistency training or training with consistency regularization. Although the example focuses on using consistency training to enhance the robustness of models to common corruptions this example can also serve a template for performing weakly supervised learning.
import keras
from keras import layers
from keras import ops
import numpy as np
import random
import matplotlib.pyplot as plt
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import EarlyStopping
from keras.optimizers import Adam
from keras.utils import PyDataset
# Set seeds for reproducibility
np.random.seed(42)
random.seed(42)
BATCH_SIZE = 128
EPOCHS = 5
NUM_CLASSES = 10
CROP_TO = 72
RESIZE_TO = 96
TEMPERATURE = 10
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
val_samples = 49500
train_x, train_y = x_train[:val_samples], y_train[:val_samples]
val_x, val_y = x_train[val_samples:], y_train[val_samples:]
train_y = train_y.reshape(-1)
val_y = val_y.reshape(-1)
y_test = y_test.reshape(-1)
Dataset objectsaugment = keras.Sequential(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomContrast(0.1),
]
)
For training the teacher model, we will only be using two geometric augmentation transforms: random horizontal flip and random crop.
def preprocess_train(image, label, noisy=True):
image = ops.cast(image, "float32")
# We first resize the original image to a larger dimension
# and then we take random crops from it.
image = ops.image.resize(image, (RESIZE_TO, RESIZE_TO))
x = keras.random.randint((), 0, RESIZE_TO - CROP_TO)
y = keras.random.randint((), 0, RESIZE_TO - CROP_TO)
image = image[x : x + CROP_TO, y : y + CROP_TO, :]
if keras.random.uniform(()) > 0.5:
image = ops.flip(image, axis=1)
if noisy:
image = augment(image)
return np.array(image), label
def preprocess_test(image, label):
image = ops.cast(image, "float32")
image = ops.image.resize(image, (CROP_TO, CROP_TO))
return image, label
We make sure train_clean_ds and train_noisy_ds are shuffled using the same seed to
ensure their orders are exactly the same. This will be helpful during training the
student model.
# This dataset will be used to train the first model.
class TeacherDataset(PyDataset):
def __init__(self, x, y, batch_size=128, training=True, **kwargs):
super().__init__(**kwargs)
self.x = x
self.y = y
self.batch_size = batch_size
self.training = training
self.indices = np.arange(len(x))
def __len__(self):
return int(np.ceil(len(self.x) / self.batch_size))
def on_epoch_end(self):
if self.training:
np.random.shuffle(self.indices)
def __getitem__(self, idx):
ids = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
images, labels = [], []
for i in ids:
img, lbl = preprocess_train(self.x[i], self.y[i], noisy=False)
images.append(img)
labels.append(lbl)
return np.array(images), np.array(labels)
class ConsistencyDataset(PyDataset):
def __init__(self, x, y, batch_size=128, training=True, **kwargs):
super().__init__(**kwargs)
self.x = x
self.y = y
self.batch_size = batch_size
self.training = training
self.indices = np.arange(len(x))
def __len__(self):
return int(np.ceil(len(self.x) / self.batch_size))
def on_epoch_end(self):
if self.training:
np.random.shuffle(self.indices)
def __getitem__(self, idx):
ids = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
clean, noisy, labels = [], [], []
for i in ids:
img, lbl = self.x[i], self.y[i]
c, _ = preprocess_train(img, lbl, noisy=False)
n, _ = preprocess_train(img, lbl, noisy=True)
clean.append(c)
noisy.append(n)
labels.append(lbl)
clean_batch = np.array(clean, dtype="float32")
noisy_batch = np.array(noisy, dtype="float32")
combined_x = np.concatenate([clean_batch, noisy_batch], axis=-1)
return combined_x, np.array(labels)
Eval Dataset
class EvalDataset(PyDataset):
def __init__(self, x, y, batch_size=128, **kwargs):
super().__init__(**kwargs)
self.x = x
self.y = y
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / self.batch_size))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size : (idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size : (idx + 1) * self.batch_size]
images = [preprocess_test(x, y)[0] for x, y in zip(batch_x, batch_y)]
return np.array(images), np.array(batch_y)
We make sure train_clean_ds and train_noisy_ds are shuffled using the same seed to
ensure their orders are exactly the same. This will be helpful during training the
student model.
# This dataset will be used to train the first model.
train_clean_ds = TeacherDataset(train_x, train_y, BATCH_SIZE, True)
consistency_training_ds = ConsistencyDataset(train_x, train_y, BATCH_SIZE, True)
validation_ds = EvalDataset(val_x, val_y, BATCH_SIZE)
test_ds = EvalDataset(x_test, y_test, BATCH_SIZE)
batch_inputs, labels = next(iter(consistency_training_ds))
clean_imgs = batch_inputs[..., :3]
noisy_imgs = batch_inputs[..., 3:]
plt.figure(figsize=(10, 10))
# Clean images
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(clean_imgs[i].astype("uint8"))
plt.axis("off")
# Noisy images
plt.figure(figsize=(10, 10))
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(noisy_imgs[i].astype("uint8"))
plt.axis("off")
plt.tight_layout()
plt.show()

We now define our model building utility. Our model is based on the ResNet50V2 architecture.
def get_training_model():
base = keras.applications.ResNet50V2(
weights=None,
include_top=False,
input_shape=(CROP_TO, CROP_TO, 3),
)
return keras.Sequential(
[
layers.Input((CROP_TO, CROP_TO, 3)),
layers.Rescaling(1 / 127.5, offset=-1),
base,
layers.GlobalAveragePooling2D(),
layers.Dense(NUM_CLASSES),
]
)
In the interest of reproducibility, we serialize the initial random weights of the teacher network.
initial_model = get_training_model()
initial_model.save("initial_teacher_model.keras")
initial_weights = initial_model.get_weights()
As noted in Noisy Student Training, if the teacher model is trained with geometric ensembling and when the student model is forced to mimic that, it leads to better performance. The original work uses Stochastic Depth and Dropout to bring in the ensembling part but for this example, we will use Stochastic Weight Averaging (SWA) which also resembles geometric ensembling.
# Define the callbacks.
reduce_lr = keras.callbacks.ReduceLROnPlateau(patience=3)
early_stopping = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
# Compile and train the teacher model.
teacher_model = get_training_model()
teacher_model.set_weights(initial_weights)
teacher_model.compile(
optimizer=Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
teacher_model.fit(
train_clean_ds,
validation_data=validation_ds,
epochs=EPOCHS,
callbacks=[
ReduceLROnPlateau(patience=3),
EarlyStopping(patience=5, restore_best_weights=True),
],
)
teacher_model.save("teacher_model.keras")
# Evaluate the teacher model on the test set.
_, acc = teacher_model.evaluate(test_ds, verbose=0)
print(f"Test accuracy: {acc*100}%")
</div>
The `DistillationModel` is a custom Keras model that takes in two inputs: the
clean images and the noisy images.
```python
class DistillationModel(keras.Model):
def __init__(self, student, teacher, **kwargs):
super().__init__(**kwargs)
self.student = student
self.teacher = teacher
self.teacher.trainable = False
self.teacher_logits = None
def call(self, inputs, training=False):
inputs = ops.cast(inputs, "float32")
if ops.shape(inputs)[-1] == 6:
clean = inputs[:, :, :, 0:3]
noisy = inputs[:, :, :, 3:6]
else:
clean = inputs
noisy = inputs
self.teacher_logits = self.teacher(clean, training=False)
return self.student(noisy, training=training)
def get_config(self):
config = super().get_config()
config.update(
{
"student": keras.utils.serialize_keras_object(self.student),
"teacher": keras.utils.serialize_keras_object(self.teacher),
}
)
return config
@classmethod
def from_config(cls, config):
student_config = config.pop("student")
teacher_config = config.pop("teacher")
student = keras.utils.deserialize_keras_object(student_config)
teacher = keras.utils.deserialize_keras_object(teacher_config)
return cls(student, teacher, **config)
def distillation_loss(y_true, student_logits):
teacher_logits = distill_model.teacher_logits
student_loss = keras.losses.sparse_categorical_crossentropy(
y_true, student_logits, from_logits=True
)
t_soft = ops.softmax(teacher_logits / TEMPERATURE, axis=-1)
s_soft = ops.softmax(student_logits / TEMPERATURE, axis=-1)
distill_kl = keras.losses.kl_divergence(t_soft, s_soft)
return ops.mean(0.5 * student_loss + 0.5 * distill_kl)
# Define the callbacks.
# We are using a larger decay factor to stabilize the training.
reduce_lr = keras.callbacks.ReduceLROnPlateau(
patience=3, factor=0.5, monitor="val_accuracy"
)
early_stopping = keras.callbacks.EarlyStopping(
patience=10, restore_best_weights=True, monitor="val_accuracy"
)
# Compile and train the student model.
student = get_training_model()
student.set_weights(initial_weights)
distill_model = DistillationModel(student, teacher_model)
distill_model.compile(optimizer=Adam(), loss=distillation_loss, metrics=["accuracy"])
history = distill_model.fit(
consistency_training_ds,
epochs=EPOCHS,
validation_data=validation_ds,
callbacks=[reduce_lr, early_stopping],
)
student.save("student_model_final.keras")
# Evaluate the student model.
_, acc = distill_model.evaluate(test_ds, verbose=0)
print(f"Test accuracy from student model: {acc*100}%")