Keras 3 API documentation / KerasCV / Models / Tasks / Segment Anything

Segment Anything

[source]

SegmentAnythingModel class

keras_cv.models.SegmentAnythingModel(
    backbone, prompt_encoder, mask_decoder, **kwargs
)

The Segment Anything (SAM) Model.

Arguments

  • backbone (keras_cv.models.Backbone): A feature extractor for the input images.
  • prompt_encoder (keras_cv.models.SAMPromptEncoder): A Keras layer to compute embeddings for points, box, and mask prompt.
  • mask_decoder (keras_cv.models.SAMMaskDecoder): A Keras layer to generate segmentation masks given the embeddings generated by the backbone and the prompt encoder.

References

Examples

>>> import numpy as np
>>> from keras_cv.models import ViTDetBBackbone
>>> from keras_cv.models import SAMPromptEncoder
>>> from keras_cv.models import SAMMaskDecoder

Create all the components of the SAM model:

>>> backbone = ViTDetBBackbone()
>>> prompt_encoder = SAMPromptEncoder()
>>> mask_decoder = SAMMaskDecoder()

Instantiate the model:

>>> sam = SegmentAnythingModel(
...     backbone=backbone,
...     prompt_encoder=prompt_encoder,
...     mask_decoder=mask_decoder
... )

Define the input of the backbone. This must be a batch of images of shape (1024, 1024, 3) for the ViT backbone we are using:

>>> image = np.ones((1, 1024, 1024, 3))

SAM works by prompting the input images. There are three ways to prompt:

(1) Labelled Points: Foreground points (points with label 1) are encoded such that the output masks generated by the mask decoder contain them and background points (points with label 0) are encoded such that the generated masks don't contain them. (2) Box: A box tells the model which part/crop of the image to segment. (3) Mask: An input mask can be used to refine the output of the mask decoder.

These prompts can be mixed and matched but at least one of the prompts must be present. To turn off a particular prompt, simply exclude it from the inputs to the model.

TODO(ianstenbit): Remove the need for the 1 axes, and fix the box shape

.

(1) For points prompts, the expected shape is (batch, num_points, 2). The labels must have a corresponding shape of (batch, num_points). (2) For box prompt, the expected shape is (batch, 1, 2, 2). (3) Similarly, mask prompts have shape (batch, 1, H, W, 1).

For example, to pass in all the prompts, do:

>>> points = np.array([[[512., 512.], [100., 100.]]])
>>> # For labels: 1 means foreground point, 0 means background
>>> labels = np.array([[1., 0.]])
>>> box = np.array([[[[384., 384.], [640., 640.]]]])
>>> input_mask = np.ones((1, 1, 256, 256, 1))

Prepare an input dictionary:

>>> inputs = {
...     "images": image,
...     "points": points,
...     "labels": labels,
...     "boxes": box,
...     "masks": input_mask
... }
...
>>> outputs = sam.predict(inputs)
>>> masks, iou_pred = outputs["masks"], outputs["iou_pred"]

The first mask in the output masks (i.e. masks[:, 0, ...]) is the best mask predicted by the model based on the prompts. Other masks (i.e. masks[:, 1:, ...]) are alternate predictions that can be used if they are desired over the first one.

Now, in case of only points and box prompts, simply exclude the masks:

>>> inputs = {
...     "images": image,
...     "points": points,
...     "labels": labels,
...     "boxes": box,
... }
...
>>> outputs = sam.predict(inputs)
>>> masks, iou_pred = outputs["masks"], outputs["iou_pred"]

TODO(ianstenbit): Remove the need for this padding

.

Another example is that only points prompts are present. Note that if point prompts are present but no box prompt is present, the points must be padded using a zero point and -1 label:

