Developer guides / KerasCV / Segment Anything in KerasCV!

Segment Anything in KerasCV!

Author: Tirth Patel, Ian Stenbit
Date created: 2023/12/04
Last modified: 2023/12/19
Description: Segment anything using text, box, and points prompts in KerasCV.

View in Colab GitHub source


Overview

The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.

In this guide, we will show how to use KerasCV's implementation of the Segment Anything Model and show how powerful TensorFlow's and JAX's performance boost is.

First, let's get all our dependencies and images for our demo.

!pip install -Uq keras-cv
!pip install -Uq keras
!wget -q https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg

Choose your backend

With Keras 3, you can choose to use your favorite backend!

import os

os.environ["KERAS_BACKEND"] = "jax"

import timeit
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras import ops
import keras_cv

Helper functions

Let's define some helper functions for visulazing the images, prompts, and the segmentation results.

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def show_box(box, ax):
    box = box.reshape(-1)
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
    )


def inference_resizing(image, pad=True):
    # Compute Preprocess Shape
    image = ops.cast(image, dtype="float32")
    old_h, old_w = image.shape[0], image.shape[1]
    scale = 1024 * 1.0 / max(old_h, old_w)
    new_h = old_h * scale
    new_w = old_w * scale
    preprocess_shape = int(new_h + 0.5), int(new_w + 0.5)

    # Resize the image
    image = ops.image.resize(image[None, ...], preprocess_shape)[0]

    # Pad the shorter side
    if pad:
        pixel_mean = ops.array([123.675, 116.28, 103.53])
        pixel_std = ops.array([58.395, 57.12, 57.375])
        image = (image - pixel_mean) / pixel_std
        h, w = image.shape[0], image.shape[1]
        pad_h = 1024 - h
        pad_w = 1024 - w
        image = ops.pad(image, [(0, pad_h), (0, pad_w), (0, 0)])
        # KerasCV now rescales the images and normalizes them.
        # Just unnormalize such that when KerasCV normalizes them
        # again, the padded values map to 0.
        image = image * pixel_std + pixel_mean
    return image

Get the pretrained SAM model

We can initialize a trained SAM model using KerasCV's from_preset factory method. Here, we use the huge ViT backbone trained on the SA-1B dataset (sam_huge_sa1b) for high-quality segmentation masks. You can also use one of the sam_large_sa1b or sam_base_sa1b for better performance (at the cost of decreasing quality of segmentation masks).

model = keras_cv.models.SegmentAnythingModel.from_preset("sam_huge_sa1b")

Understanding Prompts

Segment Anything allows prompting an image using points, boxes, and masks:

  1. Point prompts are the most basic of all: the model tries to guess the object given a point on an image. The point can either be a foreground point (i.e. the desired segmentation mask contains the point in it) or a backround point (i.e. the point lies outside the desired mask).
  2. Another way to prompt the model is using boxes. Given a bounding box, the model tries to segment the object contained in it.
  3. Finally, the model can also be prompted using a mask itself. This is useful, for instance, to refine the borders of a previously predicted or known segmentation mask.

What makes the model incredibly powerful is the ability to combine the prompts above. Point, box, and mask prompts can be combined in several different ways to achieve the best result.

Let's see the semantics of passing these prompts to the Segment Anything model in KerasCV. Input to the SAM model is a dictionary with keys:

  1. "images": A batch of images to segment. Must be of shape (B, 1024, 1024, 3).
  2. "points": A batch of point prompts. Each point is an (x, y) coordinate originating from the top-left corner of the image. In other works, each point is of the form (r, c) where r and c are the row and column of the pixel in the image. Must be of shape (B, N, 2).
  3. "labels": A batch of labels for the given points. 1 represents foreground points and 0 represents background points. Must be of shape (B, N).
  4. "boxes": A batch of boxes. Note that the model only accepts one box per batch. Hence, the expected shape is (B, 1, 2, 2). Each box is a collection of 2 points: the top left corner and the bottom right corner of the box. The points here follow the same semantics as the point prompts. Here the 1 in the second dimension represents the presence of box prompts. If the box prompts are missing, a placeholder input of shape (B, 0, 2, 2) must be passed.
  5. "masks": A batch of masks. Just like box prompts, only one mask prompt per image is allowed. The shape of the input mask must be (B, 1, 256, 256, 1) if they are present and (B, 0, 256, 256, 1) for missing mask prompt.

Placeholder prompts are only required when calling the model directly (i.e. model(...)). When calling the predict method, missing prompts can be omitted from the input dictionary.


Point prompts

First, let's segment an image using point prompts. We load the image and resize it to shape (1024, 1024), the image size the pretrained SAM model expects.

