» Code examples / Computer Vision / Learning to tokenize in Vision Transformers

Learning to tokenize in Vision Transformers

Authors: Aritra Roy Gosthipaty, Sayak Paul (equal contribution)
Date created: 2021/12/10
Last modified: 2021/12/15
Description: Adaptively generating a smaller number of tokens for Vision Transformers.

View in Colab GitHub source


Introduction

Vision Transformers (Dosovitskiy et al.) and many other Transformer-based architectures (Liu et al., Yuan et al., etc.) have shown strong results in image recognition. The following provides a brief overview of the components involved in the Vision Transformer architecture for image classification:

  • Extract small patches from input images.
  • Linearly project those patches.
  • Add positional embeddings to these linear projections.
  • Run these projections through a series of Transformer (Vaswani et al.) blocks.
  • Finally, take the representation from the final Transformer block and add a classification head.

If we take 224x224 images and extract 16x16 patches, we get a total of 196 patches (also called tokens) for each image. The number of patches increases as we increase the resolution, leading to higher memory footprint. Could we use a reduced number of patches without having to compromise performance? Ryoo et al. investigate this question in TokenLearner: Adaptive Space-Time Tokenization for Videos. They introduce a novel module called TokenLearner that can help reduce the number of patches used by a Vision Transformer (ViT) in an adaptive manner. With TokenLearner incorporated in the standard ViT architecture, they are able to reduce the amount of compute (measured in FLOPS) used by the model.

In this example, we implement the TokenLearner module and demonstrate its performance with a mini ViT and the CIFAR-10 dataset. We make use of the following references:


Setup

We need to install TensorFlow Addons to run this example. To install it, execute the following:

pip install tensorflow-addons

Imports

import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np

import math

Hyperparameters

Please feel free to change the hyperparameters and check your results. The best way to develop intuition about the architecture is to experiment with it.

# DATA
BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (32, 32, 3)
NUM_CLASSES = 10

# OPTIMIZER
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

# TRAINING
EPOCHS = 20

# AUGMENTATION
IMAGE_SIZE = 48  # We will resize input images to this size.
PATCH_SIZE = 6  # Size of the patches to be extracted from the input images.
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2

# ViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 4
MLP_UNITS = [
    PROJECTION_DIM * 2,
    PROJECTION_DIM,
]

# TOKENLEARNER
NUM_TOKENS = 4

Load and prepare the CIFAR-10 dataset

# Load the CIFAR-10 dataset.
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[:40000], y_train[:40000]),
    (x_train[40000:], y_train[40000:]),
)
print(f"Training samples: {len(x_train)}")
print(f"Validation samples: {len(x_val)}")
print(f"Testing samples: {len(x_test)}")

# Convert to tf.data.Dataset objects.
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO)

val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)
Training samples: 40000
Validation samples: 10000
Testing samples: 10000

2021-12-15 13:59:48.329729: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-12-15 13:59:50.627454: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38444 MB memory:  -> device: 0, name: A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0

Data augmentation

The augmentation pipeline consists of:

  • Rescaling
  • Resizing
  • Random cropping (fixed-sized or random sized)
  • Random horizontal flipping
data_augmentation = keras.Sequential(
    [
        layers.Rescaling(1 / 255.0),
        layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
        layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
        layers.RandomFlip("horizontal"),
    ],
    name="data_augmentation",
)

Note that image data augmentation layers do not apply data transformations at inference time. This means that when these layers are called with training=False they behave differently. Refer to the documentation for more details.


Positional embedding module

A Transformer architecture consists of multi-head self attention layers and fully-connected feed forward networks (MLP) as the main components. Both these components are permutation invariant: they're not aware of feature order.

To overcome this problem we inject tokens with positional information. The position_embedding function adds this positional information to the linearly projected tokens.

def position_embedding(
    projected_patches, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM
):
    # Build the positions.
    positions = tf.range(start=0, limit=num_patches, delta=1)

    # Encode the positions with an Embedding layer.
    encoded_positions = layers.Embedding(
        input_dim=num_patches, output_dim=projection_dim
    )(positions)

    # Add encoded positions to the projected patches.
    return projected_patches + encoded_positions

MLP block for Transformer

This serves as the Fully Connected Feed Forward block for our Transformer.

def mlp(x, dropout_rate, hidden_units):
    # Iterate over the hidden units and
    # add Dense => Dropout.
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

TokenLearner module

The following figure presents a pictorial overview of the module (source).

TokenLearner module GIF

The TokenLearner module takes as input an image-shaped tensor. It then passes it through multiple single-channel convolutional layers extracting different spatial attention maps focusing on different parts of the input. These attention maps are then element-wise multiplied to the input and result is aggregated with pooling. This pooled output can be trated as a summary of the input and has much lesser number of patches (8, for example) than the original one (196, for example).

