» Code examples / Computer Vision / When Recurrence meets Transformers

When Recurrence meets Transformers

Author: Aritra Roy Gosthipaty, Suvaditya Mukherjee
Date created: 2023/03/12
Last modified: 2023/03/12
Description: Image Classification with Temporal Latent Bottleneck Networks.

View in Colab GitHub source


A simple Recurrent Neural Network (RNN) displays a strong inductive bias towards learning temporally compressed representations. Equation 1 shows the recurrence formula, where h_t is the compressed representation (a single vector) of the entire input sequence x.

Equation of RNN
Equation 1: The recurrence equation. (Source: Aritra and Suvaditya)

On the other hand, Transformers (Vaswani et. al) have little inductive bias towards learning temporally compressed representations. Transformer has achieved SoTA results in Natural Language Processing (NLP) and Vision tasks with its pairwise attention mechanism.

While the Transformer has the ability to attend to different sections of the input sequence, the computation of attention is quadratic in nature.

Didolkar et. al argue that having a more compressed representation of a sequence may be beneficial for generalization, as it can be easily re-used and re-purposed with fewer irrelevant details. While compression is good, they also notice that too much of it can harm expressiveness.

The authors propose a solution that divides computation into two streams. A slow stream that is recurrent in nature and a fast stream that is parameterized as a Transformer. While this method has the novelty of introducing different processing streams in order to preserve and process latent states, it has parallels drawn in other works like the Perceiver Mechanism (by Jaegle et. al.) and Grounded Language Learning Fast and Slow (by Hill et. al.).

The following example explores how we can make use of the new Temporal Latent Bottleneck mechanism to perform image classification on the CIFAR-10 dataset. We implement this model by making a custom RNNCell implementation in order to make a performant and vectorized design.

Note: This example makes use of TensorFlow 2.12.0, which must be installed into our system

Setup imports

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision
from tensorflow.keras.optimizers import AdamW

import random
from matplotlib import pyplot as plt

# Set seed for reproducibility.


Setting required configuration

We set a few configuration parameters that are needed within the pipeline we have designed. The current parameters are for use with the CIFAR10 dataset.

The model also supports mixed-precision settings, which would quantize the model to use 16-bit float numbers where it can, while keeping some parameters in 32-bit as needed for numerical stability. This brings performance benefits as the footprint of the model decreases significantly while bringing speed boosts at inference-time.

config = {
    "mixed_precision": True,
    "dataset": "cifar10",
    "train_slice": 40_000,
    "batch_size": 2048,
    "buffer_size": 2048 * 2,
    "input_shape": [32, 32, 3],
    "image_size": 48,
    "num_classes": 10,
    "learning_rate": 1e-4,
    "weight_decay": 1e-4,
    "epochs": 30,
    "patch_size": 4,
    "embed_dim": 64,
    "chunk_size": 8,
    "r": 2,
    "num_layers": 4,
    "ffn_drop": 0.2,
    "attn_drop": 0.2,
    "num_heads": 1,

if config["mixed_precision"]:
    policy = mixed_precision.Policy("mixed_float16")
INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA A100-PCIE-40GB, compute capability 8.0

Loading the CIFAR-10 dataset

We are going to use the CIFAR10 dataset for running our experiments. This dataset contains a training set of 50,000 images for 10 classes with the standard image size of (32, 32, 3).

It also has a separate set of 10,000 images with similar characteristics. More information about the dataset may be found at the official site for the dataset as well as keras.datasets.cifar10 API reference

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[: config["train_slice"]], y_train[: config["train_slice"]]),
    (x_train[config["train_slice"] :], y_train[config["train_slice"] :]),

Define data augmentation for the training and validation/test pipelines

We define separate pipelines for performing image augmentation on our data. This step is important to make the model more robust to changes, helping it generalize better. The preprocessing and augmentation steps we perform are as follows:

  • Rescaling (training, test): This step is performed to normalize all image pixel values from the [0,255] range to [0,1). This helps in maintaining numerical stability later ahead during training.
  • Resizing (training, test): We resize the image from it's original size of (32, 32) to (52, 52). This is done to account for the Random Crop, as well as comply with the specifications of the data given in the paper.
  • RandomCrop (training): This layer randomly selects a crop/sub-region of the image with size (48, 48).
  • RandomFlip (training): This layer randomly flips all the images horizontally, keeping image sizes the same.
