Author: Kenneth Borup
Date created: 2020/09/01
Last modified: 2020/09/01
Description: Implementation of classical Knowledge Distillation.
Knowledge Distillation is a procedure for model compression, in which a small (student) model is trained to match a large pre-trained (teacher) model. Knowledge is transferred from the teacher model to the student by minimizing a loss function, aimed at matching softened teacher logits as well as ground-truth labels.
The logits are softened by applying a "temperature" scaling function in the softmax, effectively smoothing out the probability distribution and revealing inter-class relationships learned by the teacher.
Reference:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
Distiller()
classThe custom Distiller()
class, overrides the Model
methods train_step
, test_step
,
and compile()
. In order to use the distiller, we need:
temperature
, on the difference between the
soft student predictions and the soft teacher labelsalpha
factor to weight the student and distillation lossIn the train_step
method, we perform a forward pass of both the teacher and student,
calculate the loss with weighting of the student_loss
and distillation_loss
by alpha
and
1 - alpha
, respectively, and perform the backward pass. Note: only the student weights are updated,
and therefore we only calculate the gradients for the student weights.
In the test_step
method, we evaluate the student model on the provided dataset.
class Distiller(keras.Model):
def __init__(self, student, teacher):
super().__init__()
self.teacher = teacher
self.student = student
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
alpha=0.1,
temperature=3,
):
""" Configure the distiller.
Args:
optimizer: Keras optimizer for the student weights
metrics: Keras metrics for evaluation
student_loss_fn: Loss function of difference between student
predictions and ground-truth
distillation_loss_fn: Loss function of difference between soft
student predictions and soft teacher predictions
alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
temperature: Temperature for softening probability distributions.
Larger temperature gives softer distributions.
"""
super().compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
self.alpha = alpha
self.temperature = temperature
def train_step(self, data):
# Unpack data
x, y = data
# Forward pass of teacher
teacher_predictions = self.teacher(x, training=False)
with tf.GradientTape() as tape:
# Forward pass of student
student_predictions = self.student(x, training=True)
# Compute losses
student_loss = self.student_loss_fn(y, student_predictions)
# Compute scaled distillation loss from https://arxiv.org/abs/1503.02531
# The magnitudes of the gradients produced by the soft targets scale
# as 1/T^2, multiply them by T^2 when using both hard and soft targets.
distillation_loss = (
self.distillation_loss_fn(
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
tf.nn.softmax(student_predictions / self.temperature, axis=1),
)
* self.temperature**2
)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
# Compute gradients
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics configured in `compile()`.
self.compiled_metrics.update_state(y, student_predictions)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update(
{"student_loss": student_loss, "distillation_loss": distillation_loss}
)
return results
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_prediction = self.student(x, training=False)
# Calculate the loss
student_loss = self.student_loss_fn(y, y_prediction)
# Update the metrics.
self.compiled_metrics.update_state(y, y_prediction)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update({"student_loss": student_loss})
return results
Initialy, we create a teacher model and a smaller student model. Both models are
convolutional neural networks and created using Sequential()
,
but could be any Keras model.
# Create the teacher
teacher = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="teacher",
)
# Create the student
student = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="student",
)
# Clone student for later comparison
student_scratch = keras.models.clone_model(student)
The dataset used for training the teacher and distilling the teacher is MNIST, and the procedure would be equivalent for any other dataset, e.g. CIFAR-10, with a suitable choice of models. Both the student and teacher are trained on the training set and evaluated on the test set.
# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
In knowledge distillation we assume that the teacher is trained and fixed. Thus, we start by training the teacher model on the training set in the usual way.
# Train teacher as usual
teacher.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)
Epoch 1/5
1875/1875 [==============================] - 248s 132ms/step - loss: 0.2438 - sparse_categorical_accuracy: 0.9220
Epoch 2/5
1875/1875 [==============================] - 263s 140ms/step - loss: 0.0881 - sparse_categorical_accuracy: 0.9738
Epoch 3/5
1875/1875 [==============================] - 245s 131ms/step - loss: 0.0650 - sparse_categorical_accuracy: 0.9811
Epoch 5/5
363/1875 [====>.........................] - ETA: 3:18 - loss: 0.0555 - sparse_categorical_accuracy: 0.9839
We have already trained the teacher model, and we only need to initialize a
Distiller(student, teacher)
instance, compile()
it with the desired losses,
hyperparameters and optimizer, and distill the teacher to the student.
# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
optimizer=keras.optimizers.Adam(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
distillation_loss_fn=keras.losses.KLDivergence(),
alpha=0.1,
temperature=10,
)
# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3)
# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)
Epoch 1/3
1875/1875 [==============================] - 242s 129ms/step - sparse_categorical_accuracy: 0.9761 - student_loss: 0.1526 - distillation_loss: 0.0226
Epoch 2/3
1875/1875 [==============================] - 281s 150ms/step - sparse_categorical_accuracy: 0.9863 - student_loss: 0.1384 - distillation_loss: 0.0185
Epoch 3/3
399/1875 [=====>........................] - ETA: 3:27 - sparse_categorical_accuracy: 0.9896 - student_loss: 0.1300 - distillation_loss: 0.0182
We can also train an equivalent student model from scratch without the teacher, in order to evaluate the performance gain obtained by knowledge distillation.
# Train student as doen usually
student_scratch.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)
Epoch 1/3
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4731 - sparse_categorical_accuracy: 0.8550
Epoch 2/3
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0966 - sparse_categorical_accuracy: 0.9710
Epoch 3/3
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0750 - sparse_categorical_accuracy: 0.9773
313/313 [==============================] - 0s 963us/step - loss: 0.0691 - sparse_categorical_accuracy: 0.9778
[0.06905383616685867, 0.9778000116348267]
If the teacher is trained for 5 full epochs and the student is distilled on this teacher for 3 full epochs, you should in this example experience a performance boost compared to training the same student model from scratch, and even compared to the teacher itself. You should expect the teacher to have accuracy around 97.6%, the student trained from scratch should be around 97.6%, and the distilled student should be around 98.1%. Remove or try out different seeds to use different weight initializations.