â–º Code examples / Computer Vision / Video Classification with a CNN-RNN Architecture

Video Classification with a CNN-RNN Architecture

Author: Sayak Paul
Date created: 2021/05/28
Last modified: 2023/08/28
Description: Training a video classifier with transfer learning and a recurrent model on the UCF101 dataset.

ⓘ This example uses Keras 2

View in Colab • GitHub source

This example demonstrates video classification, an important use-case with applications in recommendations, security, and so on. We will be using the UCF101 dataset to build our video classifier. The dataset consists of videos categorized into different actions, like cricket shot, punching, biking, etc. This dataset is commonly used to build action recognizers, which are an application of video classification.

A video consists of an ordered sequence of frames. Each frame contains spatial information, and the sequence of those frames contains temporal information. To model both of these aspects, we use a hybrid architecture that consists of convolutions (for spatial processing) as well as recurrent layers (for temporal processing). Specifically, we'll use a Convolutional Neural Network (CNN) and a Recurrent Neural Network (RNN) consisting of GRU layers. This kind of hybrid architecture is popularly known as a CNN-RNN.

This example requires TensorFlow 2.5 or higher, as well as TensorFlow Docs, which can be installed using the following command:

!pip install -q git+https://github.com/tensorflow/docs

Data collection

In order to keep the runtime of this example relatively short, we will be using a subsampled version of the original UCF101 dataset. You can refer to this notebook to know how the subsampling was done.

!wget -q https://github.com/sayakpaul/Action-Recognition-in-TensorFlow/releases/download/v1.0.0/ucf101_top5.tar.gz
!tar xf ucf101_top5.tar.gz

Setup

from tensorflow_docs.vis import embed
from tensorflow import keras
from imutils import paths

import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import numpy as np
import imageio
import cv2
import os
2021-09-13 14:08:15.945527: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-09-13 14:08:15.945551: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

Define hyperparameters

IMG_SIZE = 224
BATCH_SIZE = 64
EPOCHS = 10

MAX_SEQ_LENGTH = 20
NUM_FEATURES = 2048

Data preparation

train_df = pd.read_csv("train.csv")
test_df = pd.read_csv("test.csv")

print(f"Total videos for training: {len(train_df)}")
print(f"Total videos for testing: {len(test_df)}")

train_df.sample(10)
Total videos for training: 594
Total videos for testing: 224
video_name tag
149 v_PlayingCello_g12_c05.avi PlayingCello
317 v_Punch_g19_c05.avi Punch
438 v_ShavingBeard_g20_c03.avi ShavingBeard
559 v_TennisSwing_g20_c02.avi TennisSwing
368 v_ShavingBeard_g09_c03.avi ShavingBeard
241 v_Punch_g08_c04.avi Punch
398 v_ShavingBeard_g14_c03.avi ShavingBeard
111 v_CricketShot_g25_c01.avi CricketShot
119 v_PlayingCello_g08_c02.avi PlayingCello
249 v_Punch_g09_c05.avi Punch

One of the many challenges of training video classifiers is figuring out a way to feed the videos to a network. This blog post discusses five such methods. Since a video is an ordered sequence of frames, we could just extract the frames and put them in a 3D tensor. But the number of frames may differ from video to video which would prevent us from stacking them into batches (unless we use padding). As an alternative, we can save video frames at a fixed interval until a maximum frame count is reached. In this example we will do the following:

  1. Capture the frames of a video.
  2. Extract frames from the videos until a maximum frame count is reached.
  3. In the case, where a video's frame count is lesser than the maximum frame count we will pad the video with zeros.

Note that this workflow is identical to problems involving texts sequences. Videos of the UCF101 dataset is known to not contain extreme variations in objects and actions across frames. Because of this, it may be okay to only consider a few frames for the learning task. But this approach may not generalize well to other video classification problems. We will be using OpenCV's VideoCapture() method to read frames from videos.

# The following two methods are taken from this tutorial:
# https://www.tensorflow.org/hub/tutorials/action_recognition_with_tf_hub


