Author: Sayak Paul
Date created: 2022/04/05
Last modified: 2026/05/13
Description: Distillation of Vision Transformers through attention.
In the original Vision Transformers (ViT) paper (Dosovitskiy et al.), the authors concluded that to perform on par with Convolutional Neural Networks (CNNs), ViTs need to be pre-trained on larger datasets. The larger the better. This is mainly due to the lack of inductive biases in the ViT architecture – unlike CNNs, they don't have layers that exploit locality. In a follow-up paper (Steiner et al.), the authors show that it is possible to substantially improve the performance of ViTs with stronger regularization and longer training.
Many groups have proposed different ways to deal with the problem of data-intensiveness of ViT training. One such way was shown in the Data-efficient image Transformers, (DeiT) paper (Touvron et al.). The authors introduced a distillation technique that is specific to transformer-based vision models. DeiT is among the first works to show that it's possible to train ViTs well without using larger datasets.
In this example, we implement the distillation recipe proposed in DeiT. This requires us to slightly tweak the original ViT architecture and write a custom training loop to implement the distillation recipe.
To comfortably navigate through this example, you'll be expected to know how a ViT and knowledge distillation work. The following are good resources in case you needed a refresher:
from pathlib import Path
from typing import List
import numpy as np
import keras
from keras import layers
keras.utils.set_random_seed(42)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1778696636.585494 2411 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1778696636.592069 2411 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1778696636.608151 2411 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1778696636.608181 2411 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1778696636.608183 2411 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1778696636.608184 2411 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
# Model
MODEL_TYPE = "deit_distilled_tiny_patch16_224"
RESOLUTION = 224
PATCH_SIZE = 16
NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 192
NUM_HEADS = 3
NUM_LAYERS = 12
MLP_UNITS = [PROJECTION_DIM * 4, PROJECTION_DIM]
DROPOUT_RATE = 0.0
DROP_PATH_RATE = 0.1
# Training
NUM_EPOCHS = 20
BASE_LR = 0.0005
WEIGHT_DECAY = 0.0001
# Data
BATCH_SIZE = 256
NUM_CLASSES = 5
You probably noticed that DROPOUT_RATE has been set 0.0. Dropout has been used
in the implementation to keep it complete. For smaller models (like the one used in
this example), you don't need it, but for bigger models, using dropout helps.
The authors use an array of different augmentation techniques, including MixUp (Zhang et al.), RandAugment (Cubuk et al.), and so on. However, to keep the example simple to work through, we'll discard them.
We use keras.utils.PyDataset to build a fully backend-agnostic data pipeline that
works with JAX, PyTorch, and TensorFlow alike.
A couple of practical details are important here:
keras.utils.get_file(untar=True) may return the extraction cache directory, so we
explicitly resolve the inner flower_photos/ folder when present.target_size before stacking into a NumPy batch.FLOWERS_URL = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
class FlowersDataset(keras.utils.PyDataset):
"""Backend-agnostic flowers dataset that loads images from disk each epoch."""
def __init__(
self,
image_paths,
labels,
augmenter,
batch_size=BATCH_SIZE,
shuffle=False,
seed=42,
**kwargs,
):
super().__init__(**kwargs)
self.image_paths = np.array(image_paths)
self.labels = np.array(labels, dtype="int32")
self.augmenter = augmenter
self.batch_size = batch_size
self.shuffle = shuffle
self.rng = np.random.default_rng(seed)
self.indices = np.arange(len(self.image_paths))
self.on_epoch_end()
def __len__(self):
return int(np.ceil(len(self.image_paths) / self.batch_size))
def on_epoch_end(self):
if self.shuffle:
self.rng.shuffle(self.indices)
def __getitem__(self, idx):
start = idx * self.batch_size
end = min((idx + 1) * self.batch_size, len(self.image_paths))
batch_indices = self.indices[start:end]
images = []
for i in batch_indices:
target_size = (
(RESOLUTION + 20, RESOLUTION + 20)
if self.shuffle
else (RESOLUTION, RESOLUTION)
)
image = keras.utils.load_img(self.image_paths[i], target_size=target_size)
images.append(keras.utils.img_to_array(image))
images = np.array(images, dtype="float32")
if self.augmenter is not None:
images = self.augmenter(images, training=self.shuffle)
labels = keras.ops.one_hot(self.labels[batch_indices], num_classes=NUM_CLASSES)
return images, labels
def get_augmenter(is_training=True):
if is_training:
return keras.Sequential(
[
layers.RandomCrop(RESOLUTION, RESOLUTION),
layers.RandomFlip("horizontal"),
],
name="train_augmentation",
)
return None
def load_flower_file_paths(validation_split=0.1):
extracted = Path(keras.utils.get_file(origin=FLOWERS_URL, untar=True))
data_dir = (
extracted / "flower_photos"
if (extracted / "flower_photos").is_dir()
else extracted
)
class_names = sorted([p.name for p in data_dir.iterdir() if p.is_dir()])
class_to_index = {name: idx for idx, name in enumerate(class_names)}
train_paths, train_labels = [], []
val_paths, val_labels = [], []
rng = np.random.default_rng(42)
for class_name in class_names:
class_files = sorted((data_dir / class_name).glob("*.jpg"))
class_files = np.array([str(path) for path in class_files])
rng.shuffle(class_files)
num_val = int(len(class_files) * validation_split)
val_paths.extend(class_files[:num_val])
val_labels.extend([class_to_index[class_name]] * num_val)
train_paths.extend(class_files[num_val:])
train_labels.extend([class_to_index[class_name]] * (len(class_files) - num_val))
return train_paths, train_labels, val_paths, val_labels
train_paths, train_labels, val_paths, val_labels = load_flower_file_paths()
print(f"Number of training examples: {len(train_paths)}")
print(f"Number of validation examples: {len(val_paths)}")
train_dataset = FlowersDataset(
train_paths,
train_labels,
augmenter=get_augmenter(is_training=True),
shuffle=True,
workers=4,
)
val_dataset = FlowersDataset(
val_paths,
val_labels,
augmenter=get_augmenter(is_training=False),
shuffle=False,
workers=4,
)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228813984/228813984 ━━━━━━━━━━━━━━━━━━━━ 2s 0us/step
Number of training examples: 3306
Number of validation examples: 364
I0000 00:00:1778696644.640571 2411 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38482 MB memory: -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0
Since DeiT is an extension of ViT it'd make sense to first implement ViT and then extend it to support DeiT's components.
First, we'll implement a layer for Stochastic Depth (Huang et al.) which is used in DeiT for regularization.
# Referred from: github.com:rwightman/pytorch-image-models.
class StochasticDepth(layers.Layer):
def __init__(self, drop_prop, **kwargs):
super().__init__(**kwargs)
self.drop_prob = drop_prop
self.seed_generator = keras.random.SeedGenerator(1337)
def call(self, x, training=True):
if training:
keep_prob = 1 - self.drop_prob
shape = (keras.ops.shape(x)[0],) + (1,) * (len(keras.ops.shape(x)) - 1)
random_tensor = keep_prob + keras.random.uniform(
shape, 0, 1, seed=self.seed_generator
)
random_tensor = keras.ops.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
Now, we'll implement the MLP and Transformer blocks.
def mlp(x, dropout_rate: float, hidden_units: List):
"""FFN for a Transformer block."""
for idx, units in enumerate(hidden_units):
x = layers.Dense(units, activation="gelu" if idx == 0 else None)(x)
x = layers.Dropout(dropout_rate)(x)
return x
def transformer(drop_prob: float, name: str) -> keras.Model:
"""Transformer block with pre-norm."""
num_patches = NUM_PATCHES + 2 if "distilled" in MODEL_TYPE else NUM_PATCHES + 1
encoded_patches = layers.Input((num_patches, PROJECTION_DIM))
x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
attention_output = layers.MultiHeadAttention(
num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=DROPOUT_RATE
)(x1, x1)
attention_output = (
StochasticDepth(drop_prob)(attention_output) if drop_prob else attention_output
)
x2 = layers.Add()([attention_output, encoded_patches])
x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=DROPOUT_RATE)
x4 = StochasticDepth(drop_prob)(x4) if drop_prob else x4
outputs = layers.Add()([x2, x4])
return keras.Model(encoded_patches, outputs, name=name)
We'll now implement a ViTClassifier class building on top of the components we just
developed. Here we'll be following the original pooling strategy used in the ViT paper –
use a class token and use the feature representations corresponding to it for
classification.
class ViTClassifier(keras.Model):
"""Vision Transformer base class."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.projection = keras.Sequential(
[
layers.Conv2D(
filters=PROJECTION_DIM,
kernel_size=(PATCH_SIZE, PATCH_SIZE),
strides=(PATCH_SIZE, PATCH_SIZE),
padding="VALID",
name="conv_projection",
),
layers.Reshape(
target_shape=(NUM_PATCHES, PROJECTION_DIM),
name="flatten_projection",
),
],
name="projection",
)
dpr = [x for x in keras.ops.linspace(0.0, DROP_PATH_RATE, NUM_LAYERS)]
self.transformer_blocks = [
transformer(drop_prob=dpr[i], name=f"transformer_block_{i}")
for i in range(NUM_LAYERS)
]
self.dropout = layers.Dropout(DROPOUT_RATE)
self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
self.head = layers.Dense(NUM_CLASSES, name="classification_head")
def build(self, input_shape):
self.positional_embedding = self.add_weight(
shape=(1, NUM_PATCHES + 1, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="pos_embedding",
)
self.cls_token = self.add_weight(
shape=(1, 1, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="cls",
)
super().build(input_shape)
def call(self, inputs, training=True):
n = keras.ops.shape(inputs)[0]
projected_patches = self.projection(inputs)
cls_token = keras.ops.tile(self.cls_token, (n, 1, 1))
cls_token = keras.ops.cast(cls_token, projected_patches.dtype)
projected_patches = keras.ops.concatenate(
[cls_token, projected_patches], axis=1
)
encoded_patches = self.positional_embedding + projected_patches
encoded_patches = self.dropout(encoded_patches)
for transformer_module in self.transformer_blocks:
encoded_patches = transformer_module(encoded_patches)
representation = self.layer_norm(encoded_patches)
encoded_patches = representation[:, 0]
output = self.head(encoded_patches)
return output
This class can be used standalone as ViT and is end-to-end trainable. Just remove the
distilled phrase in MODEL_TYPE and it should work with vit_tiny = ViTClassifier().
Let's now extend it to DeiT. The following figure presents the schematic of DeiT (taken
from the DeiT paper):

Apart from the class token, DeiT has another token for distillation. During distillation, the logits corresponding to the class token are compared to the true labels, and the logits corresponding to the distillation token are compared to the teacher's predictions.
class ViTDistilled(ViTClassifier):
def __init__(self, regular_training=False, **kwargs):
super().__init__(**kwargs)
self.num_tokens = 2
self.regular_training = regular_training
self.head = layers.Dense(NUM_CLASSES, name="classification_head")
self.head_dist = layers.Dense(NUM_CLASSES, name="distillation_head")
def build(self, input_shape):
self.cls_token = self.add_weight(
shape=(1, 1, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="cls",
)
self.dist_token = self.add_weight(
shape=(1, 1, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="dist_token",
)
self.positional_embedding = self.add_weight(
shape=(1, NUM_PATCHES + self.num_tokens, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="pos_embedding",
)
def call(self, inputs, training=True):
n = keras.ops.shape(inputs)[0]
projected_patches = self.projection(inputs)
cls_token = keras.ops.tile(self.cls_token, (n, 1, 1))
dist_token = keras.ops.tile(self.dist_token, (n, 1, 1))
cls_token = keras.ops.cast(cls_token, projected_patches.dtype)
dist_token = keras.ops.cast(dist_token, projected_patches.dtype)
projected_patches = keras.ops.concatenate(
[cls_token, dist_token, projected_patches], axis=1
)
encoded_patches = self.positional_embedding + projected_patches
encoded_patches = self.dropout(encoded_patches)
for transformer_module in self.transformer_blocks:
encoded_patches = transformer_module(encoded_patches)
representation = self.layer_norm(encoded_patches)
x, x_dist = (
self.head(representation[:, 0]),
self.head_dist(representation[:, 1]),
)
if training and not self.regular_training:
return x, x_dist
return (x + x_dist) / 2
Let's verify if the ViTDistilled class can be initialized and called as expected.
deit_tiny_distilled = ViTDistilled()
dummy_inputs = keras.ops.ones((2, 224, 224, 3))
outputs = deit_tiny_distilled(dummy_inputs, training=False)
print(outputs.shape)
I0000 00:00:1778696648.330183 2411 cuda_dnn.cc:529] Loaded cuDNN version 92000
(2, 5)
Unlike what happens in standard knowledge distillation (Hinton et al.), where a temperature-scaled softmax is used as well as KL divergence, DeiT authors use the following loss function:

Here,
psi is the softmax functionclass DeiT(keras.Model):
# Reference: https://keras.io/examples/vision/knowledge_distillation/
def __init__(self, student, teacher, **kwargs):
super().__init__(**kwargs)
self.student = student
self.teacher = teacher
self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
self.dist_loss_tracker = keras.metrics.Mean(name="distillation_loss")
self.accuracy_metric = keras.metrics.CategoricalAccuracy(name="accuracy")
@property
def metrics(self):
metrics = super().metrics
metrics.append(self.student_loss_tracker)
metrics.append(self.dist_loss_tracker)
metrics.append(self.accuracy_metric)
return metrics
def compile(self, optimizer, student_loss_fn, distillation_loss_fn):
super().compile(optimizer=optimizer)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
x_normalized = keras.ops.cast(x, "float32") / 255.0
cls_predictions, dist_predictions = self.student(x_normalized, training=True)
teacher_logits = self.teacher(keras.ops.cast(x, "float32"), training=False)
teacher_predictions = keras.ops.softmax(teacher_logits, axis=-1)
student_loss = self.student_loss_fn(y, cls_predictions)
distillation_loss = self.distillation_loss_fn(
teacher_predictions, dist_predictions
)
self.student_loss_tracker.update_state(student_loss)
self.dist_loss_tracker.update_state(distillation_loss)
student_predictions = (cls_predictions + dist_predictions) / 2
self.accuracy_metric.update_state(y, student_predictions)
return (student_loss + distillation_loss) / 2
def test_step(self, data):
x, y = data
x_normalized = keras.ops.cast(x, "float32") / 255.0
y_prediction = self.student(x_normalized, training=False)
student_loss = self.student_loss_fn(y, y_prediction)
self.accuracy_metric.update_state(y, y_prediction)
self.student_loss_tracker.update_state(student_loss)
return {
"loss": student_loss,
"student_loss": self.student_loss_tracker.result(),
"accuracy": self.accuracy_metric.result(),
}
def call(self, inputs, training=False):
inputs_normalized = keras.ops.cast(inputs, "float32") / 255.0
return self.student(inputs_normalized, training=False)
For full backend portability in Keras 3, we build a teacher with standard Keras layers
instead of using a TensorFlow-only SavedModel loader. We use EfficientNetV2B0 (pretrained on
ImageNet) as the backbone, freeze it, and fine-tune only a small classification head on
the flowers dataset. In practice you could swap in any compatible Keras model as the
teacher.
EfficientNetV2B0 includes preprocessing by default (include_preprocessing=True),
so it expects raw [0, 255] image values. We therefore avoid adding an extra
Rescaling(1/255) layer in the teacher path to prevent double normalization.
teacher_backbone = keras.applications.EfficientNetV2B0(
include_top=False, pooling="avg", weights="imagenet"
)
teacher_backbone.trainable = False
teacher_model = keras.Sequential(
[teacher_backbone, layers.Dense(NUM_CLASSES)], name="teacher"
)
teacher_model.compile(
optimizer=keras.optimizers.AdamW(learning_rate=1e-3, weight_decay=1e-4),
loss=keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.CategoricalAccuracy(name="accuracy")],
)
print("Fine-tuning teacher head on flowers dataset...")
teacher_model.fit(train_dataset, validation_data=val_dataset, epochs=5)
teacher_model.trainable = False
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/efficientnet_v2/efficientnetv2-b0_notop.h5
24274472/24274472 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Fine-tuning teacher head on flowers dataset...
Epoch 1/5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1778696668.253681 2467 service.cc:152] XLA service 0x7fddf4004780 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1778696668.253740 2467 service.cc:160] StreamExecutor device (0): NVIDIA A100-SXM4-40GB, Compute Capability 8.0
2/13 ━━━━━━━━━━━━━━━━━━━━ 1s 97ms/step - accuracy: 0.2812 - loss: 1.5813
I0000 00:00:1778696697.640091 2467 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
13/13 ━━━━━━━━━━━━━━━━━━━━ 100s 5s/step - accuracy: 0.6031 - loss: 1.1777 - val_accuracy: 0.7363 - val_loss: 0.8961
Epoch 2/5
13/13 ━━━━━━━━━━━━━━━━━━━━ 5s 303ms/step - accuracy: 0.7919 - loss: 0.7255 - val_accuracy: 0.8049 - val_loss: 0.6591
Epoch 3/5
13/13 ━━━━━━━━━━━━━━━━━━━━ 5s 308ms/step - accuracy: 0.8436 - loss: 0.5433 - val_accuracy: 0.8462 - val_loss: 0.5433
Epoch 4/5
13/13 ━━━━━━━━━━━━━━━━━━━━ 5s 309ms/step - accuracy: 0.8724 - loss: 0.4523 - val_accuracy: 0.8626 - val_loss: 0.4762
Epoch 5/5
13/13 ━━━━━━━━━━━━━━━━━━━━ 5s 309ms/step - accuracy: 0.8905 - loss: 0.3986 - val_accuracy: 0.8764 - val_loss: 0.4336
deit_tiny = ViTDistilled()
deit_distiller = DeiT(student=deit_tiny, teacher=teacher_model)
lr_scaled = (BASE_LR / 512) * BATCH_SIZE
deit_distiller.compile(
optimizer=keras.optimizers.AdamW(
weight_decay=WEIGHT_DECAY, learning_rate=lr_scaled
),
student_loss_fn=keras.losses.CategoricalCrossentropy(
from_logits=True, label_smoothing=0.1
),
distillation_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=True),
)
_ = deit_distiller.fit(train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS)
Epoch 1/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 189s 7s/step - accuracy: 0.2257 - distillation_loss: 2.1809 - loss: 2.0463 - student_loss: 1.9302 - val_accuracy: 0.1896 - val_loss: 1.4346 - val_student_loss: 1.5736
Epoch 2/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 367ms/step - accuracy: 0.2184 - distillation_loss: 1.6308 - loss: 1.6227 - student_loss: 1.6147 - val_accuracy: 0.2445 - val_loss: 1.5118 - val_student_loss: 1.5775
Epoch 3/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 378ms/step - accuracy: 0.2151 - distillation_loss: 1.6084 - loss: 1.6084 - student_loss: 1.6084 - val_accuracy: 0.2445 - val_loss: 1.5851 - val_student_loss: 1.6011
Epoch 4/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 375ms/step - accuracy: 0.2426 - distillation_loss: 1.6073 - loss: 1.6051 - student_loss: 1.6028 - val_accuracy: 0.2170 - val_loss: 1.5452 - val_student_loss: 1.5852
Epoch 5/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 368ms/step - accuracy: 0.2387 - distillation_loss: 1.6037 - loss: 1.6033 - student_loss: 1.6029 - val_accuracy: 0.2445 - val_loss: 1.5618 - val_student_loss: 1.5873
Epoch 6/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 369ms/step - accuracy: 0.2489 - distillation_loss: 1.6007 - loss: 1.6008 - student_loss: 1.6008 - val_accuracy: 0.2445 - val_loss: 1.6880 - val_student_loss: 1.6270
Epoch 7/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 361ms/step - accuracy: 0.2308 - distillation_loss: 1.6025 - loss: 1.6014 - student_loss: 1.6003 - val_accuracy: 0.2527 - val_loss: 1.5244 - val_student_loss: 1.5746
Epoch 8/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 363ms/step - accuracy: 0.2480 - distillation_loss: 1.5985 - loss: 1.5973 - student_loss: 1.5964 - val_accuracy: 0.2857 - val_loss: 1.5183 - val_student_loss: 1.5697
Epoch 9/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 361ms/step - accuracy: 0.3073 - distillation_loss: 1.5623 - loss: 1.5619 - student_loss: 1.5617 - val_accuracy: 0.3104 - val_loss: 1.5522 - val_student_loss: 1.5427
Epoch 10/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 363ms/step - accuracy: 0.3657 - distillation_loss: 1.4807 - loss: 1.4769 - student_loss: 1.4729 - val_accuracy: 0.3407 - val_loss: 1.5153 - val_student_loss: 1.4909
Epoch 11/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 369ms/step - accuracy: 0.3941 - distillation_loss: 1.4111 - loss: 1.4111 - student_loss: 1.4115 - val_accuracy: 0.3544 - val_loss: 1.5445 - val_student_loss: 1.4759
Epoch 12/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 377ms/step - accuracy: 0.4250 - distillation_loss: 1.3679 - loss: 1.3678 - student_loss: 1.3676 - val_accuracy: 0.3929 - val_loss: 1.3185 - val_student_loss: 1.3640
Epoch 13/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 367ms/step - accuracy: 0.4604 - distillation_loss: 1.3078 - loss: 1.3064 - student_loss: 1.3049 - val_accuracy: 0.4478 - val_loss: 1.5273 - val_student_loss: 1.4022
Epoch 14/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 364ms/step - accuracy: 0.5094 - distillation_loss: 1.2820 - loss: 1.2736 - student_loss: 1.2653 - val_accuracy: 0.4753 - val_loss: 1.2482 - val_student_loss: 1.3120
Epoch 15/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 367ms/step - accuracy: 0.5088 - distillation_loss: 1.2707 - loss: 1.2553 - student_loss: 1.2394 - val_accuracy: 0.5302 - val_loss: 1.3324 - val_student_loss: 1.2801
Epoch 16/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 366ms/step - accuracy: 0.5544 - distillation_loss: 1.2321 - loss: 1.2154 - student_loss: 1.1985 - val_accuracy: 0.5385 - val_loss: 1.4207 - val_student_loss: 1.2887
Epoch 17/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 377ms/step - accuracy: 0.5693 - distillation_loss: 1.2090 - loss: 1.1889 - student_loss: 1.1681 - val_accuracy: 0.5549 - val_loss: 1.1973 - val_student_loss: 1.2029
Epoch 18/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 365ms/step - accuracy: 0.5941 - distillation_loss: 1.1869 - loss: 1.1691 - student_loss: 1.1510 - val_accuracy: 0.5797 - val_loss: 1.1819 - val_student_loss: 1.1792
Epoch 19/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 372ms/step - accuracy: 0.6062 - distillation_loss: 1.1702 - loss: 1.1475 - student_loss: 1.1242 - val_accuracy: 0.5934 - val_loss: 1.1396 - val_student_loss: 1.1463
Epoch 20/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 358ms/step - accuracy: 0.6243 - distillation_loss: 1.1581 - loss: 1.1337 - student_loss: 1.1085 - val_accuracy: 0.5934 - val_loss: 1.2590 - val_student_loss: 1.1801
In this Keras 3 setup, distillation consistently improves over training the same backbone from scratch under the same budget. In our current run, the distilled model reaches about 61.5% validation accuracy after 20 epochs.
You can adapt the following code to reproduce a non-distilled baseline:
vit_tiny = ViTClassifier()
inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
x = keras.layers.Rescaling(scale=1./255)(inputs)
outputs = deit_tiny(x)
model = keras.Model(inputs, outputs)
model.compile(...)
model.fit(...)
EfficientNetV2B0) provides a strong
signal and stabilizes DeiT training on the flowers dataset.timm
updated with readable implementations. I referred to the implementations of ViT and DeiT
a lot during implementing them in Keras.ViTClassifier in another project.Example available on HuggingFace:
| Trained Model | Demo |
|---|---|