β–Ί Code examples / Computer Vision / Monocular depth estimation

Monocular depth estimation

Author: Victor Basu
Date created: 2021/08/30
Last modified: 2021/08/30

β“˜ This example uses Keras 2

View in Colab β€’ GitHub source

Description: Implement a depth estimation model with a convnet.


Introduction

Depth estimation is a crucial step towards inferring scene geometry from 2D images. The goal in monocular depth estimation is to predict the depth value of each pixel or inferring depth information, given only a single RGB image as input. This example will show an approach to build a depth estimation model with a convnet and simple loss functions.

depth


Setup

import os
import sys

import tensorflow as tf
from tensorflow.keras import layers

import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt

tf.random.set_seed(123)

Downloading the dataset

We will be using the dataset DIODE: A Dense Indoor and Outdoor Depth Dataset for this tutorial. However, we use the validation set generating training and evaluation subsets for our model. The reason we use the validation set rather than the training set of the original dataset is because the training set consists of 81GB of data, which is challenging to download compared to the validation set which is only 2.6GB. Other datasets that you could use are NYU-v2 and KITTI.

annotation_folder = "/dataset/"
if not os.path.exists(os.path.abspath(".") + annotation_folder):
    annotation_zip = tf.keras.utils.get_file(
        "val.tar.gz",
        cache_subdir=os.path.abspath("."),
        origin="http://diode-dataset.s3.amazonaws.com/val.tar.gz",
        extract=True,
    )
Downloading data from http://diode-dataset.s3.amazonaws.com/val.tar.gz
2774630400/2774625282 [==============================] - 90s 0us/step
2774638592/2774625282 [==============================] - 90s 0us/step

Preparing the dataset

We only use the indoor images to train our depth estimation model.

path = "val/indoors"

filelist = []

for root, dirs, files in os.walk(path):
    for file in files:
        filelist.append(os.path.join(root, file))

filelist.sort()
data = {
    "image": [x for x in filelist if x.endswith(".png")],
    "depth": [x for x in filelist if x.endswith("_depth.npy")],
    "mask": [x for x in filelist if x.endswith("_depth_mask.npy")],
}
df = pd.DataFrame(data)

df = df.sample(frac=1, random_state=42)

Preparing hyperparameters

HEIGHT = 256
WIDTH = 256
LR = 0.0002
EPOCHS = 30
BATCH_SIZE = 32

Building a data pipeline

  1. The pipeline takes a dataframe containing the path for the RGB images, as well as the depth and depth mask files.
  2. It reads and resize the RGB images.
  3. It reads the depth and depth mask files, process them to generate the depth map image and resize it.
  4. It returns the RGB images and the depth map images for a batch.
class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, data, batch_size=6, dim=(768, 1024), n_channels=3, shuffle=True):
        """
        Initialization
        """
        self.data = data
        self.indices = self.data.index.tolist()
        self.dim = dim
        self.n_channels = n_channels
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.min_depth = 0.1
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.data) / self.batch_size))

    def __getitem__(self, index):
        if (index + 1) * self.batch_size > len(self.indices):
            self.batch_size = len(self.indices) - index * self.batch_size
        # Generate one batch of data
        # Generate indices of the batch
        index = self.indices[index * self.batch_size : (index + 1) * self.batch_size]
        # Find list of IDs
        batch = [self.indices[k] for k in index]
        x, y = self.data_generation(batch)

        return x, y

    def on_epoch_end(self):

        """
        Updates indexes after each epoch
        """
        self.index = np.arange(len(self.indices))
        if self.shuffle == True:
            np.random.shuffle(self.index)

    def load(self, image_path, depth_map, mask):
        """Load input and target image."""

        image_ = cv2.imread(image_path)
        image_ = cv2.cvtColor(image_, cv2.COLOR_BGR2RGB)
        image_ = cv2.resize(image_, self.dim)
        image_ = tf.image.convert_image_dtype(image_, tf.float32)

        depth_map = np.load(depth_map).squeeze()

        mask = np.load(mask)
        mask = mask > 0

        max_depth = min(300, np.percentile(depth_map, 99))
        depth_map = np.clip(depth_map, self.min_depth, max_depth)
        depth_map = np.log(depth_map, where=mask)

        depth_map = np.ma.masked_where(~mask, depth_map)

        depth_map = np.clip(depth_map, 0.1, np.log(max_depth))
        depth_map = cv2.resize(depth_map, self.dim)
        depth_map = np.expand_dims(depth_map, axis=2)
        depth_map = tf.image.convert_image_dtype(depth_map, tf.float32)

        return image_, depth_map

    def data_generation(self, batch):

        x = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, *self.dim, 1))

        for i, batch_id in enumerate(batch):
            x[i,], y[i,] = self.load(
                self.data["image"][batch_id],
                self.data["depth"][batch_id],
                self.data["mask"][batch_id],
            )

        return x, y

