Keras 3 API documentation / KerasCV / Models / Backbones / MixTransformer backbones

MixTransformer backbones

[source]

MiTBackbone class

keras_cv.models.MiTBackbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

Base class for Backbone models.

Backbones are reusable layers of models trained on a standard task such as Imagenet classification that can be reused in other tasks.


[source]

from_preset method

MiTBackbone.from_preset()

Instantiate MiTBackbone model from preset config and weights.

Arguments

  • preset: string. Must be one of "mit_b0", "mit_b1", "mit_b2", "mit_b3", "mit_b4", "mit_b5", "mit_b0_imagenet". If looking for a preset with pretrained weights, choose one of "mit_b0_imagenet".
  • load_weights: Whether to load pre-trained weights into model. Defaults to None, which follows whether the preset has pretrained weights available.

Examples

# Load architecture and weights from preset
model = keras_cv.models.MiTBackbone.from_preset(
    "mit_b0_imagenet",
)

# Load randomly initialized model from preset architecture with weights
model = keras_cv.models.MiTBackbone.from_preset(
    "mit_b0_imagenet",
    load_weights=False,
Preset name Parameters Description
mit_b0 3.32M MiT (MixTransformer) model with 8 transformer blocks.
mit_b1 13.16M MiT (MixTransformer) model with 8 transformer blocks.
mit_b2 24.20M MiT (MixTransformer) model with 16 transformer blocks.
mit_b3 44.08M MiT (MixTransformer) model with 28 transformer blocks.
mit_b4 60.85M MiT (MixTransformer) model with 41 transformer blocks.
mit_b5 81.45M MiT (MixTransformer) model with 52 transformer blocks.
mit_b0_imagenet 3.32M MiT (MixTransformer) model with 8 transformer blocks. Pre-trained on ImageNet-1K and scores 69% top-1 accuracy on the validation set.

[source]

MiTB0Backbone class

keras_cv.models.MiTB0Backbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

MiT model.

For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.

Arguments

  • include_rescaling: bool, whether to rescale the inputs. If set to True, inputs will be passed through a Rescaling(scale=1 / 255) layer. Defaults to True.
  • input_shape: optional shape tuple, defaults to (None, None, 3).
  • input_tensor: optional Keras tensor (i.e., output of layers.Input()) to use as image input for the model.

Example

input_data = tf.ones(shape=(8, 224, 224, 3))

# Randomly initialized backbone
model = MiTB0Backbone()
output = model(input_data)

[source]

MiTB1Backbone class

keras_cv.models.MiTB1Backbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

MiT model.

For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.

Arguments

  • include_rescaling: bool, whether to rescale the inputs. If set to True, inputs will be passed through a Rescaling(scale=1 / 255) layer. Defaults to True.
  • input_shape: optional shape tuple, defaults to (None, None, 3).
  • input_tensor: optional Keras tensor (i.e., output of layers.Input()) to use as image input for the model.

Example

input_data = tf.ones(shape=(8, 224, 224, 3))

# Randomly initialized backbone
model = MiTB1Backbone()
output = model(input_data)

[source]

MiTB2Backbone class

keras_cv.models.MiTB2Backbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

MiT model.

For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.

Arguments

  • include_rescaling: bool, whether to rescale the inputs. If set to True, inputs will be passed through a Rescaling(scale=1 / 255) layer. Defaults to True.
  • input_shape: optional shape tuple, defaults to (None, None, 3).
  • input_tensor: optional Keras tensor (i.e., output of layers.Input()) to use as image input for the model.

Example

input_data = tf.ones(shape=(8, 224, 224, 3))

# Randomly initialized backbone
model = MiTB2Backbone()
output = model(input_data)

[source]

MiTB3Backbone class

keras_cv.models.MiTB3Backbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

MiT model.

For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.

Arguments

  • include_rescaling: bool, whether to rescale the inputs. If set to True, inputs will be passed through a Rescaling(scale=1 / 255) layer. Defaults to True.
  • input_shape: optional shape tuple, defaults to (None, None, 3).
  • input_tensor: optional Keras tensor (i.e., output of layers.Input()) to use as image input for the model.

Example

input_data = tf.ones(shape=(8, 224, 224, 3))

# Randomly initialized backbone
model = MiTB3Backbone()
output = model(input_data)

[source]

MiTB4Backbone class

keras_cv.models.MiTB4Backbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

MiT model.

For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.

Arguments

  • include_rescaling: bool, whether to rescale the inputs. If set to True, inputs will be passed through a Rescaling(scale=1 / 255) layer. Defaults to True.
  • input_shape: optional shape tuple, defaults to (None, None, 3).
  • input_tensor: optional Keras tensor (i.e., output of layers.Input()) to use as image input for the model.

Example

input_data = tf.ones(shape=(8, 224, 224, 3))

# Randomly initialized backbone
model = MiTB4Backbone()
output = model(input_data)

[source]

MiTB5Backbone class

keras_cv.models.MiTB5Backbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

MiT model.

For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.

Arguments

  • include_rescaling: bool, whether to rescale the inputs. If set to True, inputs will be passed through a Rescaling(scale=1 / 255) layer. Defaults to True.
  • input_shape: optional shape tuple, defaults to (None, None, 3).
  • input_tensor: optional Keras tensor (i.e., output of layers.Input()) to use as image input for the model.

Example

input_data = tf.ones(shape=(8, 224, 224, 3))

# Randomly initialized backbone
model = MiTB5Backbone()
output = model(input_data)