def crop_center_square(frame):
    y, x = frame.shape[0:2]
    min_dim = min(y, x)
    start_x = (x // 2) - (min_dim // 2)
    start_y = (y // 2) - (min_dim // 2)
    return frame[start_y : start_y + min_dim, start_x : start_x + min_dim]


def load_video(path, max_frames=0, resize=(IMG_SIZE, IMG_SIZE)):
    cap = cv2.VideoCapture(path)
    frames = []
    try:
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame = crop_center_square(frame)
            frame = cv2.resize(frame, resize)
            frame = frame[:, :, [2, 1, 0]]
            frames.append(frame)

            if len(frames) == max_frames:
                break
    finally:
        cap.release()
    return np.array(frames)

We can use a pre-trained network to extract meaningful features from the extracted frames. The Keras Applications module provides a number of state-of-the-art models pre-trained on the ImageNet-1k dataset. We will be using the InceptionV3 model for this purpose.

def build_feature_extractor():
    feature_extractor = keras.applications.InceptionV3(
        weights="imagenet",
        include_top=False,
        pooling="avg",
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
    )
    preprocess_input = keras.applications.inception_v3.preprocess_input

    inputs = keras.Input((IMG_SIZE, IMG_SIZE, 3))
    preprocessed = preprocess_input(inputs)

    outputs = feature_extractor(preprocessed)
    return keras.Model(inputs, outputs, name="feature_extractor")


feature_extractor = build_feature_extractor()
2021-09-13 14:08:17.043898: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-13 14:08:17.044381: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-09-13 14:08:17.044436: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory
2021-09-13 14:08:17.044470: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory
2021-09-13 14:08:17.055998: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusolver.so.11'; dlerror: libcusolver.so.11: cannot open shared object file: No such file or directory
2021-09-13 14:08:17.056056: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusparse.so.11'; dlerror: libcusparse.so.11: cannot open shared object file: No such file or directory
2021-09-13 14:08:17.056646: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1835] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2021-09-13 14:08:17.056971: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

The labels of the videos are strings. Neural networks do not understand string values, so they must be converted to some numerical form before they are fed to the model. Here we will use the StringLookup layer encode the class labels as integers.

label_processor = keras.layers.StringLookup(
    num_oov_indices=0, vocabulary=np.unique(train_df["tag"])
)
print(label_processor.get_vocabulary())
['CricketShot', 'PlayingCello', 'Punch', 'ShavingBeard', 'TennisSwing']

Finally, we can put all the pieces together to create our data processing utility.

def prepare_all_videos(df, root_dir):
    num_samples = len(df)
    video_paths = df["video_name"].values.tolist()
    labels = df["tag"].values
    labels = label_processor(labels[..., None]).numpy()

    # `frame_masks` and `frame_features` are what we will feed to our sequence model.
    # `frame_masks` will contain a bunch of booleans denoting if a timestep is
    # masked with padding or not.
    frame_masks = np.zeros(shape=(num_samples, MAX_SEQ_LENGTH), dtype="bool")
    frame_features = np.zeros(
        shape=(num_samples, MAX_SEQ_LENGTH, NUM_FEATURES), dtype="float32"
    )

    # For each video.
    for idx, path in enumerate(video_paths):
        # Gather all its frames and add a batch dimension.
        frames = load_video(os.path.join(root_dir, path))
        frames = frames[None, ...]

        # Initialize placeholders to store the masks and features of the current video.
        temp_frame_mask = np.zeros(shape=(1, MAX_SEQ_LENGTH,), dtype="bool")
        temp_frame_features = np.zeros(
            shape=(1, MAX_SEQ_LENGTH, NUM_FEATURES), dtype="float32"
        )

        # Extract features from the frames of the current video.
        for i, batch in enumerate(frames):
            video_length = batch.shape[0]
            length = min(MAX_SEQ_LENGTH, video_length)
            for j in range(length):
                temp_frame_features[i, j, :] = feature_extractor.predict(
                    batch[None, j, :]
                )
            temp_frame_mask[i, :length] = 1  # 1 = not masked, 0 = masked

        frame_features[idx,] = temp_frame_features.squeeze()
        frame_masks[idx,] = temp_frame_mask.squeeze()

    return (frame_features, frame_masks), labels


train_data, train_labels = prepare_all_videos(train_df, "train")
test_data, test_labels = prepare_all_videos(test_df, "test")

print(f"Frame features in train set: {train_data[0].shape}")
print(f"Frame masks in train set: {train_data[1].shape}")
2021-09-13 14:08:18.486751: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)