Using multiple convolution layers helps with expressivity. Imposing a form of spatial attention helps retain relevant information from the inputs. Both of these components are crucial to make TokenLearner work, especially when we are significantly reducing the number of patches.

def token_learner(inputs, number_of_tokens=NUM_TOKENS):
    # Layer normalize the inputs.
    x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(inputs)  # (B, H, W, C)

    # Applying Conv2D => Reshape => Permute
    # The reshape and permute is done to help with the next steps of
    # multiplication and Global Average Pooling.
    attention_maps = keras.Sequential(
        [
            # 3 layers of conv with gelu activation as suggested
            # in the paper.
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3, 3),
                activation=tf.nn.gelu,
                padding="same",
                use_bias=False,
            ),
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3, 3),
                activation=tf.nn.gelu,
                padding="same",
                use_bias=False,
            ),
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3, 3),
                activation=tf.nn.gelu,
                padding="same",
                use_bias=False,
            ),
            # This conv layer will generate the attention maps
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3, 3),
                activation="sigmoid",  # Note sigmoid for [0, 1] output
                padding="same",
                use_bias=False,
            ),
            # Reshape and Permute
            layers.Reshape((-1, number_of_tokens)),  # (B, H*W, num_of_tokens)
            layers.Permute((2, 1)),
        ]
    )(
        x
    )  # (B, num_of_tokens, H*W)

    # Reshape the input to align it with the output of the conv block.
    num_filters = inputs.shape[-1]
    inputs = layers.Reshape((1, -1, num_filters))(inputs)  # inputs == (B, 1, H*W, C)

    # Element-Wise multiplication of the attention maps and the inputs
    attended_inputs = (
        attention_maps[..., tf.newaxis] * inputs
    )  # (B, num_tokens, H*W, C)

    # Global average pooling the element wise multiplication result.
    outputs = tf.reduce_mean(attended_inputs, axis=2)  # (B, num_tokens, C)
    return outputs

Transformer block

def transformer(encoded_patches):
    # Layer normalization 1.
    x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)

    # Multi Head Self Attention layer 1.
    attention_output = layers.MultiHeadAttention(
        num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
    )(x1, x1)

    # Skip connection 1.
    x2 = layers.Add()([attention_output, encoded_patches])

    # Layer normalization 2.
    x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)

    # MLP layer 1.
    x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=0.1)

    # Skip connection 2.
    encoded_patches = layers.Add()([x4, x2])
    return encoded_patches

ViT model with the TokenLearner module

def create_vit_classifier(use_token_learner=True, token_learner_units=NUM_TOKENS):
    inputs = layers.Input(shape=INPUT_SHAPE)  # (B, H, W, C)

    # Augment data.
    augmented = data_augmentation(inputs)

    # Create patches and project the pathces.
    projected_patches = layers.Conv2D(
        filters=PROJECTION_DIM,
        kernel_size=(PATCH_SIZE, PATCH_SIZE),
        strides=(PATCH_SIZE, PATCH_SIZE),
        padding="VALID",
    )(augmented)
    _, h, w, c = projected_patches.shape
    projected_patches = layers.Reshape((h * w, c))(
        projected_patches
    )  # (B, number_patches, projection_dim)

    # Add positional embeddings to the projected patches.
    encoded_patches = position_embedding(
        projected_patches
    )  # (B, number_patches, projection_dim)
    encoded_patches = layers.Dropout(0.1)(encoded_patches)

    # Iterate over the number of layers and stack up blocks of
    # Transformer.
    for i in range(NUM_LAYERS):
        # Add a Transformer block.
        encoded_patches = transformer(encoded_patches)

        # Add TokenLearner layer in the middle of the
        # architecture. The paper suggests that anywhere
        # between 1/2 or 3/4 will work well.
        if use_token_learner and i == NUM_LAYERS // 2:
            _, hh, c = encoded_patches.shape
            h = int(math.sqrt(hh))
            encoded_patches = layers.Reshape((h, h, c))(
                encoded_patches
            )  # (B, h, h, projection_dim)
            encoded_patches = token_learner(
                encoded_patches, token_learner_units
            )  # (B, num_tokens, c)

    # Layer normalization and Global average pooling.
    representation = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
    representation = layers.GlobalAvgPool1D()(representation)

    # Classify outputs.
    outputs = layers.Dense(NUM_CLASSES, activation="softmax")(representation)

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

As shown in the TokenLearner paper, it is almost always advantageous to include the TokenLearner module in the middle of the network.


Training utility