# Build the `train` augmentation pipeline.
train_augmentation = keras.Sequential(
        layers.Rescaling(1 / 255.0, dtype="float32"),
            config["input_shape"][0] + 20,
            config["input_shape"][0] + 20,
        layers.RandomCrop(config["image_size"], config["image_size"], dtype="float32"),
        layers.RandomFlip("horizontal", dtype="float32"),

# Build the `val` and `test` data pipeline.
test_augmentation = keras.Sequential(
        layers.Rescaling(1 / 255.0, dtype="float32"),
        layers.Resizing(config["image_size"], config["image_size"], dtype="float32"),

# We define functions in place of simple lambda functions to run through the
# [`keras.Sequential`](/api/models/sequential#sequential-class)in order to solve this warning:
# (https://github.com/tensorflow/tensorflow/issues/56089)

def train_map_fn(image, label):
    return train_augmentation(image), label

def test_map_fn(image, label):
    return test_augmentation(image), label

Load dataset into tf.data.Dataset object

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = (
    train_ds.map(train_map_fn, num_parallel_calls=AUTO)
    .batch(config["batch_size"], num_parallel_calls=AUTO)

val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = (
    val_ds.map(test_map_fn, num_parallel_calls=AUTO)
    .batch(config["batch_size"], num_parallel_calls=AUTO)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = (
    test_ds.map(test_map_fn, num_parallel_calls=AUTO)
    .batch(config["batch_size"], num_parallel_calls=AUTO)

Temporal Latent Bottleneck

An excerpt from the paper:

In the brain, short-term and long-term memory have developed in a specialized way. Short-term memory is allowed to change very quickly to react to immediate sensory inputs and perception. By contrast, long-term memory changes slowly, is highly selective and involves repeated consolidation.

Inspired from the short-term and long-term memory the authors introduce the fast stream and slow stream computation. The fast stream has a short-term memory with a high capacity that reacts quickly to sensory input (Transformers). The slow stream has long-term memory which updates at a slower rate and summarizes the most relevant information (Recurrence).

To implement this idea we need to:

  • Take a sequence of data.
  • Divide the sequence into fixed-size chunks.
  • Fast stream operates within each chunk. It provides fine-grained local information.
  • Slow stream consolidates and aggregates information across chunks. It provides coarse-grained distant information.

The fast and slow stream induce what is called information asymmetry. The two streams interact with each other through a bottleneck of attention. Figure 1 shows the architecture of the model.

Architecture of the model
Figure 1: Architecture of the model. (Source: https://arxiv.org/abs/2205.14794)

A PyTorch-style pseudocode is also proposed by the authors as shown in Algorithm 1.

Pseudocode of the model
Algorithm 1: PyTorch style pseudocode. (Source: https://arxiv.org/abs/2205.14794)

PatchEmbedding layer

This custom keras.layers.Layer is useful for generating patches from the image and transform them into a higher-dimensional embedding space using keras.layers.Embedding. The patching operation is done using a keras.layers.Conv2D instance instead of a traditional tf.image.extract_patches to allow for vectorization.

Once the patching of images is complete, we reshape the image patches in order to get a flattened representation where the number of dimensions is the embedding dimension. At this stage, we also inject positional information to the tokens.

After we obtain the tokens we chunk them. The chunking operation involves taking fixed-size sequences from the embedding output to create 'chunks', which will then be used as the final input to the model.

class PatchEmbedding(layers.Layer):
    """Image to Patch Embedding.
        image_size (`Tuple[int]`): Size of the input image.
        patch_size (`Tuple[int]`): Size of the patch.
        embed_dim (`int`): Dimension of the embedding.
        chunk_size (`int`): Number of patches to be chunked.

    def __init__(

        # Compute the patch resolution.
        patch_resolution = [
            image_size[0] // patch_size[0],
            image_size[1] // patch_size[1],

        # Store the parameters.
        self.image_size = image_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.patch_resolution = patch_resolution
        self.num_patches = patch_resolution[0] * patch_resolution[1]

        # Define the positions of the patches.
        self.positions = tf.range(start=0, limit=self.num_patches, delta=1)

        # Create the layers.
        self.projection = layers.Conv2D(
        self.flatten = layers.Reshape(
            target_shape=(-1, embed_dim),
        self.position_embedding = layers.Embedding(
        self.layernorm = keras.layers.LayerNormalization(
        self.chunking_layer = layers.Reshape(
            target_shape=(self.num_patches // chunk_size, chunk_size, embed_dim),

    def call(self, inputs):
        # Project the inputs to the embedding dimension.
        x = self.projection(inputs)

        # Flatten the pathces and add position embedding.
        x = self.flatten(x)
        x = x + self.position_embedding(self.positions)

        # Normalize the embeddings.
        x = self.layernorm(x)

        # Chunk the tokens.
        x = self.chunking_layer(x)

        return x

FeedForwardNetwork Layer

This custom keras.layers.Layer instance allows us to define a generic FFN along with a dropout.

class FeedForwardNetwork(layers.Layer):
    """Feed Forward Network.
        dims (`int`): Number of units in FFN.
        dropout (`float`): Dropout probability for FFN.

    def __init__(self, dims, dropout, **kwargs):

        # Create the layers.
        self.ffn = keras.Sequential(
                layers.Dense(units=4 * dims, activation=tf.nn.gelu),
        self.layernorm = layers.LayerNormalization(

    def call(self, inputs):
        # Apply the FFN.
        x = self.layernorm(inputs)
        x = inputs + self.ffn(x)
        return x

BaseAttention layer

This custom keras.layers.Layer instance is a super/base class that wraps a keras.layers.MultiHeadAttention layer along with some other components. This gives us basic common denominator functionality for all the Attention layers/modules in our model.

class BaseAttention(layers.Layer):
    """Base Attention Module.
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for key.
        dropout (`float`): Dropout probability for attention module.

    def __init__(self, num_heads, key_dim, dropout, **kwargs):
        self.multi_head_attention = layers.MultiHeadAttention(
        self.query_layernorm = layers.LayerNormalization(
        self.key_layernorm = layers.LayerNormalization(
        self.value_layernorm = layers.LayerNormalization(

        self.attention_scores = None

    def call(self, input_query, key, value):
        # Apply the attention module.
        query = self.query_layernorm(input_query)
        key = self.key_layernorm(key)
        value = self.value_layernorm(value)
        (attention_outputs, attention_scores) = self.multi_head_attention(

        # Save the attention scores for later visualization.
        self.attention_scores = attention_scores

        # Add the input to the attention output.
        x = input_query + attention_outputs
        return x

Attention with FeedForwardNetwork layer

This custom keras.layers.Layer implementation combines the BaseAttention and FeedForwardNetwork components to develop one block which will be used repeatedly within the model. This module is highly customizable and flexible, allowing for changes within the internal layers.

class AttentionWithFFN(layers.Layer):
    """Attention with Feed Forward Network.
        ffn_dims (`int`): Number of units in FFN.
        ffn_dropout (`float`): Dropout probability for FFN.
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for key.
        attn_dropout (`float`): Dropout probability for attention module.

    def __init__(
        # Create the layers.
        self.attention = BaseAttention(
        self.ffn = FeedForwardNetwork(

        self.attention_scores = None

    def call(self, query, key, value):
        # Apply the attention module.
        x = self.attention(query, key, value)

        # Save the attention scores for later visualization.
        self.attention_scores = self.attention.attention_scores

        # Apply the FFN.
        x = self.ffn(x)
        return x

Custom RNN Cell for Temporal Latent Bottleneck and Perceptual Module

Algorithm 1 (the pseudocode) depicts recurrence with the help of for loops. Looping does make the implementation simpler, harming the training time. In this section we wrap the custom recurrence logic inside of the CustomRecurrentCell. This custom cell will then be wrapped with the Keras RNN API that makes the entire code vectorizable.

This custom cell, implemented as a keras.layers.Layer, is the integral part of the logic for the model. The cell's functionality can be divided into 2 parts: - Slow Stream (Temporal Latent Bottleneck):

  • This module consists of a single AttentionWithFFN layer that parses the output of the previous Slow Stream, an intermediate hidden representation (which is the latent in Temporal Latent Bottleneck) as the Query, and the output of the latest Fast Stream as Key and Value. This layer can also be construed as a CrossAttention layer.
  • Fast Stream (Perceptual Module):
  • This module consists of intertwined AttentionWithFFN layers. This stream consists of n layers of SelfAttention and CrossAttention in a sequential manner.
  • Here, some layers take the chunked input as the Query, Key and Value (Also referred to as the SelfAttention layer).
  • The other layers take the intermediate state outputs from within the Temporal Latent Bottleneck module as the Query while using the output of the previous Self-Attention layers before it as the Key and Value.
class CustomRecurrentCell(layers.Layer):
    """Custom Recurrent Cell.
        chunk_size (`int`): Number of tokens in a chunk.
        r (`int`): One Cross Attention per **r** Self Attention.
        num_layers (`int`): Number of layers.
        ffn_dims (`int`): Number of units in FFN.
        ffn_dropout (`float`): Dropout probability for FFN.
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for key.
        attn_dropout (`float`): Dropout probability for attention module.

    def __init__(
        # Save the arguments.
        self.chunk_size = chunk_size
        self.r = r
        self.num_layers = num_layers
        self.ffn_dims = ffn_dims
        self.ffn_droput = ffn_dropout
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.attn_dropout = attn_dropout

        # Create the state_size and output_size. This is important for
        # custom recurrence logic.
        self.state_size = tf.TensorShape([chunk_size, ffn_dims])
        self.output_size = tf.TensorShape([chunk_size, ffn_dims])

        self.get_attention_scores = False
        self.attention_scores = []

        # Perceptual Module
        perceptual_module = list()
        for layer_idx in range(num_layers):
            if layer_idx % r == 0:
        self.perceptual_module = perceptual_module

        # Temporal Latent Bottleneck Module
        self.tlb_module = AttentionWithFFN(

    def call(self, inputs, states):
        # inputs => (batch, chunk_size, dims)
        # states => [(batch, chunk_size, units)]
        slow_stream = states[0]
        fast_stream = inputs

        for layer_idx, layer in enumerate(self.perceptual_module):
            fast_stream = layer(query=fast_stream, key=fast_stream, value=fast_stream)

            if layer_idx % self.r == 0:
                fast_stream = layer(
                    query=fast_stream, key=slow_stream, value=slow_stream

        slow_stream = self.tlb_module(
            query=slow_stream, key=fast_stream, value=fast_stream

        # Save the attention scores for later visualization.
        if self.get_attention_scores:

        return fast_stream, [slow_stream]

TemporalLatentBottleneckModel to encapsulate full model

Here, we just wrap the full model as to expose it for training.

class TemporalLatentBottleneckModel(keras.Model):
    """Model Trainer.
        patch_layer ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): Patching layer.
        custom_cell ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): Custom Recurrent Cell.

    def __init__(self, patch_layer, custom_cell, **kwargs):
        self.patch_layer = patch_layer
        self.rnn = layers.RNN(custom_cell, name="rnn")
        self.gap = layers.GlobalAveragePooling1D(name="gap")
        self.head = layers.Dense(10, activation="softmax", dtype="float32", name="head")

    def call(self, inputs):
        x = self.patch_layer(inputs)
        x = self.rnn(x)
        x = self.gap(x)
        outputs = self.head(x)
        return outputs

Build the model

To begin training, we now define the components individually and pass them as arguments to our wrapper class, which will prepare the final model for training. We define a PatchEmbed layer, and the CustomCell-based RNN.

# Build the model.
patch_layer = PatchEmbedding(
    image_size=(config["image_size"], config["image_size"]),
    patch_size=(config["patch_size"], config["patch_size"]),
custom_rnn_cell = CustomRecurrentCell(
model = TemporalLatentBottleneckModel(

Metrics and Callbacks

We use the AdamW optimizer since it has been shown to perform very well on several benchmark tasks from an optimization perspective. It is a version of the keras.optimizers.Adam optimizer, along with Weight Decay in place.

For a loss function, we make use of the keras.losses.SparseCategoricalCrossentropy function that makes use of simple Cross-entropy between prediction and actual logits. We also calculate accuracy on our data as a sanity-check.

optimizer = AdamW(
    learning_rate=config["learning_rate"], weight_decay=config["weight_decay"]

Train the model with model.fit()

We pass the training dataset and run training.

history = model.fit(
Epoch 1/30
20/20 [==============================] - 104s 3s/step - loss: 2.6284 - accuracy: 0.1010 - val_loss: 2.2835 - val_accuracy: 0.1251
Epoch 2/30
20/20 [==============================] - 35s 2s/step - loss: 2.2797 - accuracy: 0.1542 - val_loss: 2.1721 - val_accuracy: 0.1846
Epoch 3/30
20/20 [==============================] - 34s 2s/step - loss: 2.1989 - accuracy: 0.1883 - val_loss: 2.1288 - val_accuracy: 0.2241
Epoch 4/30
20/20 [==============================] - 34s 2s/step - loss: 2.1267 - accuracy: 0.2192 - val_loss: 2.0919 - val_accuracy: 0.2477
Epoch 5/30
20/20 [==============================] - 33s 2s/step - loss: 2.0653 - accuracy: 0.2393 - val_loss: 2.0134 - val_accuracy: 0.2671
Epoch 6/30
20/20 [==============================] - 34s 2s/step - loss: 2.0327 - accuracy: 0.2524 - val_loss: 2.0258 - val_accuracy: 0.2665
Epoch 7/30
20/20 [==============================] - 34s 2s/step - loss: 2.0047 - accuracy: 0.2598 - val_loss: 1.9871 - val_accuracy: 0.2831
Epoch 8/30
20/20 [==============================] - 34s 2s/step - loss: 1.9765 - accuracy: 0.2781 - val_loss: 1.9550 - val_accuracy: 0.2968
Epoch 9/30
20/20 [==============================] - 34s 2s/step - loss: 1.9432 - accuracy: 0.2883 - val_loss: 1.9559 - val_accuracy: 0.2969
Epoch 10/30
20/20 [==============================] - 33s 2s/step - loss: 1.9062 - accuracy: 0.3020 - val_loss: 1.8967 - val_accuracy: 0.3200
Epoch 11/30
20/20 [==============================] - 33s 2s/step - loss: 1.8741 - accuracy: 0.3158 - val_loss: 1.8648 - val_accuracy: 0.3330
Epoch 12/30
20/20 [==============================] - 33s 2s/step - loss: 1.8336 - accuracy: 0.3282 - val_loss: 1.7863 - val_accuracy: 0.3464
Epoch 13/30
20/20 [==============================] - 33s 2s/step - loss: 1.7931 - accuracy: 0.3434 - val_loss: 1.7364 - val_accuracy: 0.3669
Epoch 14/30
20/20 [==============================] - 34s 2s/step - loss: 1.7491 - accuracy: 0.3558 - val_loss: 1.7104 - val_accuracy: 0.3710
Epoch 15/30
20/20 [==============================] - 34s 2s/step - loss: 1.7182 - accuracy: 0.3686 - val_loss: 1.6883 - val_accuracy: 0.3866
Epoch 16/30
20/20 [==============================] - 33s 2s/step - loss: 1.6819 - accuracy: 0.3790 - val_loss: 1.6493 - val_accuracy: 0.3933
Epoch 17/30
20/20 [==============================] - 33s 2s/step - loss: 1.6594 - accuracy: 0.3873 - val_loss: 1.6021 - val_accuracy: 0.4161
Epoch 18/30
20/20 [==============================] - 33s 2s/step - loss: 1.6279 - accuracy: 0.3946 - val_loss: 1.5949 - val_accuracy: 0.4170
Epoch 19/30
20/20 [==============================] - 34s 2s/step - loss: 1.6127 - accuracy: 0.4015 - val_loss: 1.5672 - val_accuracy: 0.4239
Epoch 20/30
20/20 [==============================] - 33s 2s/step - loss: 1.5995 - accuracy: 0.4079 - val_loss: 1.5795 - val_accuracy: 0.4223
Epoch 21/30
20/20 [==============================] - 34s 2s/step - loss: 1.5809 - accuracy: 0.4167 - val_loss: 1.5294 - val_accuracy: 0.4390
Epoch 22/30
20/20 [==============================] - 34s 2s/step - loss: 1.5572 - accuracy: 0.4254 - val_loss: 1.5192 - val_accuracy: 0.4455
Epoch 23/30
20/20 [==============================] - 33s 2s/step - loss: 1.5468 - accuracy: 0.4291 - val_loss: 1.5243 - val_accuracy: 0.4424
Epoch 24/30
20/20 [==============================] - 34s 2s/step - loss: 1.5347 - accuracy: 0.4335 - val_loss: 1.4920 - val_accuracy: 0.4532
Epoch 25/30
20/20 [==============================] - 33s 2s/step - loss: 1.5245 - accuracy: 0.4387 - val_loss: 1.4805 - val_accuracy: 0.4584
Epoch 26/30
20/20 [==============================] - 33s 2s/step - loss: 1.5057 - accuracy: 0.4469 - val_loss: 1.4754 - val_accuracy: 0.4592
Epoch 27/30
20/20 [==============================] - 34s 2s/step - loss: 1.5013 - accuracy: 0.4457 - val_loss: 1.4688 - val_accuracy: 0.4619
Epoch 28/30
20/20 [==============================] - 33s 2s/step - loss: 1.4852 - accuracy: 0.4548 - val_loss: 1.4543 - val_accuracy: 0.4704
Epoch 29/30
20/20 [==============================] - 34s 2s/step - loss: 1.4728 - accuracy: 0.4570 - val_loss: 1.4437 - val_accuracy: 0.4751
Epoch 30/30
20/20 [==============================] - 34s 2s/step - loss: 1.4652 - accuracy: 0.4606 - val_loss: 1.4546 - val_accuracy: 0.4726

Visualize training metrics

The model.fit() will return a history object, which stores the values of the metrics generated during the training run (but it is ephemeral and needs to be saved manually).

We now display the Loss and Accuracy curves for the training and validation sets.

plt.plot(history.history["loss"], label="loss")
plt.plot(history.history["val_loss"], label="val_loss")

plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"], label="val_accuracy")



Visualize attention maps from the Temporal Latent Bottleneck

Now that we have trained our model, it is time for some visualizations. The Fast Stream (Transformers) processes a chunk of tokens. The Slow Stream processes each chunk and attends to tokens that are useful for the task.

In this section we visualize the attention map of the Slow Stream. This is done by extracting the attention scores from the TLB layer at each chunk's intersection and storing it within the RNN's state. This is followed by 'ballooning' it up and returning these values.

def score_to_viz(chunk_score):
    # get the most attended token
    chunk_viz = tf.math.reduce_max(chunk_score, axis=-2)
    # get the mean across heads
    chunk_viz = tf.math.reduce_mean(chunk_viz, axis=1)
    return chunk_viz

# Get a batch of images and labels from the testing dataset
images, labels = next(iter(test_ds))

# Set the get_attn_scores flag to True
model.rnn.cell.get_attention_scores = True

# Run the model with the testing images and grab the
# attention scores.
outputs = model(images)
list_chunk_scores = model.rnn.cell.attention_scores

# Process the attention scores in order to visualize them
list_chunk_viz = [score_to_viz(x) for x in list_chunk_scores]
chunk_viz = tf.concat(list_chunk_viz[1:], axis=-1)
chunk_viz = tf.reshape(
        config["image_size"] // config["patch_size"],
        config["image_size"] // config["patch_size"],
upsampled_heat_map = layers.UpSampling2D(
    size=(4, 4), interpolation="bilinear", dtype="float32"

Run the following code snippet to get different images and their attention maps.

# Sample a random image
index = random.randint(0, config["batch_size"])
orig_image = images[index]
overlay_image = upsampled_heat_map[index, ..., 0]

# Plot the visualization
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))


image = ax[1].imshow(orig_image)
ax[1].set_title("TLB Attention:")




This example has demonstrated an implementation of the Temporal Latent Bottleneck mechanism. The example highlights the use of compression and storage of historical states in the form of a Temporal Latent Bottleneck with regular updates from a Perceptual Module as an effective method to do so.

In the original paper, the authors have conducted highly extensive tests around different modalities ranging from Supervised Image Classification to applications in Reinforcement Learning.

While we have only displayed a method to apply this mechanism to Image Classification, it can be extended to other modalities too with minimal changes.

Note: While building this example we did not have the official code to refer to. This means that our implementation is inspired by the paper with no claims of being a complete reproduction. For more details on the training process one can head over to our GitHub repository.


Thanks to Aniket Didolkar (the first author) and Anirudh Goyal (the third author) for revieweing our work.

We would like to thank PyImageSearch for a Colab Pro account and JarvisLabs.ai for the GPU credits.