Author: Sayak Paul
Date created: 2022/04/05
Last modified: 2026/02/10
Description: Distillation of Vision Transformers through attention.
In the original Vision Transformers (ViT) paper (Dosovitskiy et al.), the authors concluded that to perform on par with Convolutional Neural Networks (CNNs), ViTs need to be pre-trained on larger datasets. The larger the better. This is mainly due to the lack of inductive biases in the ViT architecture – unlike CNNs, they don't have layers that exploit locality. In a follow-up paper (Steiner et al.), the authors show that it is possible to substantially improve the performance of ViTs with stronger regularization and longer training.
Many groups have proposed different ways to deal with the problem of data-intensiveness of ViT training. One such way was shown in the Data-efficient image Transformers, (DeiT) paper (Touvron et al.). The authors introduced a distillation technique that is specific to transformer-based vision models. DeiT is among the first works to show that it's possible to train ViTs well without using larger datasets.
In this example, we implement the distillation recipe proposed in DeiT. This requires us to slightly tweak the original ViT architecture and write a custom training loop to implement the distillation recipe.
To comfortably navigate through this example, you'll be expected to know how a ViT and knowledge distillation work. The following are good resources in case you needed a refresher:
from typing import List
import tensorflow as tf
import tensorflow_datasets as tfds
import keras
from keras import layers
tfds.disable_progress_bar()
keras.utils.set_random_seed(42)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1770754850.038391 5167 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770754850.043322 5167 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770754850.055075 5167 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770754850.055088 5167 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770754850.055089 5167 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770754850.055090 5167 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
# Model
MODEL_TYPE = "deit_distilled_tiny_patch16_224"
RESOLUTION = 224
PATCH_SIZE = 16
NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 192
NUM_HEADS = 3
NUM_LAYERS = 12
MLP_UNITS = [
PROJECTION_DIM * 4,
PROJECTION_DIM,
]
DROPOUT_RATE = 0.0
DROP_PATH_RATE = 0.1
# Training
NUM_EPOCHS = 20
BASE_LR = 0.0005
WEIGHT_DECAY = 0.0001
# Data
BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
NUM_CLASSES = 5
You probably noticed that DROPOUT_RATE has been set 0.0. Dropout has been used
in the implementation to keep it complete. For smaller models (like the one used in
this example), you don't need it, but for bigger models, using dropout helps.
tf_flowers dataset and prepare preprocessing utilitiesThe authors use an array of different augmentation techniques, including MixUp (Zhang et al.), RandAugment (Cubuk et al.), and so on. However, to keep the example simple to work through, we'll discard them.
def preprocess_dataset(is_training=True):
def fn(image, label):
if is_training:
# Resize to a bigger spatial resolution and take the random
# crops.
image = keras.ops.image.resize(image, (RESOLUTION + 20, RESOLUTION + 20))
# Perform random crop using TensorFlow ops for graph compatibility
# Get random crop coordinates (0 to 20 pixels offset)
crop_top = tf.random.uniform((), 0, 21, dtype=tf.int32)
crop_left = tf.random.uniform((), 0, 21, dtype=tf.int32)
image = tf.image.crop_to_bounding_box(
image,
offset_height=crop_top,
offset_width=crop_left,
target_height=RESOLUTION,
target_width=RESOLUTION,
)
# Random horizontal flip
if tf.random.uniform(()) > 0.5:
image = tf.image.flip_left_right(image)
else:
image = keras.ops.image.resize(image, (RESOLUTION, RESOLUTION))
label = keras.ops.one_hot(label, num_classes=NUM_CLASSES)
return image, label
return fn
def prepare_dataset(dataset, is_training=True):
if is_training:
dataset = dataset.shuffle(BATCH_SIZE * 10)
dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
return dataset.batch(BATCH_SIZE).prefetch(AUTO)
train_dataset, val_dataset = tfds.load(
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")
train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
Number of training examples: 3303
Number of validation examples: 367
Since DeiT is an extension of ViT it'd make sense to first implement ViT and then extend it to support DeiT's components.
First, we'll implement a layer for Stochastic Depth (Huang et al.) which is used in DeiT for regularization.
# Referred from: github.com:rwightman/pytorch-image-models.
class StochasticDepth(layers.Layer):
def __init__(self, drop_prop, **kwargs):
super().__init__(**kwargs)
self.drop_prob = drop_prop
self.seed_generator = keras.random.SeedGenerator(1337)
def call(self, x, training=True):
if training:
keep_prob = 1 - self.drop_prob
shape = (keras.ops.shape(x)[0],) + (1,) * (len(keras.ops.shape(x)) - 1)
random_tensor = keep_prob + keras.random.uniform(
shape, 0, 1, seed=self.seed_generator
)
random_tensor = keras.ops.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
Now, we'll implement the MLP and Transformer blocks.
def mlp(x, dropout_rate: float, hidden_units: List):
"""FFN for a Transformer block."""
# Iterate over the hidden units and
# add Dense => Dropout.
for idx, units in enumerate(hidden_units):
x = layers.Dense(
units,
activation="gelu" if idx == 0 else None,
)(x)
x = layers.Dropout(dropout_rate)(x)
return x
def transformer(drop_prob: float, name: str) -> keras.Model:
"""Transformer block with pre-norm."""
num_patches = NUM_PATCHES + 2 if "distilled" in MODEL_TYPE else NUM_PATCHES + 1
encoded_patches = layers.Input((num_patches, PROJECTION_DIM))
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
# Multi Head Self Attention layer 1.
attention_output = layers.MultiHeadAttention(
num_heads=NUM_HEADS,
key_dim=PROJECTION_DIM,
dropout=DROPOUT_RATE,
)(x1, x1)
attention_output = (
StochasticDepth(drop_prob)(attention_output) if drop_prob else attention_output
)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
# MLP layer 1.
x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=DROPOUT_RATE)
x4 = StochasticDepth(drop_prob)(x4) if drop_prob else x4
# Skip connection 2.
outputs = layers.Add()([x2, x4])
return keras.Model(encoded_patches, outputs, name=name)
We'll now implement a ViTClassifier class building on top of the components we just
developed. Here we'll be following the original pooling strategy used in the ViT paper –
use a class token and use the feature representations corresponding to it for
classification.
class ViTClassifier(keras.Model):
"""Vision Transformer base class."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Patchify + linear projection + reshaping.
self.projection = keras.Sequential(
[
layers.Conv2D(
filters=PROJECTION_DIM,
kernel_size=(PATCH_SIZE, PATCH_SIZE),
strides=(PATCH_SIZE, PATCH_SIZE),
padding="VALID",
name="conv_projection",
),
layers.Reshape(
target_shape=(NUM_PATCHES, PROJECTION_DIM),
name="flatten_projection",
),
],
name="projection",
)
# Transformer blocks.
dpr = [x for x in keras.ops.linspace(0.0, DROP_PATH_RATE, NUM_LAYERS)]
self.transformer_blocks = [
transformer(drop_prob=dpr[i], name=f"transformer_block_{i}")
for i in range(NUM_LAYERS)
]
# Other layers.
self.dropout = layers.Dropout(DROPOUT_RATE)
self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
self.head = layers.Dense(
NUM_CLASSES,
name="classification_head",
)
def build(self, input_shape):
# Positional embedding.
self.positional_embedding = self.add_weight(
shape=(1, NUM_PATCHES + 1, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="pos_embedding",
)
# CLS token.
self.cls_token = self.add_weight(
shape=(1, 1, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="cls",
)
super().build(input_shape)
def call(self, inputs, training=True):
n = keras.ops.shape(inputs)[0]
# Create patches and project the patches.
projected_patches = self.projection(inputs)
cls_token = keras.ops.tile(self.cls_token, (n, 1, 1))
cls_token = keras.ops.cast(cls_token, projected_patches.dtype)
projected_patches = keras.ops.concatenate(
[cls_token, projected_patches], axis=1
)
# Add positional embeddings to the projected patches.
encoded_patches = (
self.positional_embedding + projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = self.dropout(encoded_patches)
# Iterate over the number of layers and stack up blocks of
# Transformer.
for transformer_module in self.transformer_blocks:
# Add a Transformer block.
encoded_patches = transformer_module(encoded_patches)
# Final layer normalization.
representation = self.layer_norm(encoded_patches)
# Pool representation.
encoded_patches = representation[:, 0]
# Classification head.
output = self.head(encoded_patches)
return output
This class can be used standalone as ViT and is end-to-end trainable. Just remove the
distilled phrase in MODEL_TYPE and it should work with vit_tiny = ViTClassifier().
Let's now extend it to DeiT. The following figure presents the schematic of DeiT (taken
from the DeiT paper):

Apart from the class token, DeiT has another token for distillation. During distillation, the logits corresponding to the class token are compared to the true labels, and the logits corresponding to the distillation token are compared to the teacher's predictions.
class ViTDistilled(ViTClassifier):
def __init__(self, regular_training=False, **kwargs):
super().__init__(**kwargs)
self.num_tokens = 2
self.regular_training = regular_training
# Head layers.
self.head = layers.Dense(
NUM_CLASSES,
name="classification_head",
)
self.head_dist = layers.Dense(
NUM_CLASSES,
name="distillation_head",
)
def build(self, input_shape):
# CLS token.
self.cls_token = self.add_weight(
shape=(1, 1, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="cls",
)
# Distillation token.
self.dist_token = self.add_weight(
shape=(1, 1, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="dist_token",
)
# Positional embedding (for NUM_PATCHES + 2 tokens: cls + dist).
self.positional_embedding = self.add_weight(
shape=(1, NUM_PATCHES + self.num_tokens, PROJECTION_DIM),
initializer=keras.initializers.Zeros(),
trainable=True,
name="pos_embedding",
)
def call(self, inputs, training=True):
n = keras.ops.shape(inputs)[0]
# Create patches and project the patches.
projected_patches = self.projection(inputs)
# Append the tokens.
cls_token = keras.ops.tile(self.cls_token, (n, 1, 1))
dist_token = keras.ops.tile(self.dist_token, (n, 1, 1))
cls_token = keras.ops.cast(cls_token, projected_patches.dtype)
dist_token = keras.ops.cast(dist_token, projected_patches.dtype)
projected_patches = keras.ops.concatenate(
[cls_token, dist_token, projected_patches], axis=1
)
# Add positional embeddings to the projected patches.
encoded_patches = (
self.positional_embedding + projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = self.dropout(encoded_patches)
# Iterate over the number of layers and stack up blocks of
# Transformer.
for transformer_module in self.transformer_blocks:
# Add a Transformer block.
encoded_patches = transformer_module(encoded_patches)
# Final layer normalization.
representation = self.layer_norm(encoded_patches)
# Classification heads.
x, x_dist = (
self.head(representation[:, 0]),
self.head_dist(representation[:, 1]),
)
# Only return separate classification predictions when training in distilled
# mode.
if training and not self.regular_training:
return x, x_dist
# During standard train / finetune, inference average the classifier
# predictions.
return (x + x_dist) / 2
Let's verify if the ViTDistilled class can be initialized and called as expected.
deit_tiny_distilled = ViTDistilled()
dummy_inputs = tf.ones((2, 224, 224, 3))
outputs = deit_tiny_distilled(dummy_inputs, training=False)
print(outputs.shape)
(2, 5)
Unlike what happens in standard knowledge distillation (Hinton et al.), where a temperature-scaled softmax is used as well as KL divergence, DeiT authors use the following loss function:

Here,
psi is the softmax functionclass DeiT(keras.Model):
# Reference:
# https://keras.io/examples/vision/knowledge_distillation/
def __init__(self, student, teacher, **kwargs):
super().__init__(**kwargs)
self.student = student
self.teacher = teacher
self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
self.dist_loss_tracker = keras.metrics.Mean(name="distillation_loss")
self.accuracy_metric = keras.metrics.CategoricalAccuracy(name="accuracy")
@property
def metrics(self):
metrics = super().metrics
metrics.append(self.student_loss_tracker)
metrics.append(self.dist_loss_tracker)
metrics.append(self.accuracy_metric)
return metrics
def compile(
self,
optimizer,
student_loss_fn,
distillation_loss_fn,
):
super().compile(optimizer=optimizer)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
def train_step(self, data):
# Unpack data.
x, y = data
# Normalize for student (ViT expects [0, 1])
x_student = keras.ops.cast(x, "float32") / 255.0
# Teacher expects raw [0, 255] float32 (no normalization)
x_teacher = keras.ops.cast(x, "float32")
# Forward pass of teacher
# TFSMLayer returns a dictionary, extract the output
teacher_output = self.teacher(x_teacher, training=False)
if isinstance(teacher_output, dict):
# Get the first (and likely only) output from the dictionary
teacher_output = list(teacher_output.values())[0]
# Use soft targets (probabilities) for distillation
teacher_predictions = keras.ops.nn.softmax(teacher_output, -1)
with tf.GradientTape() as tape:
# Forward pass of student.
cls_predictions, dist_predictions = self.student(x_student, training=True)
# Compute losses.
student_loss = self.student_loss_fn(y, cls_predictions)
distillation_loss = self.distillation_loss_fn(
teacher_predictions, dist_predictions
)
loss = (student_loss + distillation_loss) / 2
# Compute gradients.
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights.
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics.
student_predictions = (cls_predictions + dist_predictions) / 2
self.accuracy_metric.update_state(y, student_predictions)
self.dist_loss_tracker.update_state(distillation_loss)
self.student_loss_tracker.update_state(student_loss)
# Return a dict of performance - include loss
return {
"loss": loss,
"student_loss": self.student_loss_tracker.result(),
"distillation_loss": self.dist_loss_tracker.result(),
"accuracy": self.accuracy_metric.result(),
}
def test_step(self, data):
# Unpack the data.
x, y = data
# Convert to float32 and normalize for student
x_normalized = keras.ops.cast(x, "float32") / 255.0
# Compute predictions.
y_prediction = self.student(x_normalized, training=False)
# Calculate the loss.
student_loss = self.student_loss_fn(y, y_prediction)
# Update the metrics.
self.accuracy_metric.update_state(y, y_prediction)
self.student_loss_tracker.update_state(student_loss)
# Return a dict of performance
return {
"loss": student_loss,
"student_loss": self.student_loss_tracker.result(),
"accuracy": self.accuracy_metric.result(),
}
def call(self, inputs):
# Convert to float32 and normalize for student
inputs_normalized = keras.ops.cast(inputs, "float32") / 255.0
return self.student(inputs_normalized, training=False)
This model is based on the BiT family of ResNets
(Kolesnikov et al.)
fine-tuned on the tf_flowers dataset. You can refer to
this notebook
to know how the training was performed. The teacher model has about 212 Million parameters
which is about 40x more than the student.
!wget -q https://github.com/sayakpaul/deit-tf/releases/download/v0.1.0/bit_teacher_flowers.zip
!unzip -q bit_teacher_flowers.zip
bit_teacher_flowers = keras.layers.TFSMLayer(
"bit_teacher_flowers", call_endpoint="serving_default"
)
deit_tiny = ViTDistilled()
deit_distiller = DeiT(student=deit_tiny, teacher=bit_teacher_flowers)
lr_scaled = (BASE_LR / 512) * BATCH_SIZE
deit_distiller.compile(
optimizer=keras.optimizers.AdamW(
weight_decay=WEIGHT_DECAY, learning_rate=lr_scaled
),
student_loss_fn=keras.losses.CategoricalCrossentropy(
from_logits=True, label_smoothing=0.1
),
distillation_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=True),
)
_ = deit_distiller.fit(train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS)
Epoch 1/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 130s 8s/step - accuracy: 0.2150 - distillation_loss: 2.1021 - loss: 0.0000e+00 - student_loss: 1.8120 - val_accuracy: 0.2616 - val_loss: 1.6223 - val_student_loss: 1.6278
Epoch 2/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.2416 - distillation_loss: 1.6185 - loss: 0.0000e+00 - student_loss: 1.6297 - val_accuracy: 0.1662 - val_loss: 1.6018 - val_student_loss: 1.6075
Epoch 3/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 104s 8s/step - accuracy: 0.2467 - distillation_loss: 1.6028 - loss: 0.0000e+00 - student_loss: 1.6087 - val_accuracy: 0.2316 - val_loss: 1.5954 - val_student_loss: 1.6009
Epoch 4/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.2349 - distillation_loss: 1.5968 - loss: 0.0000e+00 - student_loss: 1.6022 - val_accuracy: 0.2289 - val_loss: 1.5922 - val_student_loss: 1.6017
Epoch 5/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.2634 - distillation_loss: 1.5902 - loss: 0.0000e+00 - student_loss: 1.5928 - val_accuracy: 0.3025 - val_loss: 1.5703 - val_student_loss: 1.5795
Epoch 6/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.3279 - distillation_loss: 1.5441 - loss: 0.0000e+00 - student_loss: 1.5456 - val_accuracy: 0.3515 - val_loss: 1.4880 - val_student_loss: 1.4937
Epoch 7/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.3966 - distillation_loss: 1.4085 - loss: 0.0000e+00 - student_loss: 1.4534 - val_accuracy: 0.3706 - val_loss: 1.4348 - val_student_loss: 1.4335
Epoch 8/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.3890 - distillation_loss: 1.3647 - loss: 0.0000e+00 - student_loss: 1.4229 - val_accuracy: 0.3297 - val_loss: 1.4575 - val_student_loss: 1.4463
Epoch 9/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.4223 - distillation_loss: 1.3332 - loss: 0.0000e+00 - student_loss: 1.3850 - val_accuracy: 0.4114 - val_loss: 1.3888 - val_student_loss: 1.3763
Epoch 10/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.4475 - distillation_loss: 1.2577 - loss: 0.0000e+00 - student_loss: 1.3548 - val_accuracy: 0.4441 - val_loss: 1.3202 - val_student_loss: 1.3331
Epoch 11/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.4717 - distillation_loss: 1.2107 - loss: 0.0000e+00 - student_loss: 1.2995 - val_accuracy: 0.4632 - val_loss: 1.3016 - val_student_loss: 1.2872
Epoch 12/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.5017 - distillation_loss: 1.1562 - loss: 0.0000e+00 - student_loss: 1.2542 - val_accuracy: 0.5395 - val_loss: 1.2761 - val_student_loss: 1.2575
Epoch 13/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.5328 - distillation_loss: 1.1119 - loss: 0.0000e+00 - student_loss: 1.2223 - val_accuracy: 0.5068 - val_loss: 1.2102 - val_student_loss: 1.2383
Epoch 14/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 102s 8s/step - accuracy: 0.5655 - distillation_loss: 1.0595 - loss: 0.0000e+00 - student_loss: 1.1837 - val_accuracy: 0.5722 - val_loss: 1.1773 - val_student_loss: 1.1774
Epoch 15/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.5998 - distillation_loss: 1.0133 - loss: 0.0000e+00 - student_loss: 1.1465 - val_accuracy: 0.5204 - val_loss: 1.2519 - val_student_loss: 1.2340
Epoch 16/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6110 - distillation_loss: 0.9992 - loss: 0.0000e+00 - student_loss: 1.1359 - val_accuracy: 0.6104 - val_loss: 1.0947 - val_student_loss: 1.1090
Epoch 17/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6191 - distillation_loss: 0.9635 - loss: 0.0000e+00 - student_loss: 1.1101 - val_accuracy: 0.6076 - val_loss: 1.0678 - val_student_loss: 1.0952
Epoch 18/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6400 - distillation_loss: 0.9460 - loss: 0.0000e+00 - student_loss: 1.0902 - val_accuracy: 0.6076 - val_loss: 1.0256 - val_student_loss: 1.0681
Epoch 19/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6340 - distillation_loss: 0.9411 - loss: 0.0000e+00 - student_loss: 1.0943 - val_accuracy: 0.6213 - val_loss: 1.0353 - val_student_loss: 1.0702
Epoch 20/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 103s 8s/step - accuracy: 0.6506 - distillation_loss: 0.9121 - loss: 0.0000e+00 - student_loss: 1.0674 - val_accuracy: 0.6376 - val_loss: 1.0027 - val_student_loss: 1.0602
If we had trained the same model (the ViTClassifier) from scratch with the exact same
hyperparameters, the model would have scored about 59% accuracy. You can adapt the following code
to reproduce this result:
vit_tiny = ViTClassifier()
inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
x = keras.layers.Rescaling(scale=1./255)(inputs)
outputs = deit_tiny(x)
model = keras.Model(inputs, outputs)
model.compile(...)
model.fit(...)
timm
updated with readable implementations. I referred to the implementations of ViT and DeiT
a lot during implementing them in Keras.ViTClassifier in another project.Example available on HuggingFace:
| Trained Model | Demo |
|---|---|