def run_experiment(model):
    # Initialize the AdamW optimizer.
    optimizer = tfa.optimizers.AdamW(
        learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )

    # Compile the model with the optimizer, loss function
    # and the metrics.
    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    # Define callbacks
    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    # Train the model.
    _ = model.fit(
        train_ds,
        epochs=EPOCHS,
        validation_data=val_ds,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(test_ds)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

Train and evaluate a ViT with TokenLearner

vit_token_learner = create_vit_classifier()
run_experiment(vit_token_learner)
Epoch 1/20

2021-12-15 13:59:59.531011: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8200
2021-12-15 14:00:04.728435: I tensorflow/stream_executor/cuda/cuda_blas.cc:1774] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.

157/157 [==============================] - 20s 39ms/step - loss: 2.2716 - accuracy: 0.1396 - top-5-accuracy: 0.5908 - val_loss: 2.0672 - val_accuracy: 0.2004 - val_top-5-accuracy: 0.7632
Epoch 2/20
157/157 [==============================] - 5s 34ms/step - loss: 1.9780 - accuracy: 0.2488 - top-5-accuracy: 0.7917 - val_loss: 1.8621 - val_accuracy: 0.2986 - val_top-5-accuracy: 0.8391
Epoch 3/20
157/157 [==============================] - 5s 34ms/step - loss: 1.8168 - accuracy: 0.3138 - top-5-accuracy: 0.8437 - val_loss: 1.7044 - val_accuracy: 0.3680 - val_top-5-accuracy: 0.8793
Epoch 4/20
157/157 [==============================] - 5s 34ms/step - loss: 1.6765 - accuracy: 0.3701 - top-5-accuracy: 0.8820 - val_loss: 1.6490 - val_accuracy: 0.3857 - val_top-5-accuracy: 0.8809
Epoch 5/20
157/157 [==============================] - 5s 34ms/step - loss: 1.6091 - accuracy: 0.4058 - top-5-accuracy: 0.8978 - val_loss: 1.5899 - val_accuracy: 0.4221 - val_top-5-accuracy: 0.8989
Epoch 6/20
157/157 [==============================] - 5s 34ms/step - loss: 1.5386 - accuracy: 0.4340 - top-5-accuracy: 0.9097 - val_loss: 1.5434 - val_accuracy: 0.4321 - val_top-5-accuracy: 0.9098
Epoch 7/20
157/157 [==============================] - 5s 34ms/step - loss: 1.4944 - accuracy: 0.4481 - top-5-accuracy: 0.9171 - val_loss: 1.4914 - val_accuracy: 0.4674 - val_top-5-accuracy: 0.9146
Epoch 8/20
157/157 [==============================] - 5s 33ms/step - loss: 1.4767 - accuracy: 0.4586 - top-5-accuracy: 0.9179 - val_loss: 1.5280 - val_accuracy: 0.4528 - val_top-5-accuracy: 0.9090
Epoch 9/20
157/157 [==============================] - 5s 34ms/step - loss: 1.4331 - accuracy: 0.4751 - top-5-accuracy: 0.9248 - val_loss: 1.3996 - val_accuracy: 0.4857 - val_top-5-accuracy: 0.9298
Epoch 10/20
157/157 [==============================] - 5s 34ms/step - loss: 1.3990 - accuracy: 0.4925 - top-5-accuracy: 0.9291 - val_loss: 1.3888 - val_accuracy: 0.4872 - val_top-5-accuracy: 0.9308
Epoch 11/20
157/157 [==============================] - 5s 33ms/step - loss: 1.3646 - accuracy: 0.5019 - top-5-accuracy: 0.9355 - val_loss: 1.4330 - val_accuracy: 0.4811 - val_top-5-accuracy: 0.9208
Epoch 12/20
157/157 [==============================] - 5s 34ms/step - loss: 1.3607 - accuracy: 0.5037 - top-5-accuracy: 0.9354 - val_loss: 1.3242 - val_accuracy: 0.5149 - val_top-5-accuracy: 0.9415
Epoch 13/20
157/157 [==============================] - 5s 34ms/step - loss: 1.3303 - accuracy: 0.5170 - top-5-accuracy: 0.9384 - val_loss: 1.2934 - val_accuracy: 0.5295 - val_top-5-accuracy: 0.9437
Epoch 14/20
157/157 [==============================] - 5s 33ms/step - loss: 1.3038 - accuracy: 0.5259 - top-5-accuracy: 0.9426 - val_loss: 1.3102 - val_accuracy: 0.5187 - val_top-5-accuracy: 0.9422
Epoch 15/20
157/157 [==============================] - 5s 33ms/step - loss: 1.2926 - accuracy: 0.5304 - top-5-accuracy: 0.9441 - val_loss: 1.3220 - val_accuracy: 0.5234 - val_top-5-accuracy: 0.9428
Epoch 16/20
157/157 [==============================] - 5s 34ms/step - loss: 1.2724 - accuracy: 0.5346 - top-5-accuracy: 0.9458 - val_loss: 1.2670 - val_accuracy: 0.5370 - val_top-5-accuracy: 0.9491
Epoch 17/20
157/157 [==============================] - 5s 33ms/step - loss: 1.2515 - accuracy: 0.5450 - top-5-accuracy: 0.9462 - val_loss: 1.2837 - val_accuracy: 0.5349 - val_top-5-accuracy: 0.9474
Epoch 18/20
157/157 [==============================] - 5s 33ms/step - loss: 1.2427 - accuracy: 0.5505 - top-5-accuracy: 0.9492 - val_loss: 1.3425 - val_accuracy: 0.5180 - val_top-5-accuracy: 0.9371
Epoch 19/20
157/157 [==============================] - 5s 34ms/step - loss: 1.2129 - accuracy: 0.5605 - top-5-accuracy: 0.9514 - val_loss: 1.2297 - val_accuracy: 0.5590 - val_top-5-accuracy: 0.9536
Epoch 20/20
157/157 [==============================] - 5s 33ms/step - loss: 1.1994 - accuracy: 0.5667 - top-5-accuracy: 0.9523 - val_loss: 1.2390 - val_accuracy: 0.5577 - val_top-5-accuracy: 0.9528
40/40 [==============================] - 0s 11ms/step - loss: 1.2293 - accuracy: 0.5564 - top-5-accuracy: 0.9549
Test accuracy: 55.64%
Test top 5 accuracy: 95.49%

