SegFormerBackbone
classkeras_hub.models.SegFormerBackbone(image_encoder, projection_filters, **kwargs)
A Keras model implementing SegFormer for semantic segmentation.
This class implements the majority of the SegFormer architecture described in SegFormer: Simple and Efficient Design for Semantic Segmentation and based on the TensorFlow implementation from DeepVision.
SegFormers are meant to be used with the MixTransformer (MiT) encoder family, and use a very lightweight all-MLP decoder head.
The MiT encoder uses a hierarchical transformer which outputs features at multiple scales, similar to that of the hierarchical outputs typically associated with CNNs.
Arguments
keras.Model
. The backbone network for the model that is
used as a feature extractor for the SegFormer encoder.
Should be used with the MiT backbone model
(keras_hub.models.MiTBackbone
) which was created
specifically for SegFormers.Example
Using the class with a custom backbone
:
import keras_hub
backbone = keras_hub.models.MiTBackbone(
depths=[2, 2, 2, 2],
image_shape=(224, 224, 3),
hidden_dims=[32, 64, 160, 256],
num_layers=4,
blockwise_num_heads=[1, 2, 5, 8],
blockwise_sr_ratios=[8, 4, 2, 1],
max_drop_path_rate=0.1,
patch_sizes=[7, 3, 3, 3],
strides=[4, 2, 2, 2],
)
segformer_backbone = keras_hub.models.SegFormerBackbone(
image_encoder=backbone, projection_filters=256)
Using the class with a preset backbone
:
import keras_hub
backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
segformer_backbone = keras_hub.models.SegFormerBackbone(
image_encoder=backbone, projection_filters=256)
from_preset
methodSegFormerBackbone.from_preset(preset, load_weights=True, **kwargs)
Instantiate a keras_hub.models.Backbone
from a model preset.
A preset is a directory of configs, weights and other file assets used
to save and load a pre-trained model. The preset
can be passed as a
one of:
'bert_base_en'
'kaggle://user/bert/keras/bert_base_en'
'hf://user/bert_base_en'
'./bert_base_en'
This constructor can be called in one of two ways. Either from the base
class like keras_hub.models.Backbone.from_preset()
, or from
a model class like keras_hub.models.GemmaBackbone.from_preset()
.
If calling from the base class, the subclass of the returning object
will be inferred from the config in the preset directory.
For any Backbone
subclass, you can run cls.presets.keys()
to list
all built-in presets available on the class.
Arguments
True
, the weights will be loaded into the
model architecture. If False
, the weights will be randomly
initialized.Examples
# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(
"gemma_2b_en",
)
# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(
"bert_base_en",
load_weights=False,
)