# Load our image
image = np.array(keras.utils.load_img("truck.jpg"))
image = inference_resizing(image)

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
plt.axis("on")
plt.show()

png

Next, we will define the point on the object we want to segment. Let's try to segment the truck's window pane at coordinates (284, 213).

# Define the input point prompt
input_point = np.array([[284, 213.5]])
input_label = np.array([1])

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_points(input_point, input_label, plt.gca())
plt.axis("on")
plt.show()

png

Now let's call the predict method of our model to get the segmentation masks.

Note: We don't call the model directly (model(...)) since placeholder prompts are required to do so. Missing prompts are handled automatically by the predict method so we call it instead. Also, when no box prompts are present, the points and labels need to be padded with a zero point prompt and -1 label prompt respectively. The cell below demonstrates how this works.

outputs = model.predict(
    {
        "images": image[np.newaxis, ...],
        "points": np.concatenate(
            [input_point[np.newaxis, ...], np.zeros((1, 1, 2))], axis=1
        ),
        "labels": np.concatenate(
            [input_label[np.newaxis, ...], np.full((1, 1), fill_value=-1)], axis=1
        ),
    }
)
 1/1 ━━━━━━━━━━━━━━━━━━━━ 48s 48s/step

SegmentAnythingModel.predict returns two outputs. First are logits (segmentation masks) of shape (1, 4, 256, 256) and the other are the IoU confidence scores (of shape (1, 4)) for each mask predicted. The pretrained SAM model predicts four masks: the first is the best mask the model could come up with for the given prompts, and the other 3 are the alternative masks which can be used in case the best prediction doesn't contain the desired object. The user can choose whichever mask they prefer.

Let's visualize the masks returned by the model!

# Resize the mask to our image shape i.e. (1024, 1024)
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
# Convert the logits to a numpy array
# and convert the logits to a boolean mask
mask = ops.convert_to_numpy(mask) > 0.0
iou_score = ops.convert_to_numpy(outputs["iou_pred"][0][0])

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"IoU Score: {iou_score:.3f}", fontsize=18)
plt.axis("off")
plt.show()

png

As expected, the model returns a segmentation mask for the truck's window pane. But, our point prompt can also mean a range of other things. For example, another possible mask that contains our point is just the right side of the window pane or the whole truck.

Let's also visualize the other masks the model has predicted.

fig, ax = plt.subplots(1, 3, figsize=(20, 60))
masks, scores = outputs["masks"][0][1:], outputs["iou_pred"][0][1:]
for i, (mask, score) in enumerate(zip(masks, scores)):
    mask = inference_resizing(mask[..., None], pad=False)[..., 0]
    mask, score = map(ops.convert_to_numpy, (mask, score))
    mask = 1 * (mask > 0.0)
    ax[i].imshow(ops.convert_to_numpy(image) / 255.0)
    show_mask(mask, ax[i])
    show_points(input_point, input_label, ax[i])
    ax[i].set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=12)
    ax[i].axis("off")
plt.show()

png

Nice! SAM was able to capture the ambiguity of our point prompt and also returned other possible segmentation masks.


Box Prompts

Now, let's see how we can prompt the model using boxes. The box is specified using two points, the top-left corner and the bottom-right corner of the bounding box in xyxy format. Let's prompt the model using a bounding box around the left front tyre of the truck.

# Let's specify the box
input_box = np.array([[240, 340], [400, 500]])