Visualizing samples

def visualize_depth_map(samples, test=False, model=None):
    input, target = samples
    cmap = plt.cm.jet
    cmap.set_bad(color="black")

    if test:
        pred = model.predict(input)
        fig, ax = plt.subplots(6, 3, figsize=(50, 50))
        for i in range(6):
            ax[i, 0].imshow((input[i].squeeze()))
            ax[i, 1].imshow((target[i].squeeze()), cmap=cmap)
            ax[i, 2].imshow((pred[i].squeeze()), cmap=cmap)

    else:
        fig, ax = plt.subplots(6, 2, figsize=(50, 50))
        for i in range(6):
            ax[i, 0].imshow((input[i].squeeze()))
            ax[i, 1].imshow((target[i].squeeze()), cmap=cmap)


visualize_samples = next(
    iter(DataGenerator(data=df, batch_size=6, dim=(HEIGHT, WIDTH)))
)
visualize_depth_map(visualize_samples)

png


3D point cloud visualization

depth_vis = np.flipud(visualize_samples[1][1].squeeze())  # target
img_vis = np.flipud(visualize_samples[0][1].squeeze())  # input

fig = plt.figure(figsize=(15, 10))
ax = plt.axes(projection="3d")

STEP = 3
for x in range(0, img_vis.shape[0], STEP):
    for y in range(0, img_vis.shape[1], STEP):
        ax.scatter(
            [depth_vis[x, y]] * 3,
            [y] * 3,
            [x] * 3,
            c=tuple(img_vis[x, y, :3] / 255),
            s=3,
        )
    ax.view_init(45, 135)

png


Building the model

  1. The basic model is from U-Net.
  2. Addditive skip-connections are implemented in the downscaling block.
class DownscaleBlock(layers.Layer):
    def __init__(
        self, filters, kernel_size=(3, 3), padding="same", strides=1, **kwargs
    ):
        super().__init__(**kwargs)
        self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
        self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
        self.reluA = layers.LeakyReLU(alpha=0.2)
        self.reluB = layers.LeakyReLU(alpha=0.2)
        self.bn2a = tf.keras.layers.BatchNormalization()
        self.bn2b = tf.keras.layers.BatchNormalization()

        self.pool = layers.MaxPool2D((2, 2), (2, 2))

    def call(self, input_tensor):
        d = self.convA(input_tensor)
        x = self.bn2a(d)
        x = self.reluA(x)

        x = self.convB(x)
        x = self.bn2b(x)
        x = self.reluB(x)

        x += d
        p = self.pool(x)
        return x, p


class UpscaleBlock(layers.Layer):
    def __init__(
        self, filters, kernel_size=(3, 3), padding="same", strides=1, **kwargs
    ):
        super().__init__(**kwargs)
        self.us = layers.UpSampling2D((2, 2))
        self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
        self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
        self.reluA = layers.LeakyReLU(alpha=0.2)
        self.reluB = layers.LeakyReLU(alpha=0.2)
        self.bn2a = tf.keras.layers.BatchNormalization()
        self.bn2b = tf.keras.layers.BatchNormalization()
        self.conc = layers.Concatenate()

    def call(self, x, skip):
        x = self.us(x)
        concat = self.conc([x, skip])
        x = self.convA(concat)
        x = self.bn2a(x)
        x = self.reluA(x)

        x = self.convB(x)
        x = self.bn2b(x)
        x = self.reluB(x)

        return x


class BottleNeckBlock(layers.Layer):
    def __init__(
        self, filters, kernel_size=(3, 3), padding="same", strides=1, **kwargs
    ):
        super().__init__(**kwargs)
        self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
        self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
        self.reluA = layers.LeakyReLU(alpha=0.2)
        self.reluB = layers.LeakyReLU(alpha=0.2)

    def call(self, x):
        x = self.convA(x)
        x = self.reluA(x)
        x = self.convB(x)
        x = self.reluB(x)
        return x

Defining the loss

We will optimize 3 losses in our mode. 1. Structural similarity index(SSIM). 2. L1-loss, or Point-wise depth in our case. 3. Depth smoothness loss.

Out of the three loss functions, SSIM contributes the most to improving model performance.

class DepthEstimationModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.ssim_loss_weight = 0.85
        self.l1_loss_weight = 0.1
        self.edge_loss_weight = 0.9
        self.loss_metric = tf.keras.metrics.Mean(name="loss")
        f = [16, 32, 64, 128, 256]
        self.downscale_blocks = [
            DownscaleBlock(f[0]),
            DownscaleBlock(f[1]),
            DownscaleBlock(f[2]),
            DownscaleBlock(f[3]),
        ]
        self.bottle_neck_block = BottleNeckBlock(f[4])
        self.upscale_blocks = [
            UpscaleBlock(f[3]),
            UpscaleBlock(f[2]),
            UpscaleBlock(f[1]),
            UpscaleBlock(f[0]),
        ]
        self.conv_layer = layers.Conv2D(1, (1, 1), padding="same", activation="tanh")

    def calculate_loss(self, target, pred):
        # Edges
        dy_true, dx_true = tf.image.image_gradients(target)
        dy_pred, dx_pred = tf.image.image_gradients(pred)
        weights_x = tf.exp(tf.reduce_mean(tf.abs(dx_true)))
        weights_y = tf.exp(tf.reduce_mean(tf.abs(dy_true)))

        # Depth smoothness
        smoothness_x = dx_pred * weights_x
        smoothness_y = dy_pred * weights_y

        depth_smoothness_loss = tf.reduce_mean(abs(smoothness_x)) + tf.reduce_mean(
            abs(smoothness_y)
        )

        # Structural similarity (SSIM) index
        ssim_loss = tf.reduce_mean(
            1
            - tf.image.ssim(
                target, pred, max_val=WIDTH, filter_size=7, k1=0.01 ** 2, k2=0.03 ** 2
            )
        )
        # Point-wise depth
        l1_loss = tf.reduce_mean(tf.abs(target - pred))

        loss = (
            (self.ssim_loss_weight * ssim_loss)
            + (self.l1_loss_weight * l1_loss)
            + (self.edge_loss_weight * depth_smoothness_loss)
        )

        return loss

    @property
    def metrics(self):
        return [self.loss_metric]

    def train_step(self, batch_data):
        input, target = batch_data
        with tf.GradientTape() as tape:
            pred = self(input, training=True)
            loss = self.calculate_loss(target, pred)

        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        self.loss_metric.update_state(loss)
        return {
            "loss": self.loss_metric.result(),
        }

    def test_step(self, batch_data):
        input, target = batch_data

        pred = self(input, training=False)
        loss = self.calculate_loss(target, pred)

        self.loss_metric.update_state(loss)
        return {
            "loss": self.loss_metric.result(),
        }

    def call(self, x):
        c1, p1 = self.downscale_blocks[0](x)
        c2, p2 = self.downscale_blocks[1](p1)
        c3, p3 = self.downscale_blocks[2](p2)
        c4, p4 = self.downscale_blocks[3](p3)

        bn = self.bottle_neck_block(p4)

        u1 = self.upscale_blocks[0](bn, c4)
        u2 = self.upscale_blocks[1](u1, c3)
        u3 = self.upscale_blocks[2](u2, c2)
        u4 = self.upscale_blocks[3](u3, c1)

        return self.conv_layer(u4)

Model training

optimizer = tf.keras.optimizers.Adam(
    learning_rate=LR,
    amsgrad=False,
)
model = DepthEstimationModel()
# Compile the model
model.compile(optimizer)

