ImageClassifier classkeras_hub.models.ImageClassifier(
backbone,
num_classes,
preprocessor=None,
pooling="avg",
activation=None,
dropout=0.0,
head_dtype=None,
**kwargs
)
Base class for all image classification tasks.
ImageClassifier tasks wrap a keras_hub.models.Backbone and
a keras_hub.models.Preprocessor to create a model that can be used for
image classification. ImageClassifier tasks take an additional
num_classes argument, controlling the number of predicted output classes.
To fine-tune with fit(), pass a dataset containing tuples of (x, y)
labels where x is a string and y is a integer from [0, num_classes).
All ImageClassifier tasks include a from_preset() constructor which can
be used to load a pre-trained config and weights.
Arguments
keras_hub.models.Backbone instance or a keras.Model.None, a keras_hub.models.Preprocessor instance,
a keras.Layer instance, or a callable. If None no preprocessing
will be applied to the inputs."avg" or "max". The type of pooling to apply on backbone
output. Defaults to average pooling.None, str, or callable. The activation function to use on
the Dense layer. Set activation=None to return the output
logits. Defaults to "softmax".None, str, or keras.mixed_precision.DTypePolicy. The
dtype to use for the classification head's computations and weights.Examples
Call predict() to run inference.
# Load preset and train
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
classifier = keras_hub.models.ImageClassifier.from_preset(
"resnet_50_imagenet"
)
classifier.predict(images)
Call fit() on a single batch.
# Load preset and train
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
labels = [0, 3]
classifier = keras_hub.models.ImageClassifier.from_preset(
"resnet_50_imagenet"
)
classifier.fit(x=images, y=labels, batch_size=2)
Call fit() with custom loss, optimizer and backbone.
classifier = keras_hub.models.ImageClassifier.from_preset(
"resnet_50_imagenet"
)
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
)
classifier.backbone.trainable = False
classifier.fit(x=images, y=labels, batch_size=2)
Custom backbone.
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
labels = [0, 3]
backbone = keras_hub.models.ResNetBackbone(
stackwise_num_filters=[64, 64, 64],
stackwise_num_blocks=[2, 2, 2],
stackwise_num_strides=[1, 2, 2],
block_type="basic_block",
use_pre_activation=True,
pooling="avg",
)
classifier = keras_hub.models.ImageClassifier(
backbone=backbone,
num_classes=4,
)
classifier.fit(x=images, y=labels, batch_size=2)
from_preset methodImageClassifier.from_preset(preset, load_weights=True, **kwargs)
Instantiate a keras_hub.models.Task from a model preset.
A preset is a directory of configs, weights and other file assets used
to save and load a pre-trained model. The preset can be passed as
one of:
'bert_base_en''kaggle://user/bert/keras/bert_base_en''hf://user/bert_base_en''./bert_base_en'For any Task subclass, you can run cls.presets.keys() to list all
built-in presets available on the class.
This constructor can be called in one of two ways. Either from a task
specific base class like keras_hub.models.CausalLM.from_preset(), or
from a model class like
keras_hub.models.BertTextClassifier.from_preset().
If calling from the a base class, the subclass of the returning object
will be inferred from the config in the preset directory.
Arguments
True, saved weights will be loaded into
the model architecture. If False, all weights will be
randomly initialized.Examples
# Load a Gemma generative task.
causal_lm = keras_hub.models.CausalLM.from_preset(
"gemma_2b_en",
)
# Load a Bert classification task.
model = keras_hub.models.TextClassifier.from_preset(
"bert_base_en",
num_classes=2,
)
| Preset | Parameters | Description |
|---|---|---|
| csp_resnext_50_ra_imagenet | 20.57M | A CSP-ResNeXt (Cross-Stage-Partial) image classification model pre-trained on the Randomly Augmented ImageNet 1k dataset at a 256x256 resolution. |
| csp_resnet_50_ra_imagenet | 21.62M | A CSP-ResNet (Cross-Stage-Partial) image classification model pre-trained on the Randomly Augmented ImageNet 1k dataset at a 256x256 resolution. |
| csp_darknet_53_ra_imagenet | 27.64M | A CSP-DarkNet (Cross-Stage-Partial) image classification model pre-trained on the Randomly Augmented ImageNet 1k dataset at a 256x256 resolution. |
| darknet_53_imagenet | 41.61M | A DarkNet image classification model pre-trained on theImageNet 1k dataset at a 256x256 resolution. |
| deit_tiny_distilled_patch16_224_imagenet | 5.52M | DeiT-T16 model pre-trained on the ImageNet 1k dataset with image resolution of 224x224 |
| deit_small_distilled_patch16_224_imagenet | 21.67M | DeiT-S16 model pre-trained on the ImageNet 1k dataset with image resolution of 224x224 |
| deit_base_distilled_patch16_224_imagenet | 85.80M | DeiT-B16 model pre-trained on the ImageNet 1k dataset with image resolution of 224x224 |
| deit_base_distilled_patch16_384_imagenet | 86.09M | DeiT-B16 model pre-trained on the ImageNet 1k dataset with image resolution of 384x384 |
| densenet_121_imagenet | 7.04M | 121-layer DenseNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| densenet_169_imagenet | 12.64M | 169-layer DenseNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| densenet_201_imagenet | 18.32M | 201-layer DenseNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| efficientnet_lite0_ra_imagenet | 4.65M | EfficientNet-Lite model fine-trained on the ImageNet 1k dataset with RandAugment recipe. |
| efficientnet_b0_ra_imagenet | 5.29M | EfficientNet B0 model pre-trained on the ImageNet 1k dataset with RandAugment recipe. |
| efficientnet_b0_ra4_e3600_r224_imagenet | 5.29M | EfficientNet B0 model pre-trained on the ImageNet 1k dataset by Ross Wightman. Trained with timm scripts using hyper-parameters inspired by the MobileNet-V4 small, mixed with go-to hparams from timm and 'ResNet Strikes Back'. |
| efficientnet_es_ra_imagenet | 5.44M | EfficientNet-EdgeTPU Small model trained on the ImageNet 1k dataset with RandAugment recipe. |
| efficientnet_em_ra2_imagenet | 6.90M | EfficientNet-EdgeTPU Medium model trained on the ImageNet 1k dataset with RandAugment2 recipe. |
| efficientnet_b1_ft_imagenet | 7.79M | EfficientNet B1 model fine-tuned on the ImageNet 1k dataset. |
| efficientnet_b1_ra4_e3600_r240_imagenet | 7.79M | EfficientNet B1 model pre-trained on the ImageNet 1k dataset by Ross Wightman. Trained with timm scripts using hyper-parameters inspired by the MobileNet-V4 small, mixed with go-to hparams from timm and 'ResNet Strikes Back'. |
| efficientnet_b2_ra_imagenet | 9.11M | EfficientNet B2 model pre-trained on the ImageNet 1k dataset with RandAugment recipe. |
| efficientnet_el_ra_imagenet | 10.59M | EfficientNet-EdgeTPU Large model trained on the ImageNet 1k dataset with RandAugment recipe. |
| efficientnet_b3_ra2_imagenet | 12.23M | EfficientNet B3 model pre-trained on the ImageNet 1k dataset with RandAugment2 recipe. |
| efficientnet2_rw_t_ra2_imagenet | 13.65M | EfficientNet-v2 Tiny model trained on the ImageNet 1k dataset with RandAugment2 recipe. |
| efficientnet_b4_ra2_imagenet | 19.34M | EfficientNet B4 model pre-trained on the ImageNet 1k dataset with RandAugment2 recipe. |
| efficientnet2_rw_s_ra2_imagenet | 23.94M | EfficientNet-v2 Small model trained on the ImageNet 1k dataset with RandAugment2 recipe. |
| efficientnet_b5_sw_imagenet | 30.39M | EfficientNet B5 model pre-trained on the ImageNet 12k dataset by Ross Wightman. Based on Swin Transformer train / pretrain recipe with modifications (related to both DeiT and ConvNeXt recipes). |
| efficientnet_b5_sw_ft_imagenet | 30.39M | EfficientNet B5 model pre-trained on the ImageNet 12k dataset and fine-tuned on ImageNet-1k by Ross Wightman. Based on Swin Transformer train / pretrain recipe with modifications (related to both DeiT and ConvNeXt recipes). |
| efficientnet2_rw_m_agc_imagenet | 53.24M | EfficientNet-v2 Medium model trained on the ImageNet 1k dataset with adaptive gradient clipping. |
| hgnetv2_b4_ssld_stage2_ft_in1k | 13.60M | HGNetV2 B4 model with 2-stage SSLD training, fine-tuned on ImageNet-1K. |
| hgnetv2_b5_ssld_stage1_in22k_in1k | 33.42M | HGNetV2 B5 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K. |
| hgnetv2_b5_ssld_stage2_ft_in1k | 33.42M | HGNetV2 B5 model with 2-stage SSLD training, fine-tuned on ImageNet-1K. |
| hgnetv2_b6_ssld_stage1_in22k_in1k | 69.18M | HGNetV2 B6 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K. |
| hgnetv2_b6_ssld_stage2_ft_in1k | 69.18M | HGNetV2 B6 model with 2-stage SSLD training, fine-tuned on ImageNet-1K. |
| mit_b0_ade20k_512 | 3.32M | MiT (MixTransformer) model with 8 transformer blocks. |
| mit_b0_cityscapes_1024 | 3.32M | MiT (MixTransformer) model with 8 transformer blocks. |
| mit_b1_ade20k_512 | 13.16M | MiT (MixTransformer) model with 8 transformer blocks. |
| mit_b1_cityscapes_1024 | 13.16M | MiT (MixTransformer) model with 8 transformer blocks. |
| mit_b2_ade20k_512 | 24.20M | MiT (MixTransformer) model with 16 transformer blocks. |
| mit_b2_cityscapes_1024 | 24.20M | MiT (MixTransformer) model with 16 transformer blocks. |
| mit_b3_ade20k_512 | 44.08M | MiT (MixTransformer) model with 28 transformer blocks. |
| mit_b3_cityscapes_1024 | 44.08M | MiT (MixTransformer) model with 28 transformer blocks. |
| mit_b4_ade20k_512 | 60.85M | MiT (MixTransformer) model with 41 transformer blocks. |
| mit_b4_cityscapes_1024 | 60.85M | MiT (MixTransformer) model with 41 transformer blocks. |
| mit_b5_ade20k_640 | 81.45M | MiT (MixTransformer) model with 52 transformer blocks. |
| mit_b5_cityscapes_1024 | 81.45M | MiT (MixTransformer) model with 52 transformer blocks. |
| mobilenet_v3_small_050_imagenet | 278.78K | Small Mobilenet V3 model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Has half channel multiplier. |
| mobilenet_v3_small_100_imagenet | 939.12K | Small Mobilenet V3 model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Has baseline channel multiplier. |
| mobilenet_v3_large_100_imagenet | 3.00M | Large Mobilenet V3 model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. Has baseline channel multiplier. |
| mobilenet_v3_large_100_imagenet_21k | 3.00M | Large Mobilenet V3 model pre-trained on the ImageNet 21k dataset at a 224x224 resolution. Has baseline channel multiplier. |
| mobilenetv5_300m_enc_gemma3n | 294.28M | Lightweight 300M-parameter convolutional vision encoder used as the image backbone for Gemma 3n |
| resnet_18_imagenet | 11.19M | 18-layer ResNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| resnet_vd_18_imagenet | 11.72M | 18-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| resnet_vd_34_imagenet | 21.84M | 34-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| resnet_50_imagenet | 23.56M | 50-layer ResNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| resnet_v2_50_imagenet | 23.56M | 50-layer ResNetV2 model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| resnet_vd_50_imagenet | 25.63M | 50-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| resnet_vd_50_ssld_imagenet | 25.63M | 50-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution with knowledge distillation. |
| resnet_vd_50_ssld_v2_imagenet | 25.63M | 50-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution with knowledge distillation and AutoAugment. |
| resnet_vd_50_ssld_v2_fix_imagenet | 25.63M | 50-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution with knowledge distillation, AutoAugment and additional fine-tuning of the classification head. |
| resnet_101_imagenet | 42.61M | 101-layer ResNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| resnet_v2_101_imagenet | 42.61M | 101-layer ResNetV2 model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| resnet_vd_101_imagenet | 44.67M | 101-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| resnet_vd_101_ssld_imagenet | 44.67M | 101-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution with knowledge distillation. |
| resnet_152_imagenet | 58.30M | 152-layer ResNet model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| resnet_vd_152_imagenet | 60.36M | 152-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| resnet_vd_200_imagenet | 74.93M | 200-layer ResNetVD (ResNet with bag of tricks) model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| vgg_11_imagenet | 9.22M | 11-layer vgg model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| vgg_13_imagenet | 9.40M | 13-layer vgg model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| vgg_16_imagenet | 14.71M | 16-layer vgg model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| vgg_19_imagenet | 20.02M | 19-layer vgg model pre-trained on the ImageNet 1k dataset at a 224x224 resolution. |
| vit_base_patch16_224_imagenet | 85.80M | ViT-B16 model pre-trained on the ImageNet 1k dataset with image resolution of 224x224 |
| vit_base_patch16_224_imagenet21k | 85.80M | ViT-B16 backbone pre-trained on the ImageNet 21k dataset with image resolution of 224x224 |
| vit_base_patch16_384_imagenet | 86.09M | ViT-B16 model pre-trained on the ImageNet 1k dataset with image resolution of 384x384 |
| vit_base_patch32_224_imagenet21k | 87.46M | ViT-B32 backbone pre-trained on the ImageNet 21k dataset with image resolution of 224x224 |
| vit_base_patch32_384_imagenet | 87.53M | ViT-B32 model pre-trained on the ImageNet 1k dataset with image resolution of 384x384 |
| vit_large_patch16_224_imagenet | 303.30M | ViT-L16 model pre-trained on the ImageNet 1k dataset with image resolution of 224x224 |
| vit_large_patch16_224_imagenet21k | 303.30M | ViT-L16 backbone pre-trained on the ImageNet 21k dataset with image resolution of 224x224 |
| vit_large_patch16_384_imagenet | 303.69M | ViT-L16 model pre-trained on the ImageNet 1k dataset with image resolution of 384x384 |
| vit_large_patch32_224_imagenet21k | 305.51M | ViT-L32 backbone pre-trained on the ImageNet 21k dataset with image resolution of 224x224 |
| vit_large_patch32_384_imagenet | 305.61M | ViT-L32 model pre-trained on the ImageNet 1k dataset with image resolution of 384x384 |
| vit_huge_patch14_224_imagenet21k | 630.76M | ViT-H14 backbone pre-trained on the ImageNet 21k dataset with image resolution of 224x224 |
| xception_41_imagenet | 20.86M | 41-layer Xception model pre-trained on ImageNet 1k. |
compile methodImageClassifier.compile(optimizer="auto", loss="auto", metrics="auto", **kwargs)
Configures the ImageClassifier task for training.
The ImageClassifier task extends the default compilation signature of
keras.Model.compile with defaults for optimizer, loss, and
metrics. To override these defaults, pass any value
to these arguments during compilation.
Arguments
"auto", an optimizer name, or a keras.Optimizer
instance. Defaults to "auto", which uses the default optimizer
for the given model and task. See keras.Model.compile and
keras.optimizers for more info on possible optimizer values."auto", a loss name, or a keras.losses.Loss instance.
Defaults to "auto", where a
keras.losses.SparseCategoricalCrossentropy loss will be
applied for the classification task. See
keras.Model.compile and keras.losses for more info on
possible loss values."auto", or a list of metrics to be evaluated by
the model during training and testing. Defaults to "auto",
where a keras.metrics.SparseCategoricalAccuracy will be
applied to track the accuracy of the model during training.
See keras.Model.compile and keras.metrics for
more info on possible metrics values.keras.Model.compile for a full list of arguments
supported by the compile method.save_to_preset methodImageClassifier.save_to_preset(preset_dir, max_shard_size=10)
Save task to a preset directory.
Arguments
int or float. Maximum size in GB for each
sharded file. If None, no sharding will be done. Defaults to
10.preprocessor propertykeras_hub.models.ImageClassifier.preprocessor
A keras_hub.models.Preprocessor layer used to preprocess input.
backbone propertykeras_hub.models.ImageClassifier.backbone
A keras_hub.models.Backbone model with the core architecture.