SegFormerBackbone model


SegFormerBackbone class

keras_hub.models.SegFormerBackbone(image_encoder, projection_filters, **kwargs)

A Keras model implementing SegFormer for semantic segmentation.

This class implements the majority of the SegFormer architecture described in SegFormer: Simple and Efficient Design for Semantic Segmentation and based on the TensorFlow implementation from DeepVision.

SegFormers are meant to be used with the MixTransformer (MiT) encoder family, and use a very lightweight all-MLP decoder head.

The MiT encoder uses a hierarchical transformer which outputs features at multiple scales, similar to that of the hierarchical outputs typically associated with CNNs.


  • image_encoder: keras.Model. The backbone network for the model that is used as a feature extractor for the SegFormer encoder. Should be used with the MiT backbone model (keras_hub.models.MiTBackbone) which was created specifically for SegFormers.
  • num_classes: int, the number of classes for the detection model, including the background class.
  • projection_filters: int, number of filters in the convolution layer projecting the concatenated features into a segmentation map. Defaults to 256`.


Using the class with a custom backbone:

import keras_hub

backbone = keras_hub.models.MiTBackbone(
    depths=[2, 2, 2, 2],
    image_shape=(224, 224, 3),
    hidden_dims=[32, 64, 160, 256],
    blockwise_num_heads=[1, 2, 5, 8],
    blockwise_sr_ratios=[8, 4, 2, 1],
    patch_sizes=[7, 3, 3, 3],
    strides=[4, 2, 2, 2],

segformer_backbone = keras_hub.models.SegFormerBackbone(
    image_encoder=backbone, projection_filters=256)

Using the class with a preset backbone:

import keras_hub

backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
segformer_backbone = keras_hub.models.SegFormerBackbone(
    image_encoder=backbone, projection_filters=256)


from_preset method

SegFormerBackbone.from_preset(preset, load_weights=True, **kwargs)

Instantiate a keras_hub.models.Backbone from a model preset.

A preset is a directory of configs, weights and other file assets used to save and load a pre-trained model. The preset can be passed as a one of:

  1. a built-in preset identifier like 'bert_base_en'
  2. a Kaggle Models handle like 'kaggle://user/bert/keras/bert_base_en'
  3. a Hugging Face handle like 'hf://user/bert_base_en'
  4. a path to a local preset directory like './bert_base_en'

This constructor can be called in one of two ways. Either from the base class like keras_hub.models.Backbone.from_preset(), or from a model class like keras_hub.models.GemmaBackbone.from_preset(). If calling from the base class, the subclass of the returning object will be inferred from the config in the preset directory.

For any Backbone subclass, you can run cls.presets.keys() to list all built-in presets available on the class.


  • preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory.
  • load_weights: bool. If True, the weights will be loaded into the model architecture. If False, the weights will be randomly initialized.


# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(

# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(