train_loader = DataGenerator(
    data=df[:260].reset_index(drop="true"), batch_size=BATCH_SIZE, dim=(HEIGHT, WIDTH)
)
validation_loader = DataGenerator(
    data=df[260:].reset_index(drop="true"), batch_size=BATCH_SIZE, dim=(HEIGHT, WIDTH)
)
model.fit(
    train_loader,
    epochs=EPOCHS,
    validation_data=validation_loader,
)
Epoch 1/30
9/9 [==============================] - 18s 1s/step - loss: 1.1543 - val_loss: 1.4281
Epoch 2/30
9/9 [==============================] - 3s 390ms/step - loss: 0.8727 - val_loss: 1.0686
Epoch 3/30
9/9 [==============================] - 4s 428ms/step - loss: 0.6659 - val_loss: 0.7884
Epoch 4/30
9/9 [==============================] - 3s 334ms/step - loss: 0.6462 - val_loss: 0.6198
Epoch 5/30
9/9 [==============================] - 3s 355ms/step - loss: 0.5689 - val_loss: 0.6207
Epoch 6/30
9/9 [==============================] - 3s 361ms/step - loss: 0.5067 - val_loss: 0.4876
Epoch 7/30
9/9 [==============================] - 3s 357ms/step - loss: 0.4680 - val_loss: 0.4698
Epoch 8/30
9/9 [==============================] - 3s 325ms/step - loss: 0.4622 - val_loss: 0.7249
Epoch 9/30
9/9 [==============================] - 3s 393ms/step - loss: 0.4215 - val_loss: 0.3826
Epoch 10/30
9/9 [==============================] - 3s 337ms/step - loss: 0.3788 - val_loss: 0.3289
Epoch 11/30
9/9 [==============================] - 3s 345ms/step - loss: 0.3347 - val_loss: 0.3032
Epoch 12/30
9/9 [==============================] - 3s 327ms/step - loss: 0.3488 - val_loss: 0.2631
Epoch 13/30
9/9 [==============================] - 3s 326ms/step - loss: 0.3315 - val_loss: 0.2383
Epoch 14/30
9/9 [==============================] - 3s 331ms/step - loss: 0.3349 - val_loss: 0.2379
Epoch 15/30
9/9 [==============================] - 3s 333ms/step - loss: 0.3394 - val_loss: 0.2151
Epoch 16/30
9/9 [==============================] - 3s 337ms/step - loss: 0.3073 - val_loss: 0.2243
Epoch 17/30
9/9 [==============================] - 3s 355ms/step - loss: 0.3951 - val_loss: 0.2627
Epoch 18/30
9/9 [==============================] - 3s 335ms/step - loss: 0.3657 - val_loss: 0.2175
Epoch 19/30
9/9 [==============================] - 3s 321ms/step - loss: 0.3404 - val_loss: 0.2073
Epoch 20/30
9/9 [==============================] - 3s 320ms/step - loss: 0.3549 - val_loss: 0.1972
Epoch 21/30
9/9 [==============================] - 3s 317ms/step - loss: 0.2802 - val_loss: 0.1936
Epoch 22/30
9/9 [==============================] - 3s 316ms/step - loss: 0.2632 - val_loss: 0.1893
Epoch 23/30
9/9 [==============================] - 3s 318ms/step - loss: 0.2862 - val_loss: 0.1807
Epoch 24/30
9/9 [==============================] - 3s 328ms/step - loss: 0.3083 - val_loss: 0.1923
Epoch 25/30
9/9 [==============================] - 3s 312ms/step - loss: 0.3666 - val_loss: 0.1795
Epoch 26/30
9/9 [==============================] - 3s 316ms/step - loss: 0.2928 - val_loss: 0.1753
Epoch 27/30
9/9 [==============================] - 3s 325ms/step - loss: 0.2945 - val_loss: 0.1790
Epoch 28/30
9/9 [==============================] - 3s 325ms/step - loss: 0.2642 - val_loss: 0.1775
Epoch 29/30
9/9 [==============================] - 3s 333ms/step - loss: 0.2546 - val_loss: 0.1810
Epoch 30/30
9/9 [==============================] - 3s 315ms/step - loss: 0.2650 - val_loss: 0.1795

<keras.callbacks.History at 0x7f5151799fd0>

Visualizing model output

We visualize the model output over the validation set. The first image is the RGB image, the second image is the ground truth depth map image and the third one is the predicted depth map image.

test_loader = next(
    iter(
        DataGenerator(
            data=df[265:].reset_index(drop="true"), batch_size=6, dim=(HEIGHT, WIDTH)
        )
    )
)
visualize_depth_map(test_loader, test=True, model=model)

test_loader = next(
    iter(
        DataGenerator(
            data=df[300:].reset_index(drop="true"), batch_size=6, dim=(HEIGHT, WIDTH)
        )
    )
)
visualize_depth_map(test_loader, test=True, model=model)

png

png


Possible improvements

  1. You can improve this model by replacing the encoding part of the U-Net with a pretrained DenseNet or ResNet.
  2. Loss functions play an important role in solving this problem. Tuning the loss functions may yield significant improvement.

References

The following papers go deeper into possible approaches for depth estimation. 1. Depth Prediction Without the Sensors: Leveraging Structure for Unsupervised Learning from Monocular Videos 2. Digging Into Self-Supervised Monocular Depth Estimation 3. Deeper Depth Prediction with Fully Convolutional Residual Networks

You can also find helpful implementations in the papers with code depth estimation task.

You can use the trained model hosted on Hugging Face Hub and try the demo on Hugging Face Spaces.