SwinTransformerBackbone model

[source]

SwinTransformerBackbone class

keras_hub.models.SwinTransformerBackbone(
    image_shape,
    embed_dim,
    depths,
    num_heads,
    window_size,
    patch_size=4,
    mlp_ratio=4.0,
    qkv_bias=True,
    dropout_rate=0.0,
    attention_dropout=0.0,
    drop_path=0.1,
    patch_norm=True,
    data_format=None,
    dtype=None,
    **kwargs
)

A Swin Transformer backbone network.

This network implements a hierarchical vision transformer as described in "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows". It includes the patch embedding, transformer stages with shifted windows, and final normalization, but not the classification head.

The default constructor gives a fully customizable, randomly initialized Swin Transformer with any number of layers, heads, and embedding dimensions. To load preset architectures and weights, use the from_preset() constructor.

Arguments

  • image_shape: tuple of ints. The shape of the input images, excluding batch dimension.
  • embed_dim: int. Base dimension of the transformer.
  • depths: tuple of ints. Number of transformer blocks in each stage.
  • num_heads: tuple of ints. Number of attention heads in each stage.
  • window_size: int. Size of the attention window.
  • patch_size: int. Size of the patches to be extracted from the input images. Defaults to 4.
  • mlp_ratio: float. Ratio of mlp hidden dim to embedding dim. Defaults to 4.0.
  • qkv_bias: bool. If True, add a learnable bias to query, key, value. Defaults to True.
  • dropout_rate: float. Dropout rate. Defaults to 0.0.
  • attention_dropout: float. Dropout rate for attention. Defaults to 0.0.
  • drop_path: float. Stochastic depth rate. Defaults to 0.1.
  • patch_norm: bool. If True, add normalization after patch embedding. Defaults to True.
  • data_format: str. Format of the input data, either "channels_last" or "channels_first". Defaults to None (which uses "channels_last").
  • dtype: string or keras.mixed_precision.DTypePolicy. The dtype to use for model computations and weights. Defaults to None.

Examples

# Pretrained Swin Transformer backbone.
model = keras_hub.models.SwinTransformerBackbone.from_preset(
    "swin_tiny_224"
)
model(np.ones((1, 224, 224, 3)))

# Randomly initialized Swin Transformer with custom config.
model = keras_hub.models.SwinTransformerBackbone(
    image_shape=(224, 224, 3),
    embed_dim=96,
    depths=(2, 2, 6, 2),
    num_heads=(3, 6, 12, 24),
    window_size=7,
)
model(np.ones((1, 224, 224, 3)))

[source]

from_preset method

SwinTransformerBackbone.from_preset(preset, load_weights=True, **kwargs)

Instantiate a keras_hub.models.Backbone from a model preset.

A preset is a directory of configs, weights and other file assets used to save and load a pre-trained model. The preset can be passed as a one of:

  1. a built-in preset identifier like 'bert_base_en'
  2. a Kaggle Models handle like 'kaggle://user/bert/keras/bert_base_en'
  3. a Hugging Face handle like 'hf://user/bert_base_en'
  4. a ModelScope handle like 'modelscope://user/bert_base_en'
  5. a path to a local preset directory like './bert_base_en'

This constructor can be called in one of two ways. Either from the base class like keras_hub.models.Backbone.from_preset(), or from a model class like keras_hub.models.GemmaBackbone.from_preset(). If calling from the base class, the subclass of the returning object will be inferred from the config in the preset directory.

For any Backbone subclass, you can run cls.presets.keys() to list all built-in presets available on the class.

Arguments

  • preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory.
  • load_weights: bool. If True, the weights will be loaded into the model architecture. If False, the weights will be randomly initialized.

Examples

# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(
    "gemma_2b_en",
)

# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(
    "bert_base_en",
    load_weights=False,
)
Preset Parameters Description
swin_tiny_patch4_window7_224 27.52M Swin-Tiny model pre-trained on the ImageNet 1k dataset with image resolution of 224x224.
swin_small_patch4_window7_224 48.84M Swin-Small model pre-trained on the ImageNet 1k dataset with image resolution of 224x224.
swin_base_patch4_window7_224 86.74M Swin-Base model pre-trained on the ImageNet 1k dataset with image resolution of 224x224.
swin_base_patch4_window12_384 86.88M Swin-Base model pre-trained on the ImageNet 1k dataset with image resolution of 384x384.
swin_large_patch4_window7_224 195.00M Swin-Large model pre-trained on the ImageNet 1k dataset with image resolution of 224x224.
swin_large_patch4_window12_384 195.20M Swin-Large model pre-trained on the ImageNet 1k dataset with image resolution of 384x384.