Code examples / Computer Vision / Distilling Vision Transformers

Distilling Vision Transformers

Author: Sayak Paul
Date created: 2022/04/05
Last modified: 2026/02/10
Description: Distillation of Vision Transformers through attention.

ⓘ This example uses Keras 2. This example may not be compatible with the latest version of Keras. Please check out all of our Keras 3 examples here.

View in Colab GitHub source


Introduction

In the original Vision Transformers (ViT) paper (Dosovitskiy et al.), the authors concluded that to perform on par with Convolutional Neural Networks (CNNs), ViTs need to be pre-trained on larger datasets. The larger the better. This is mainly due to the lack of inductive biases in the ViT architecture – unlike CNNs, they don't have layers that exploit locality. In a follow-up paper (Steiner et al.), the authors show that it is possible to substantially improve the performance of ViTs with stronger regularization and longer training.

Many groups have proposed different ways to deal with the problem of data-intensiveness of ViT training. One such way was shown in the Data-efficient image Transformers, (DeiT) paper (Touvron et al.). The authors introduced a distillation technique that is specific to transformer-based vision models. DeiT is among the first works to show that it's possible to train ViTs well without using larger datasets.

In this example, we implement the distillation recipe proposed in DeiT. This requires us to slightly tweak the original ViT architecture and write a custom training loop to implement the distillation recipe.

To comfortably navigate through this example, you'll be expected to know how a ViT and knowledge distillation work. The following are good resources in case you needed a refresher:


Imports

from typing import List

import tensorflow as tf
import tensorflow_datasets as tfds
import keras
from keras import layers

tfds.disable_progress_bar()
keras.utils.set_random_seed(42)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1770754850.038391    5167 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770754850.043322    5167 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770754850.055075    5167 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770754850.055088    5167 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770754850.055089    5167 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770754850.055090    5167 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.

Constants

# Model
MODEL_TYPE = "deit_distilled_tiny_patch16_224"
RESOLUTION = 224
PATCH_SIZE = 16
NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 192
NUM_HEADS = 3
NUM_LAYERS = 12
MLP_UNITS = [
    PROJECTION_DIM * 4,
    PROJECTION_DIM,
]
DROPOUT_RATE = 0.0
DROP_PATH_RATE = 0.1

# Training
NUM_EPOCHS = 20
BASE_LR = 0.0005
WEIGHT_DECAY = 0.0001

# Data
BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
NUM_CLASSES = 5

You probably noticed that DROPOUT_RATE has been set 0.0. Dropout has been used in the implementation to keep it complete. For smaller models (like the one used in this example), you don't need it, but for bigger models, using dropout helps.


Load the tf_flowers dataset and prepare preprocessing utilities

The authors use an array of different augmentation techniques, including MixUp (Zhang et al.), RandAugment (Cubuk et al.), and so on. However, to keep the example simple to work through, we'll discard them.

def preprocess_dataset(is_training=True):
    def fn(image, label):
        if is_training:
            # Resize to a bigger spatial resolution and take the random
            # crops.
            image = keras.ops.image.resize(image, (RESOLUTION + 20, RESOLUTION + 20))
            # Perform random crop using TensorFlow ops for graph compatibility
            # Get random crop coordinates (0 to 20 pixels offset)
            crop_top = tf.random.uniform((), 0, 21, dtype=tf.int32)
            crop_left = tf.random.uniform((), 0, 21, dtype=tf.int32)
            image = tf.image.crop_to_bounding_box(
                image,
                offset_height=crop_top,
                offset_width=crop_left,
                target_height=RESOLUTION,
                target_width=RESOLUTION,
            )
            # Random horizontal flip
            if tf.random.uniform(()) > 0.5:
                image = tf.image.flip_left_right(image)
        else:
            image = keras.ops.image.resize(image, (RESOLUTION, RESOLUTION))
        label = keras.ops.one_hot(label, num_classes=NUM_CLASSES)
        return image, label

    return fn


