SAMImageSegmenter model

[source]

SAMImageSegmenter class

keras_hub.models.SAMImageSegmenter(backbone, preprocessor=None, **kwargs)

The Segment Anything (SAM) image segmenter Model.

SAM works by prompting the input images. There are three ways to prompt: (1) Labelled Points: Foreground points (points with label 1) are encoded such that the output masks generated by the mask decoder contain them and background points (points with label 0) are encoded such that the generated masks don't contain them. (2) Box: A box tells the model which part/crop of the image to segment. (3) Mask: An input mask can be used to refine the output of the mask decoder. These prompts can be mixed and matched but at least one of the prompts must be present. To turn off a particular prompt, simply exclude it from the inputs to the model. (1) For points prompts, the expected shape is (batch, num_points, 2). The labels must have a corresponding shape of (batch, num_points). (2) For box prompt, the expected shape is (batch, 1, 2, 2). (3) Similarly, mask prompts have shape (batch, 1, H, W, 1).

Arguments

  • backbone: A keras_hub.models.VGGBackbone instance.

Example

Load pretrained model using from_preset.

image_size=128
batch_size=2
input_data = {
    "images": np.ones(
        (batch_size, image_size, image_size, 3),
        dtype="float32",
    ),
    "points": np.ones((batch_size, 1, 2), dtype="float32"),
    "labels": np.ones((batch_size, 1), dtype="float32"),
    "boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
    "masks": np.zeros(
        (batch_size, 0, image_size, image_size, 1)
    ),
}
# todo: update preset name
sam = keras_hub.models.SAMImageSegmenter.from_preset(`sam_base`)
sam(input_data)

Load segment anything image segmenter with custom backbone

image_size = 128
batch_size = 2
images = np.ones(
    (batch_size, image_size, image_size, 3),
    dtype="float32",
)
image_encoder = ViTDetBackbone(
    hidden_size=16,
    num_layers=16,
    intermediate_dim=16 * 4,
    num_heads=16,
    global_attention_layer_indices=[2, 5, 8, 11],
    patch_size=16,
    num_output_channels=8,
    window_size=2,
    image_shape=(image_size, image_size, 3),
)
prompt_encoder = SAMPromptEncoder(
    hidden_size=8,
    image_embedding_size=(8, 8),
    input_image_size=(
        image_size,
        image_size,
    ),
    mask_in_channels=16,
)
mask_decoder = SAMMaskDecoder(
    num_layers=2,
    hidden_size=8,
    intermediate_dim=32,
    num_heads=8,
    embedding_dim=8,
    num_multimask_outputs=3,
    iou_head_depth=3,
    iou_head_hidden_dim=8,
)
backbone = SAMBackbone(
    image_encoder=image_encoder,
    prompt_encoder=prompt_encoder,
    mask_decoder=mask_decoder,
    image_shape=(image_size, image_size, 3),
)
sam = SAMImageSegmenter(
    backbone=backbone
)

For example, to pass in all the prompts, do:

points = np.array([[[512., 512.], [100., 100.]]])
# For labels: 1 means foreground point, 0 means background
labels = np.array([[1., 0.]])
box = np.array([[[[384., 384.], [640., 640.]]]])
input_mask = np.ones((1, 1, 256, 256, 1))
Prepare an input dictionary:
inputs = {
    "images": image,
    "points": points,
    "labels": labels,
    "boxes": box,
    "masks": input_mask
}
outputs = sam.predict(inputs)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]

The first mask in the output masks (i.e. masks[:, 0, ...]) is the best mask predicted by the model based on the prompts. Other masks (i.e. masks[:, 1:, ...]) are alternate predictions that can be used if they are desired over the first one. Now, in case of only points and box prompts, simply exclude the masks:

inputs = {
    "images": image,
    "points": points,
    "labels": labels,
    "boxes": box,
}

outputs = sam.predict(inputs)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]

Another example is that only points prompts are present. Note that if point prompts are present but no box prompt is present, the points must be padded using a zero point and -1 label:

padded_points = np.concatenate(
    [points, np.zeros((1, 1, 2))], axis=1
)

padded_labels = np.concatenate(
    [labels, -np.ones((1, 1))], axis=1
)
inputs = {
    "images": image,
    "points": padded_points,
    "labels": padded_labels,
}
outputs = sam.predict(inputs)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]

[source]

from_preset method

SAMImageSegmenter.from_preset(preset, load_weights=True, **kwargs)

Instantiate a keras_hub.models.Task from a model preset.

A preset is a directory of configs, weights and other file assets used to save and load a pre-trained model. The preset can be passed as one of:

  1. a built-in preset identifier like 'bert_base_en'
  2. a Kaggle Models handle like 'kaggle://user/bert/keras/bert_base_en'
  3. a Hugging Face handle like 'hf://user/bert_base_en'
  4. a path to a local preset directory like './bert_base_en'

For any Task subclass, you can run cls.presets.keys() to list all built-in presets available on the class.

This constructor can be called in one of two ways. Either from a task specific base class like keras_hub.models.CausalLM.from_preset(), or from a model class like keras_hub.models.BertTextClassifier.from_preset(). If calling from the a base class, the subclass of the returning object will be inferred from the config in the preset directory.

Arguments

  • preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory.
  • load_weights: bool. If True, saved weights will be loaded into the model architecture. If False, all weights will be randomly initialized.

Examples

# Load a Gemma generative task.
causal_lm = keras_hub.models.CausalLM.from_preset(
    "gemma_2b_en",
)

# Load a Bert classification task.
model = keras_hub.models.TextClassifier.from_preset(
    "bert_base_en",
    num_classes=2,
)
Preset name Parameters Description
sam_base_sa1b 93.74M The base SAM model trained on the SA1B dataset.
sam_large_sa1b 641.09M The large SAM model trained on the SA1B dataset.
sam_huge_sa1b 312.34M The huge SAM model trained on the SA1B dataset.

backbone property

keras_hub.models.SAMImageSegmenter.backbone

A keras_hub.models.Backbone model with the core architecture.


preprocessor property

keras_hub.models.SAMImageSegmenter.preprocessor

A keras_hub.models.Preprocessor layer used to preprocess input.