HGNetV2ImageClassifier
classkeras_hub.models.HGNetV2ImageClassifier(
backbone,
preprocessor,
num_classes,
head_filters=None,
pooling="avg",
activation=None,
dropout=0.0,
head_dtype=None,
**kwargs
)
HGNetV2 image classification model.
HGNetV2ImageClassifier
wraps a HGNetV2Backbone
and
a HGNetV2ImageClassifierPreprocessor
to create a model that can be used
for image classification tasks. This model implements the HGNetV2
architecture with an additional classification head including a 1x1
convolution layer, global pooling, and a dense output layer.
The model takes an additional num_classes
argument, controlling the number
of predicted output classes, and optionally, a head_filters
argument to
specify the number of filters in the classification head's convolution
layer. To fine-tune with fit()
, pass a dataset containing tuples of
(x, y)
labels where x
is an image tensor and y
is an integer from
[0, num_classes)
.
Arguments
HGNetV2Backbone
instance.HGNetV2ImageClassifierPreprocessor
instance,
a keras.Layer
instance, or a callable. If None
no preprocessing
will be applied to the inputs.None
, it defaults
to the last value of hidden_sizes
from the backbone."avg"
or "max"
. The type of global pooling to apply after
the head convolution. Defaults to "avg"
.None
, str, or callable. The activation function to use on
the final Dense
layer. Set activation=None
to return the output
logits. Defaults to None
.None
, str, or keras.mixed_precision.DTypePolicy
. The
dtype to use for the classification head's computations and weights.Examples
Call predict()
to run inference.
# Load preset and predict.
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
classifier = keras_hub.models.HGNetV2ImageClassifier.from_preset(
"hgnetv2_b5_ssld_stage2_ft_in1k"
)
classifier.predict(images)
Call fit()
on a single batch.
# Load preset and train.
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
labels = [0, 3]
classifier = keras_hub.models.HGNetV2ImageClassifier.from_preset(
"hgnetv2_b5_ssld_stage2_ft_in1k"
)
classifier.fit(x=images, y=labels, batch_size=2)
Call fit()
with custom loss, optimizer and frozen backbone.
classifier = keras_hub.models.HGNetV2ImageClassifier.from_preset(
"hgnetv2_b5_ssld_stage2_ft_in1k"
)
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
)
classifier.backbone.trainable = False
classifier.fit(x=images, y=labels, batch_size=2)
Create a custom HGNetV2 classifier with specific head configuration.
backbone = keras_hub.models.HGNetV2Backbone.from_preset(
"hgnetv2_b5_ssld_stage2_ft_in1k"
)
preproc = keras_hub.models.HGNetV2ImageClassifierPreprocessor.from_preset(
"hgnetv2_b5_ssld_stage2_ft_in1k"
)
classifier = keras_hub.models.HGNetV2ImageClassifier(
backbone=backbone,
preprocessor=preproc,
num_classes=10,
pooling="avg",
dropout=0.2,
)
from_preset
methodHGNetV2ImageClassifier.from_preset(preset, load_weights=True, **kwargs)
Instantiate a keras_hub.models.Task
from a model preset.
A preset is a directory of configs, weights and other file assets used
to save and load a pre-trained model. The preset
can be passed as
one of:
'bert_base_en'
'kaggle://user/bert/keras/bert_base_en'
'hf://user/bert_base_en'
'./bert_base_en'
For any Task
subclass, you can run cls.presets.keys()
to list all
built-in presets available on the class.
This constructor can be called in one of two ways. Either from a task
specific base class like keras_hub.models.CausalLM.from_preset()
, or
from a model class like
keras_hub.models.BertTextClassifier.from_preset()
.
If calling from the a base class, the subclass of the returning object
will be inferred from the config in the preset directory.
Arguments
True
, saved weights will be loaded into
the model architecture. If False
, all weights will be
randomly initialized.Examples
# Load a Gemma generative task.
causal_lm = keras_hub.models.CausalLM.from_preset(
"gemma_2b_en",
)
# Load a Bert classification task.
model = keras_hub.models.TextClassifier.from_preset(
"bert_base_en",
num_classes=2,
)
Preset | Parameters | Description |
---|---|---|
hgnetv2_b4_ssld_stage2_ft_in1k | 13.60M | HGNetV2 B4 model with 2-stage SSLD training, fine-tuned on ImageNet-1K. |
hgnetv2_b5_ssld_stage1_in22k_in1k | 33.42M | HGNetV2 B5 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K. |
hgnetv2_b5_ssld_stage2_ft_in1k | 33.42M | HGNetV2 B5 model with 2-stage SSLD training, fine-tuned on ImageNet-1K. |
hgnetv2_b6_ssld_stage1_in22k_in1k | 69.18M | HGNetV2 B6 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K. |
hgnetv2_b6_ssld_stage2_ft_in1k | 69.18M | HGNetV2 B6 model with 2-stage SSLD training, fine-tuned on ImageNet-1K. |
backbone
propertykeras_hub.models.HGNetV2ImageClassifier.backbone
A keras_hub.models.Backbone
model with the core architecture.
preprocessor
propertykeras_hub.models.HGNetV2ImageClassifier.preprocessor
A keras_hub.models.Preprocessor
layer used to preprocess input.