>>> padded_points = np.concatenate(
...     [points, np.zeros((1, 1, 2))], axis=1
... )
...
>>> padded_labels = np.concatenate(
...     [labels, -np.ones((1, 1))], axis=1
... )
>>> inputs = {
...     "images": image,
...     "points": padded_points,
...     "labels": padded_labels,
... }
...
>>> outputs = sam.predict(inputs)
>>> masks, iou_pred = outputs["masks"], outputs["iou_pred"]

Note that the segment anything model only supports inference and training isn't support yet. So, calling the fit method will fail for now.


[source]

from_preset method

SegmentAnythingModel.from_preset()

Instantiate SegmentAnythingModel model from preset config and weights.

Arguments

  • preset: string. Must be one of "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet18_v2", "resnet34_v2", "resnet50_v2", "resnet101_v2", "resnet152_v2", "mobilenet_v3_small", "mobilenet_v3_large", "csp_darknet_tiny", "csp_darknet_s", "csp_darknet_m", "csp_darknet_l", "csp_darknet_xl", "efficientnetv1_b0", "efficientnetv1_b1", "efficientnetv1_b2", "efficientnetv1_b3", "efficientnetv1_b4", "efficientnetv1_b5", "efficientnetv1_b6", "efficientnetv1_b7", "efficientnetv2_s", "efficientnetv2_m", "efficientnetv2_l", "efficientnetv2_b0", "efficientnetv2_b1", "efficientnetv2_b2", "efficientnetv2_b3", "densenet121", "densenet169", "densenet201", "efficientnetlite_b0", "efficientnetlite_b1", "efficientnetlite_b2", "efficientnetlite_b3", "efficientnetlite_b4", "yolo_v8_xs_backbone", "yolo_v8_s_backbone", "yolo_v8_m_backbone", "yolo_v8_l_backbone", "yolo_v8_xl_backbone", "vitdet_base", "vitdet_large", "vitdet_huge", "resnet50_imagenet", "resnet50_v2_imagenet", "mobilenet_v3_large_imagenet", "mobilenet_v3_small_imagenet", "csp_darknet_tiny_imagenet", "csp_darknet_l_imagenet", "efficientnetv2_s_imagenet", "efficientnetv2_b0_imagenet", "efficientnetv2_b1_imagenet", "efficientnetv2_b2_imagenet", "densenet121_imagenet", "densenet169_imagenet", "densenet201_imagenet", "yolo_v8_xs_backbone_coco", "yolo_v8_s_backbone_coco", "yolo_v8_m_backbone_coco", "yolo_v8_l_backbone_coco", "yolo_v8_xl_backbone_coco", "vitdet_base_sa1b", "vitdet_large_sa1b", "vitdet_huge_sa1b", "sam_base_sa1b", "sam_large_sa1b", "sam_huge_sa1b". If looking for a preset with pretrained weights, choose one of "resnet50_imagenet", "resnet50_v2_imagenet", "mobilenet_v3_large_imagenet", "mobilenet_v3_small_imagenet", "csp_darknet_tiny_imagenet", "csp_darknet_l_imagenet", "efficientnetv2_s_imagenet", "efficientnetv2_b0_imagenet", "efficientnetv2_b1_imagenet", "efficientnetv2_b2_imagenet", "densenet121_imagenet", "densenet169_imagenet", "densenet201_imagenet", "yolo_v8_xs_backbone_coco", "yolo_v8_s_backbone_coco", "yolo_v8_m_backbone_coco", "yolo_v8_l_backbone_coco", "yolo_v8_xl_backbone_coco", "vitdet_base_sa1b", "vitdet_large_sa1b", "vitdet_huge_sa1b", "sam_base_sa1b", "sam_large_sa1b", "sam_huge_sa1b".
  • load_weights: Whether to load pre-trained weights into model. Defaults to None, which follows whether the preset has pretrained weights available.
  • input_shape : input shape that will be passed to backbone initialization, Defaults to None.If None, the preset value will be used.

Examples

# Load architecture and weights from preset
model = keras_cv.models.SegmentAnythingModel.from_preset(
    "resnet50_imagenet",
)