outputs = model.predict(
    {"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}
)
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
mask = ops.convert_to_numpy(mask) > 0.0

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_mask(mask, plt.gca())
show_box(input_box, plt.gca())
plt.axis("off")
plt.show()
 1/1 ━━━━━━━━━━━━━━━━━━━━ 13s 13s/step

png

Boom! The model perfectly segments out the left front tyre in our bounding box.


Combining prompts

To get the true potential of the model out, let's combine box and point prompts and see what the model does.

# Let's specify the box
input_box = np.array([[240, 340], [400, 500]])
# Let's specify the point and mark it background
input_point = np.array([[325, 425]])
input_label = np.array([0])

outputs = model.predict(
    {
        "images": image[np.newaxis, ...],
        "points": input_point[np.newaxis, ...],
        "labels": input_label[np.newaxis, ...],
        "boxes": input_box[np.newaxis, np.newaxis, ...],
    }
)
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
mask = ops.convert_to_numpy(mask) > 0.0

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_mask(mask, plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis("off")
plt.show()
 1/1 ━━━━━━━━━━━━━━━━━━━━ 16s 16s/step

png

Voila! The model understood that the object we wanted to exclude from our mask was the rim of the tyre.


Text prompts

Finally, let's see how text prompts can be used along with KerasCV's SegmentAnythingModel.

For this demo, we will use the offical Grounding DINO model. Grounding DINO is a model that takes as input a (image, text) pair and generates a bounding box around the object in the image described by the text. You can refer to the paper for more details on the implementation of the model.

For this part of the demo, we will need to install the groundingdino package from source:

pip install -U git+https://github.com/IDEA-Research/GroundingDINO.git

Then, we can install the pretrained model's weights and config:

!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
!wget -q https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/v0.1.0-alpha2/groundingdino/config/GroundingDINO_SwinT_OGC.py
from groundingdino.util.inference import Model as GroundingDINO

CONFIG_PATH = "GroundingDINO_SwinT_OGC.py"
WEIGHTS_PATH = "groundingdino_swint_ogc.pth"

grounding_dino = GroundingDINO(CONFIG_PATH, WEIGHTS_PATH)
/home/tirthp/oss/virtualenvs/keras-io-dev/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]

final text_encoder_type: bert-base-uncased

Let's load an image of a dog for this part!

filepath = keras.utils.get_file(
    origin="https://storage.googleapis.com/keras-cv/test-images/mountain-dog.jpeg"
)
image = np.array(keras.utils.load_img(filepath))
image = ops.convert_to_numpy(inference_resizing(image))

plt.figure(figsize=(10, 10))
plt.imshow(image / 255.0)
plt.axis("on")
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

We first predict the bounding box of the object we want to segment using the Grounding DINO model. Then, we prompt the SAM model using the bounding box to get the segmentation mask.

Let's try to segment out the harness of the dog. Change the image and text below to segment whatever you want using text from your image!

# Let's predict the bounding box for the harness of the dog
boxes = grounding_dino.predict_with_caption(image.astype(np.uint8), "harness")
boxes = np.array(boxes[0].xyxy)

outputs = model.predict(
    {
        "images": np.repeat(image[np.newaxis, ...], boxes.shape[0], axis=0),
        "boxes": boxes.reshape(-1, 1, 2, 2),
    },
    batch_size=1,
)
/home/tirthp/oss/virtualenvs/keras-io-dev/lib/python3.10/site-packages/transformers/modeling_utils.py:942: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.
  warnings.warn(
/home/tirthp/oss/virtualenvs/keras-io-dev/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
/home/tirthp/oss/virtualenvs/keras-io-dev/lib/python3.10/site-packages/torch/utils/checkpoint.py:61: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn(

 1/1 ━━━━━━━━━━━━━━━━━━━━ 13s 13s/step

And that's it! We got a segmentation mask for our text prompt using the combination of Gounding DINO + SAM! This is a very powerful technique to combine different models to expand the applications!

Let's visualize the results.

plt.figure(figsize=(10, 10))
plt.imshow(image / 255.0)

for mask in outputs["masks"]:
    mask = inference_resizing(mask[0][..., None], pad=False)[..., 0]
    mask = ops.convert_to_numpy(mask) > 0.0
    show_mask(mask, plt.gca())
    show_box(boxes, plt.gca())

plt.axis("off")
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png


Optimizing SAM

You can use mixed_float16 or bfloat16 dtype policies to gain huge speedups and memory optimizations at releatively low precision loss.

# Load our image
image = np.array(keras.utils.load_img("truck.jpg"))
image = inference_resizing(image)

# Specify the prompt
input_box = np.array([[240, 340], [400, 500]])

# Let's first see how fast the model is with float32 dtype
time_taken = timeit.repeat(
    'model.predict({"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}, verbose=False)',
    repeat=3,
    number=3,
    globals=globals(),
)
print(f"Time taken with float32 dtype: {min(time_taken) / 3:.10f}s")

# Set the dtype policy in Keras
keras.mixed_precision.set_global_policy("mixed_float16")

model = keras_cv.models.SegmentAnythingModel.from_preset("sam_huge_sa1b")

time_taken = timeit.repeat(
    'model.predict({"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}, verbose=False)',
    repeat=3,
    number=3,
    globals=globals(),
)
print(f"Time taken with float16 dtype: {min(time_taken) / 3:.10f}s")
Time taken with float32 dtype: 0.5304666963s
Time taken with float16 dtype: 0.1586400040s

Here's a comparison of KerasCV's implementation with the original PyTorch implementation!

benchmark

The script used to generate the benchmarks is present here.


Conclusion

KerasCV's SegmentAnythingModel supports a variety of applications and, with the help of Keras 3, enables running the model on TensorFlow, JAX, and PyTorch! With the help of XLA in JAX and TensorFlow, the model runs several times faster than the original implementation. Moreover, using Keras's mixed precision support helps optimize memory use and computation time with just one line of code!

For more advanced uses, check out the Automatic Mask Generator demo.