def prepare_dataset(dataset, is_training=True):
    if is_training:
        dataset = dataset.shuffle(BATCH_SIZE * 10)
    dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
    return dataset.batch(BATCH_SIZE).prefetch(AUTO)


train_dataset, val_dataset = tfds.load(
    "tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")

train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
Number of training examples: 3303
Number of validation examples: 367

Implementing the DeiT variants of ViT

Since DeiT is an extension of ViT it'd make sense to first implement ViT and then extend it to support DeiT's components.

First, we'll implement a layer for Stochastic Depth (Huang et al.) which is used in DeiT for regularization.

# Referred from: github.com:rwightman/pytorch-image-models.
class StochasticDepth(layers.Layer):
    def __init__(self, drop_prop, **kwargs):
        super().__init__(**kwargs)
        self.drop_prob = drop_prop
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, x, training=True):
        if training:
            keep_prob = 1 - self.drop_prob
            shape = (keras.ops.shape(x)[0],) + (1,) * (len(keras.ops.shape(x)) - 1)
            random_tensor = keep_prob + keras.random.uniform(
                shape, 0, 1, seed=self.seed_generator
            )
            random_tensor = keras.ops.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x

Now, we'll implement the MLP and Transformer blocks.

def mlp(x, dropout_rate: float, hidden_units: List):
    """FFN for a Transformer block."""
    # Iterate over the hidden units and
    # add Dense => Dropout.
    for idx, units in enumerate(hidden_units):
        x = layers.Dense(
            units,
            activation="gelu" if idx == 0 else None,
        )(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


def transformer(drop_prob: float, name: str) -> keras.Model:
    """Transformer block with pre-norm."""
    num_patches = NUM_PATCHES + 2 if "distilled" in MODEL_TYPE else NUM_PATCHES + 1
    encoded_patches = layers.Input((num_patches, PROJECTION_DIM))

    # Layer normalization 1.
    x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)

    # Multi Head Self Attention layer 1.
    attention_output = layers.MultiHeadAttention(
        num_heads=NUM_HEADS,
        key_dim=PROJECTION_DIM,
        dropout=DROPOUT_RATE,
    )(x1, x1)
    attention_output = (
        StochasticDepth(drop_prob)(attention_output) if drop_prob else attention_output
    )

    # Skip connection 1.
    x2 = layers.Add()([attention_output, encoded_patches])

    # Layer normalization 2.
    x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)

    # MLP layer 1.
    x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=DROPOUT_RATE)
    x4 = StochasticDepth(drop_prob)(x4) if drop_prob else x4

    # Skip connection 2.
    outputs = layers.Add()([x2, x4])

    return keras.Model(encoded_patches, outputs, name=name)

We'll now implement a ViTClassifier class building on top of the components we just developed. Here we'll be following the original pooling strategy used in the ViT paper – use a class token and use the feature representations corresponding to it for classification.