# Load randomly initialized model from preset architecture with weights
model = keras_cv.models.SegmentAnythingModel.from_preset(
    "resnet50_imagenet",
    load_weights=False,
Preset name Parameters Description
sam_base_sa1b 93.74M The base SAM model trained on the SA1B dataset.
sam_large_sa1b 312.34M The large SAM model trained on the SA1B dataset.
sam_huge_sa1b 641.09M The huge SAM model trained on the SA1B dataset.

[source]

SAMMaskDecoder class

keras_cv.models.SAMMaskDecoder(
    transformer_dim=256,
    transformer=None,
    num_multimask_outputs=3,
    iou_head_depth=3,
    iou_head_hidden_dim=256,
    activation="gelu",
    **kwargs
)

Mask decoder for the Segment Anything Model (SAM).

This lightweight module efficiently maps the image embedding and a set of prompt embeddings to an output mask. Before applying the transformer decoder, the layer first inserts into the set of prompt embeddings a learned output token embedding that will be used at the decoder's output. For simplicity, these embeddings (not including the image embedding) are collectively called "tokens".

The image embeddings, positional image embeddings, and tokens are passed through a transformer decoder. After running the decoder, the layer upsamples the updated image embedding by 4x with two transposed convolutional layers (now it's downscaled 4x relative to the input image). Then, the tokens attend once more to the image embedding and the updated output token embedding are passed to a small 3-layer MLP that outputs a vector matching the channel dimension of the upscaled image embedding. Finally, a mask is predicted with a spatially point-wise product between the upscaled image embedding and the MLP's output.

Arguments

  • transformer_dim (int, optional): The number of input features to the transformer decoder. Defaults to 256.
  • transformer (keras.layers.Layer, optional): A transformer decoder. Defaults to None. When None, a keras_cv.models.TwoWayTransformer layer is used.
  • num_multimask_outputs (int, optional): Number of multimask outputs. The model would generate these many extra masks. The total masks generated by the model are 1 + num_multimask_outputs. Defaults to 3.
  • iou_head_depth (int, optional): The depth of the dense net used to predict the IoU confidence score. Defaults to 3.
  • iou_head_hidden_dim (int, optional): The number of units in the hidden layers used in the dense net to predict the IoU confidence score. Defaults to 256.
  • activation (str, optional): Activation to use in the mask upscaler network. Defaults to "gelu".

References


[source]

SAMPromptEncoder class

keras_cv.models.SAMPromptEncoder(
    embed_dim=256,
    image_embedding_size=(64, 64),
    input_image_size=(1024, 1024),
    mask_in_chans=16,
    activation="gelu",
    **kwargs
)

Prompt Encoder for the Segment Anything Model (SAM).

The prompt encoder generates encodings for three types of prompts:

  • Point prompts: Points on the image along with a label indicating whether the point is in the foreground (part of the mask) or in the background (not a part of the mask).
  • Box prompts: A batch of bounding boxes with format [(x1, y1), (x2, y2)] used to determine the location of the masks in the image.
  • Masks: An input mask can be passed to refine the positional embeddings for the output mask.

First, the point prompts and box prompts are concatenated and positional encodings are generated using random spatial frequencies. A point is represented as the sum of a positional encoding of the point's location and one of two learned embeddings that indicate if the point is either in the foreground or background. A box is represented by an embedding pair:

(1) the positional encoding of its top-left corner summed with a learned embedding representing "top-left corner" and (2) the same structure but using a learned embedding indicating "bottom-right corner".

The box and point encodings are referred to as "sparse encodings"

If a mask prompt is passed, a convolutional neural net is used to downscale it to generate "dense encodings". If no mask prompt is passed, an embedding layer is used instead to generate a "no mask" embedding.

Arguments

  • embed_dim (int, optional): The number of features in the output embeddings. Defaults to 256.
  • image_embedding_size (int, optional): The number of features in the image embeddings generated by an image encoder. Defaults to (64, 64).
  • input_image_size (tuple[int], optional): A tuple of the height and width of the image being prompted. Defaults to (1024, 1024).
  • mask_in_chans (int, optional): The number of channels of the mask prompt. Defaults to 16.
  • activation (str, optional): The activation to use in the mask downscaler neural net. Defaults to "gelu".

References


[source]

TwoWayTransformer class

keras_cv.models.TwoWayTransformer(
    depth=2,
    embed_dim=256,
    num_heads=8,
    mlp_dim=2048,
    activation="relu",
    attention_downsample_rate=2,
    **kwargs
)

A two-way cross-attention transformer decoder.

A transformer decoder that attends to an input image using queries whose positional embedding is supplied.

The transformer decoder design is shown in [1]_. Each decoder layer performs 4 steps: (1) self-attention on the tokens, (2) cross-attention from tokens (as queries) to the image embedding, (3) a point-wise MLP updates each token, and (4) cross-attention from the image embedding (as queries) to tokens. This last step updates the image embedding with prompt information. Each self/cross-attention and MLP has a residual connection and layer normalization.

To ensure the decoder has access to critical geometric information the positional encodings are added to the image embedding whenever they participate in an attention layer. Additionally, the entire original prompt tokens (including their positional encodings) are re-added to the updated tokens whenever they participate in an attention layer. This allows for a strong dependence on both the prompt token's geometric location and type.

Arguments

  • depth (int, optional): The depth of the attention blocks (the number of attention blocks to use). Defaults to 2.
  • embed_dim (int, optional): The number of features of the input image and point embeddings. Defaults to 256.
  • num_heads (int, optional): Number of heads to use in the attention layers. Defaults to 8.
  • mlp_dim (int, optional): The number of units in the hidden layer of the MLP block used in the attention layers. Defaults to 2048.
  • activation (str, optional): The activation of the MLP block's output layer used in the attention layers. Defaults to "relu".
  • attention_downsample_rate (int, optional): The downsample rate of the attention layers. Defaults to 2.

References


[source]

MultiHeadAttentionWithDownsampling class

keras_cv.layers.MultiHeadAttentionWithDownsampling(
    num_heads, key_dim, downsample_rate=1, **kwargs
)

Multi-Head Attention with downsampling.

An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values.

This layer first downscales the features of input queries, keys, and values using a dense layer. Multi-head attention is then performed and the attention map is projected back (upscaled) to the number of input features.

Arguments

  • num_heads (int): Number of attention heads.
  • key_dim (int): Size of each attention head for query, key, and value.
  • downsample_rate (int, optional): The factor by which to downscale the input features i.e. the input features of size key_dim are projected down to key_dim // downsample_rate.

References


[source]

TwoWayMultiHeadAttention class

keras_cv.layers.TwoWayMultiHeadAttention(
    num_heads,
    key_dim,
    mlp_dim,
    skip_first_layer_pe,
    attention_downsample_rate=2,
    activation="relu",
    **kwargs
)

Two-way multi-head attention layer.

Arguments

  • num_heads (int): Number of attention heads.
  • key_dim (int): Size of each attention head for query, key, and value.
  • mlp_dim (int): Number of hidden dims to use in the mlp block.
  • skip_first_layer_pe (bool): A boolean indicating whether to skip the first layer positional embeddings.
  • attention_downsample_rate (int, optional): The downsample rate to use in the attention layers. Defaults to 2.
  • activation (str, optional): The activation for the mlp block's output layer. Defaults to "relu".

References


[source]

RandomFrequencyPositionalEmbeddings class

keras_cv.layers.RandomFrequencyPositionalEmbeddings(
    num_positional_features, scale, **kwargs
)

Positional encoding using random spatial frequencies.

This layer maps coordinates/points in 2D space to positional encodings using random spatial frequencies.

Arguments

  • num_positional_features (int): Number of positional features in the output.
  • scale (float): The standard deviation of the random frequencies.

References