Results

We experimented with and without the TokenLearner inside the mini ViT we implemented (with the same hyperparameters presented in this example). Here are our results:

TokenLearner # tokens in
TokenLearner
Top-1 Acc
(Averaged across 5 runs)
GFLOPs TensorBoard
N - 56.112% 0.0184 Link
Y 8 56.55% 0.0153 Link
N - 56.37% 0.0184 Link
Y 4 56.4980% 0.0147 Link
N - (# Transformer layers: 8) 55.36% 0.0359 Link

TokenLearner is able to consistently outperform our mini ViT without the module. It is also interesting to notice that it was also able to outperform a deeper version of our mini ViT (with 8 layers). The authors also report similar observations in the paper and they attribute this to the adaptiveness of TokenLearner.

One should also note that the FLOPs count decreases considerably with the addition of the TokenLearner module. With less FLOPs count the TokenLearner module is able to deliver better results. This aligns very well with the authors' findings.

Additionally, the authors introduced a newer version of the TokenLearner for smaller training data regimes. Quoting the authors:

Instead of using 4 conv. layers with small channels to implement spatial attention, this version uses 2 grouped conv. layers with more channels. It also uses softmax instead of sigmoid. We confirmed that this version works better when having limited training data, such as training with ImageNet1K from scratch.

We experimented with this module and in the following table we summarize the results:

# Groups # Tokens Top-1 Acc GFLOPs TensorBoard
4 4 54.638% 0.0149 Link
8 8 54.898% 0.0146 Link
4 8 55.196% 0.0149 Link

Please note that we used the same hyperparameters presented in this example. Our implementation is available in this notebook. We acknowledge that the results with this new TokenLearner module are slightly off than expected and this might mitigate with hyperparameter tuning.

Note: To compute the FLOPs of our models we used this utility from this repository.


Number of parameters

You may have noticed that adding the TokenLearner module increases the number of parameters of the base network. But that does not mean it is less efficient as shown by Dehghani et al.. Similar findings were reported by Bello et al. as well. The TokenLearner module helps reducing the FLOPS in the overall network thereby helping to reduce the memory footprint.


Final notes

  • TokenFuser: The authors of the paper also propose another module named TokenFuser. This module helps in remapping the representation of the TokenLearner output back to its original spatial resolution. To reuse the TokenLearner in the ViT architecture, the TokenFuser is a must. We first learn the tokens from the TokenLearner, build a representation of the tokens from a Transformer layer and then remap the representation into the original spatial resolution, so that it can again be consumed by a TokenLearner. Note here that you can only use the TokenLearner module once in entire ViT model if not paired with the TokenFuser.
  • Use of these modules for video: The authors also suggest that TokenFuser goes really well with Vision Transformers for Videos (Arnab et al.).

We are grateful to JarvisLabs and Google Developers Experts program for helping with GPU credits. Also, we are thankful to Michael Ryoo (first author of TokenLearner) for fruitful discussions.

Trained Model Demo
Generic badge Generic badge