Frame features in train set: (594, 20, 2048)
Frame masks in train set: (594, 20)

The above code block will take ~20 minutes to execute depending on the machine it's being executed.


The sequence model

Now, we can feed this data to a sequence model consisting of recurrent layers like GRU.

# Utility for our sequence model.
def get_sequence_model():
    class_vocab = label_processor.get_vocabulary()

    frame_features_input = keras.Input((MAX_SEQ_LENGTH, NUM_FEATURES))
    mask_input = keras.Input((MAX_SEQ_LENGTH,), dtype="bool")

    # Refer to the following tutorial to understand the significance of using `mask`:
    # https://keras.io/api/layers/recurrent_layers/gru/
    x = keras.layers.GRU(16, return_sequences=True)(
        frame_features_input, mask=mask_input
    )
    x = keras.layers.GRU(8)(x)
    x = keras.layers.Dropout(0.4)(x)
    x = keras.layers.Dense(8, activation="relu")(x)
    output = keras.layers.Dense(len(class_vocab), activation="softmax")(x)

    rnn_model = keras.Model([frame_features_input, mask_input], output)

    rnn_model.compile(
        loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
    )
    return rnn_model


# Utility for running experiments.
def run_experiment():
    filepath = "/tmp/video_classifier"
    checkpoint = keras.callbacks.ModelCheckpoint(
        filepath, save_weights_only=True, save_best_only=True, verbose=1
    )

    seq_model = get_sequence_model()
    history = seq_model.fit(
        [train_data[0], train_data[1]],
        train_labels,
        validation_split=0.3,
        epochs=EPOCHS,
        callbacks=[checkpoint],
    )

    seq_model.load_weights(filepath)
    _, accuracy = seq_model.evaluate([test_data[0], test_data[1]], test_labels)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")

    return history, seq_model


_, sequence_model = run_experiment()
Epoch 1/10
13/13 [==============================] - 4s 101ms/step - loss: 1.5259 - accuracy: 0.3157 - val_loss: 1.4732 - val_accuracy: 0.3408
Epoch 00001: val_loss improved from inf to 1.47325, saving model to /tmp/video_classifier
Epoch 2/10
13/13 [==============================] - 0s 21ms/step - loss: 1.3087 - accuracy: 0.5880 - val_loss: 1.4751 - val_accuracy: 0.3408
Epoch 00002: val_loss did not improve from 1.47325
Epoch 3/10
13/13 [==============================] - 0s 20ms/step - loss: 1.1532 - accuracy: 0.6795 - val_loss: 1.5020 - val_accuracy: 0.3408
Epoch 00003: val_loss did not improve from 1.47325
Epoch 4/10
13/13 [==============================] - 0s 20ms/step - loss: 1.0586 - accuracy: 0.7325 - val_loss: 1.5205 - val_accuracy: 0.3464
Epoch 00004: val_loss did not improve from 1.47325
Epoch 5/10
13/13 [==============================] - 0s 21ms/step - loss: 0.9556 - accuracy: 0.7422 - val_loss: 1.5748 - val_accuracy: 0.3464
Epoch 00005: val_loss did not improve from 1.47325
Epoch 6/10
13/13 [==============================] - 0s 21ms/step - loss: 0.8988 - accuracy: 0.7783 - val_loss: 1.6144 - val_accuracy: 0.3464
Epoch 00006: val_loss did not improve from 1.47325
Epoch 7/10
13/13 [==============================] - 0s 21ms/step - loss: 0.8242 - accuracy: 0.8072 - val_loss: 1.7030 - val_accuracy: 0.3408
Epoch 00007: val_loss did not improve from 1.47325
Epoch 8/10
13/13 [==============================] - 0s 20ms/step - loss: 0.7479 - accuracy: 0.8434 - val_loss: 1.7466 - val_accuracy: 0.3464
Epoch 00008: val_loss did not improve from 1.47325
Epoch 9/10
13/13 [==============================] - 0s 20ms/step - loss: 0.6740 - accuracy: 0.8627 - val_loss: 1.8800 - val_accuracy: 0.3464
Epoch 00009: val_loss did not improve from 1.47325
Epoch 10/10
13/13 [==============================] - 0s 20ms/step - loss: 0.6519 - accuracy: 0.8265 - val_loss: 1.9150 - val_accuracy: 0.3464
Epoch 00010: val_loss did not improve from 1.47325
7/7 [==============================] - 1s 5ms/step - loss: 1.3806 - accuracy: 0.6875
Test accuracy: 68.75%