class ViTClassifier(keras.Model):
    """Vision Transformer base class."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # Patchify + linear projection + reshaping.
        self.projection = keras.Sequential(
            [
                layers.Conv2D(
                    filters=PROJECTION_DIM,
                    kernel_size=(PATCH_SIZE, PATCH_SIZE),
                    strides=(PATCH_SIZE, PATCH_SIZE),
                    padding="VALID",
                    name="conv_projection",
                ),
                layers.Reshape(
                    target_shape=(NUM_PATCHES, PROJECTION_DIM),
                    name="flatten_projection",
                ),
            ],
            name="projection",
        )

        # Transformer blocks.
        dpr = [x for x in keras.ops.linspace(0.0, DROP_PATH_RATE, NUM_LAYERS)]
        self.transformer_blocks = [
            transformer(drop_prob=dpr[i], name=f"transformer_block_{i}")
            for i in range(NUM_LAYERS)
        ]

        # Other layers.
        self.dropout = layers.Dropout(DROPOUT_RATE)
        self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
        self.head = layers.Dense(
            NUM_CLASSES,
            name="classification_head",
        )

    def build(self, input_shape):
        # Positional embedding.
        self.positional_embedding = self.add_weight(
            shape=(1, NUM_PATCHES + 1, PROJECTION_DIM),
            initializer=keras.initializers.Zeros(),
            trainable=True,
            name="pos_embedding",
        )

        # CLS token.
        self.cls_token = self.add_weight(
            shape=(1, 1, PROJECTION_DIM),
            initializer=keras.initializers.Zeros(),
            trainable=True,
            name="cls",
        )
        super().build(input_shape)

    def call(self, inputs, training=True):
        n = keras.ops.shape(inputs)[0]

        # Create patches and project the patches.
        projected_patches = self.projection(inputs)
        cls_token = keras.ops.tile(self.cls_token, (n, 1, 1))
        cls_token = keras.ops.cast(cls_token, projected_patches.dtype)
        projected_patches = keras.ops.concatenate(
            [cls_token, projected_patches], axis=1
        )

        # Add positional embeddings to the projected patches.
        encoded_patches = (
            self.positional_embedding + projected_patches
        )  # (B, number_patches, projection_dim)
        encoded_patches = self.dropout(encoded_patches)

        # Iterate over the number of layers and stack up blocks of
        # Transformer.
        for transformer_module in self.transformer_blocks:
            # Add a Transformer block.
            encoded_patches = transformer_module(encoded_patches)

        # Final layer normalization.
        representation = self.layer_norm(encoded_patches)

        # Pool representation.
        encoded_patches = representation[:, 0]

        # Classification head.

        output = self.head(encoded_patches)

        return output

This class can be used standalone as ViT and is end-to-end trainable. Just remove the distilled phrase in MODEL_TYPE and it should work with vit_tiny = ViTClassifier(). Let's now extend it to DeiT. The following figure presents the schematic of DeiT (taken from the DeiT paper):

Apart from the class token, DeiT has another token for distillation. During distillation, the logits corresponding to the class token are compared to the true labels, and the logits corresponding to the distillation token are compared to the teacher's predictions.

class ViTDistilled(ViTClassifier):
    def __init__(self, regular_training=False, **kwargs):
        super().__init__(**kwargs)
        self.num_tokens = 2
        self.regular_training = regular_training

        # Head layers.
        self.head = layers.Dense(
            NUM_CLASSES,
            name="classification_head",
        )
        self.head_dist = layers.Dense(
            NUM_CLASSES,
            name="distillation_head",
        )

    def build(self, input_shape):
        # CLS token.
        self.cls_token = self.add_weight(
            shape=(1, 1, PROJECTION_DIM),
            initializer=keras.initializers.Zeros(),
            trainable=True,
            name="cls",
        )

        # Distillation token.
        self.dist_token = self.add_weight(
            shape=(1, 1, PROJECTION_DIM),
            initializer=keras.initializers.Zeros(),
            trainable=True,
            name="dist_token",
        )

        # Positional embedding (for NUM_PATCHES + 2 tokens: cls + dist).
        self.positional_embedding = self.add_weight(
            shape=(1, NUM_PATCHES + self.num_tokens, PROJECTION_DIM),
            initializer=keras.initializers.Zeros(),
            trainable=True,
            name="pos_embedding",
        )

    def call(self, inputs, training=True):
        n = keras.ops.shape(inputs)[0]

        # Create patches and project the patches.
        projected_patches = self.projection(inputs)

        # Append the tokens.
        cls_token = keras.ops.tile(self.cls_token, (n, 1, 1))
        dist_token = keras.ops.tile(self.dist_token, (n, 1, 1))
        cls_token = keras.ops.cast(cls_token, projected_patches.dtype)
        dist_token = keras.ops.cast(dist_token, projected_patches.dtype)
        projected_patches = keras.ops.concatenate(
            [cls_token, dist_token, projected_patches], axis=1
        )

        # Add positional embeddings to the projected patches.
        encoded_patches = (
            self.positional_embedding + projected_patches
        )  # (B, number_patches, projection_dim)
        encoded_patches = self.dropout(encoded_patches)

        # Iterate over the number of layers and stack up blocks of
        # Transformer.
        for transformer_module in self.transformer_blocks:
            # Add a Transformer block.
            encoded_patches = transformer_module(encoded_patches)

        # Final layer normalization.
        representation = self.layer_norm(encoded_patches)

        # Classification heads.
        x, x_dist = (
            self.head(representation[:, 0]),
            self.head_dist(representation[:, 1]),
        )

        # Only return separate classification predictions when training in distilled
        # mode.
        if training and not self.regular_training:
            return x, x_dist
        # During standard train / finetune, inference average the classifier
        # predictions.
        return (x + x_dist) / 2

Let's verify if the ViTDistilled class can be initialized and called as expected.

deit_tiny_distilled = ViTDistilled()

dummy_inputs = tf.ones((2, 224, 224, 3))
outputs = deit_tiny_distilled(dummy_inputs, training=False)
print(outputs.shape)
(2, 5)

Implementing the trainer

Unlike what happens in standard knowledge distillation (Hinton et al.), where a temperature-scaled softmax is used as well as KL divergence, DeiT authors use the following loss function:

Here,

  • CE is cross-entropy
  • psi is the softmax function
  • Z_s denotes student predictions
  • y denotes true labels
  • y_t denotes teacher predictions
class DeiT(keras.Model):
    # Reference:
    # https://keras.io/examples/vision/knowledge_distillation/
    def __init__(self, student, teacher, **kwargs):
        super().__init__(**kwargs)
        self.student = student
        self.teacher = teacher

        self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
        self.dist_loss_tracker = keras.metrics.Mean(name="distillation_loss")
        self.accuracy_metric = keras.metrics.CategoricalAccuracy(name="accuracy")

    @property
    def metrics(self):
        metrics = super().metrics
        metrics.append(self.student_loss_tracker)
        metrics.append(self.dist_loss_tracker)
        metrics.append(self.accuracy_metric)
        return metrics

    def compile(
        self,
        optimizer,
        student_loss_fn,
        distillation_loss_fn,
    ):
        super().compile(optimizer=optimizer)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn

    def train_step(self, data):
        # Unpack data.
        x, y = data

        # Normalize for student (ViT expects [0, 1])
        x_student = keras.ops.cast(x, "float32") / 255.0

        # Teacher expects raw [0, 255] float32 (no normalization)
        x_teacher = keras.ops.cast(x, "float32")

        # Forward pass of teacher
        # TFSMLayer returns a dictionary, extract the output
        teacher_output = self.teacher(x_teacher, training=False)
        if isinstance(teacher_output, dict):
            # Get the first (and likely only) output from the dictionary
            teacher_output = list(teacher_output.values())[0]
        # Use soft targets (probabilities) for distillation
        teacher_predictions = keras.ops.nn.softmax(teacher_output, -1)

        with tf.GradientTape() as tape:
            # Forward pass of student.
            cls_predictions, dist_predictions = self.student(x_student, training=True)

            # Compute losses.
            student_loss = self.student_loss_fn(y, cls_predictions)
            distillation_loss = self.distillation_loss_fn(
                teacher_predictions, dist_predictions
            )
            loss = (student_loss + distillation_loss) / 2

        # Compute gradients.
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights.
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics.
        student_predictions = (cls_predictions + dist_predictions) / 2
        self.accuracy_metric.update_state(y, student_predictions)
        self.dist_loss_tracker.update_state(distillation_loss)
        self.student_loss_tracker.update_state(student_loss)

        # Return a dict of performance - include loss
        return {
            "loss": loss,
            "student_loss": self.student_loss_tracker.result(),
            "distillation_loss": self.dist_loss_tracker.result(),
            "accuracy": self.accuracy_metric.result(),
        }

    def test_step(self, data):
        # Unpack the data.
        x, y = data

        # Convert to float32 and normalize for student
        x_normalized = keras.ops.cast(x, "float32") / 255.0

        # Compute predictions.
        y_prediction = self.student(x_normalized, training=False)

        # Calculate the loss.
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.accuracy_metric.update_state(y, y_prediction)
        self.student_loss_tracker.update_state(student_loss)

        # Return a dict of performance
        return {
            "loss": student_loss,
            "student_loss": self.student_loss_tracker.result(),
            "accuracy": self.accuracy_metric.result(),
        }

    def call(self, inputs):
        # Convert to float32 and normalize for student
        inputs_normalized = keras.ops.cast(inputs, "float32") / 255.0
        return self.student(inputs_normalized, training=False)

Load the teacher model

This model is based on the BiT family of ResNets (Kolesnikov et al.) fine-tuned on the tf_flowers dataset. You can refer to this notebook to know how the training was performed. The teacher model has about 212 Million parameters which is about 40x more than the student.

!wget -q https://github.com/sayakpaul/deit-tf/releases/download/v0.1.0/bit_teacher_flowers.zip
!unzip -q bit_teacher_flowers.zip
bit_teacher_flowers = keras.layers.TFSMLayer(
    "bit_teacher_flowers", call_endpoint="serving_default"
)

Training through distillation

deit_tiny = ViTDistilled()
deit_distiller = DeiT(student=deit_tiny, teacher=bit_teacher_flowers)

lr_scaled = (BASE_LR / 512) * BATCH_SIZE
deit_distiller.compile(
    optimizer=keras.optimizers.AdamW(
        weight_decay=WEIGHT_DECAY, learning_rate=lr_scaled
    ),
    student_loss_fn=keras.losses.CategoricalCrossentropy(
        from_logits=True, label_smoothing=0.1
    ),
    distillation_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=True),
)
_ = deit_distiller.fit(train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS)
Epoch 1/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 130s 8s/step - accuracy: 0.2150 - distillation_loss: 2.1021 - loss: 0.0000e+00 - student_loss: 1.8120 - val_accuracy: 0.2616 - val_loss: 1.6223 - val_student_loss: 1.6278

Epoch 2/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.2416 - distillation_loss: 1.6185 - loss: 0.0000e+00 - student_loss: 1.6297 - val_accuracy: 0.1662 - val_loss: 1.6018 - val_student_loss: 1.6075

Epoch 3/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 104s 8s/step - accuracy: 0.2467 - distillation_loss: 1.6028 - loss: 0.0000e+00 - student_loss: 1.6087 - val_accuracy: 0.2316 - val_loss: 1.5954 - val_student_loss: 1.6009

Epoch 4/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.2349 - distillation_loss: 1.5968 - loss: 0.0000e+00 - student_loss: 1.6022 - val_accuracy: 0.2289 - val_loss: 1.5922 - val_student_loss: 1.6017

Epoch 5/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.2634 - distillation_loss: 1.5902 - loss: 0.0000e+00 - student_loss: 1.5928 - val_accuracy: 0.3025 - val_loss: 1.5703 - val_student_loss: 1.5795

Epoch 6/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.3279 - distillation_loss: 1.5441 - loss: 0.0000e+00 - student_loss: 1.5456 - val_accuracy: 0.3515 - val_loss: 1.4880 - val_student_loss: 1.4937

Epoch 7/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.3966 - distillation_loss: 1.4085 - loss: 0.0000e+00 - student_loss: 1.4534 - val_accuracy: 0.3706 - val_loss: 1.4348 - val_student_loss: 1.4335

Epoch 8/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.3890 - distillation_loss: 1.3647 - loss: 0.0000e+00 - student_loss: 1.4229 - val_accuracy: 0.3297 - val_loss: 1.4575 - val_student_loss: 1.4463

Epoch 9/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.4223 - distillation_loss: 1.3332 - loss: 0.0000e+00 - student_loss: 1.3850 - val_accuracy: 0.4114 - val_loss: 1.3888 - val_student_loss: 1.3763

Epoch 10/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.4475 - distillation_loss: 1.2577 - loss: 0.0000e+00 - student_loss: 1.3548 - val_accuracy: 0.4441 - val_loss: 1.3202 - val_student_loss: 1.3331

Epoch 11/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.4717 - distillation_loss: 1.2107 - loss: 0.0000e+00 - student_loss: 1.2995 - val_accuracy: 0.4632 - val_loss: 1.3016 - val_student_loss: 1.2872

Epoch 12/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.5017 - distillation_loss: 1.1562 - loss: 0.0000e+00 - student_loss: 1.2542 - val_accuracy: 0.5395 - val_loss: 1.2761 - val_student_loss: 1.2575

Epoch 13/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.5328 - distillation_loss: 1.1119 - loss: 0.0000e+00 - student_loss: 1.2223 - val_accuracy: 0.5068 - val_loss: 1.2102 - val_student_loss: 1.2383

Epoch 14/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 102s 8s/step - accuracy: 0.5655 - distillation_loss: 1.0595 - loss: 0.0000e+00 - student_loss: 1.1837 - val_accuracy: 0.5722 - val_loss: 1.1773 - val_student_loss: 1.1774

Epoch 15/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.5998 - distillation_loss: 1.0133 - loss: 0.0000e+00 - student_loss: 1.1465 - val_accuracy: 0.5204 - val_loss: 1.2519 - val_student_loss: 1.2340

Epoch 16/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6110 - distillation_loss: 0.9992 - loss: 0.0000e+00 - student_loss: 1.1359 - val_accuracy: 0.6104 - val_loss: 1.0947 - val_student_loss: 1.1090

Epoch 17/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6191 - distillation_loss: 0.9635 - loss: 0.0000e+00 - student_loss: 1.1101 - val_accuracy: 0.6076 - val_loss: 1.0678 - val_student_loss: 1.0952

Epoch 18/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6400 - distillation_loss: 0.9460 - loss: 0.0000e+00 - student_loss: 1.0902 - val_accuracy: 0.6076 - val_loss: 1.0256 - val_student_loss: 1.0681

Epoch 19/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6340 - distillation_loss: 0.9411 - loss: 0.0000e+00 - student_loss: 1.0943 - val_accuracy: 0.6213 - val_loss: 1.0353 - val_student_loss: 1.0702

Epoch 20/20

13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6506 - distillation_loss: 0.9121 - loss: 0.0000e+00 - student_loss: 1.0674 - val_accuracy: 0.6376 - val_loss: 1.0027 - val_student_loss: 1.0602

If we had trained the same model (the ViTClassifier) from scratch with the exact same hyperparameters, the model would have scored about 59% accuracy. You can adapt the following code to reproduce this result:

vit_tiny = ViTClassifier()

inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
x = keras.layers.Rescaling(scale=1./255)(inputs)
outputs = deit_tiny(x)
model = keras.Model(inputs, outputs)

model.compile(...)
model.fit(...)

Notes

  • Through the use of distillation, we're effectively transferring the inductive biases of a CNN-based teacher model.
  • Interestingly enough, this distillation strategy works better with a CNN as the teacher model rather than a Transformer as shown in the paper.
  • The use of regularization to train DeiT models is very important.
  • ViT models are initialized with a combination of different initializers including truncated normal, random normal, Glorot uniform, etc. If you're looking for end-to-end reproduction of the original results, don't forget to initialize the ViTs well.
  • If you want to explore the pre-trained DeiT models in Keras with code for fine-tuning, check out these models on TF-Hub.

Acknowledgements

  • Ross Wightman for keeping timm updated with readable implementations. I referred to the implementations of ViT and DeiT a lot during implementing them in Keras.
  • Aritra Roy Gosthipaty who implemented some portions of the ViTClassifier in another project.
  • Google Developers Experts program for supporting me with GCP credits which were used to run experiments for this example.

Example available on HuggingFace:

Trained Model Demo
Generic badge Generic badge

Relevant Chapters from Deep Learning with Python