Note: To keep the runtime of this example relatively short, we just used a few training examples. This number of training examples is low with respect to the sequence model being used that has 99,909 trainable parameters. You are encouraged to sample more data from the UCF101 dataset using the notebook mentioned above and train the same model.


Inference

def prepare_single_video(frames):
    frames = frames[None, ...]
    frame_mask = np.zeros(shape=(1, MAX_SEQ_LENGTH,), dtype="bool")
    frame_features = np.zeros(shape=(1, MAX_SEQ_LENGTH, NUM_FEATURES), dtype="float32")

    for i, batch in enumerate(frames):
        video_length = batch.shape[0]
        length = min(MAX_SEQ_LENGTH, video_length)
        for j in range(length):
            frame_features[i, j, :] = feature_extractor.predict(batch[None, j, :])
        frame_mask[i, :length] = 1  # 1 = not masked, 0 = masked

    return frame_features, frame_mask


def sequence_prediction(path):
    class_vocab = label_processor.get_vocabulary()

    frames = load_video(os.path.join("test", path))
    frame_features, frame_mask = prepare_single_video(frames)
    probabilities = sequence_model.predict([frame_features, frame_mask])[0]

    for i in np.argsort(probabilities)[::-1]:
        print(f"  {class_vocab[i]}: {probabilities[i] * 100:5.2f}%")
    return frames


# This utility is for visualization.
# Referenced from:
# https://www.tensorflow.org/hub/tutorials/action_recognition_with_tf_hub
def to_gif(images):
    converted_images = images.astype(np.uint8)
    imageio.mimsave("animation.gif", converted_images, duration=100)
    return embed.embed_file("animation.gif")


test_video = np.random.choice(test_df["video_name"].values.tolist())
print(f"Test video path: {test_video}")
test_frames = sequence_prediction(test_video)
to_gif(test_frames[:MAX_SEQ_LENGTH])
Test video path: v_PlayingCello_g05_c03.avi
  PlayingCello: 25.61%
  CricketShot: 24.82%
  ShavingBeard: 19.38%
  TennisSwing: 17.43%
  Punch: 12.77%


Next steps

  • In this example, we made use of transfer learning for extracting meaningful features from video frames. You could also fine-tune the pre-trained network to notice how that affects the end results.
  • For speed-accuracy trade-offs, you can try out other models present inside tf.keras.applications.
  • Try different combinations of MAX_SEQ_LENGTH to observe how that affects the performance.
  • Train on a higher number of classes and see if you are able to get good performance.
  • Following this tutorial, try a pre-trained action recognition model from DeepMind.
  • Rolling-averaging can be useful technique for video classification and it can be combined with a standard image classification model to infer on videos. This tutorial will help understand how to use rolling-averaging with an image classifier.
  • When there are variations in between the frames of a video not all the frames might be equally important to decide its category. In those situations, putting a self-attention layer in the sequence model will likely yield better results.
  • Following this book chapter, you can implement Transformers-based models for processing videos.

You can use the trained model hosted on Hugging Face Hub and try the demo on Hugging Face Spaces.