DFineBackbone classkeras_hub.models.DFineBackbone(
backbone,
decoder_in_channels,
encoder_hidden_dim,
num_labels,
num_denoising,
learn_initial_query,
num_queries,
anchor_image_size,
feat_strides,
num_feature_levels,
hidden_dim,
encoder_in_channels,
encode_proj_layers,
num_attention_heads,
encoder_ffn_dim,
num_encoder_layers,
hidden_expansion,
depth_multiplier,
eval_idx,
num_decoder_layers,
decoder_attention_heads,
decoder_ffn_dim,
decoder_n_points,
lqe_hidden_dim,
num_lqe_layers,
decoder_method="default",
label_noise_ratio=0.5,
box_noise_scale=1.0,
labels=None,
seed=None,
image_shape=(None, None, 3),
out_features=None,
data_format=None,
dtype=None,
**kwargs
)
D-FINE Backbone for Object Detection.
This class implements the core D-FINE architecture, which serves as the
backbone for DFineObjectDetector. It integrates a HGNetV2Backbone for
initial feature extraction, a DFineHybridEncoder for multi-scale feature
fusion using FPN/PAN pathways, and a DFineDecoder for refining object
queries.
The backbone orchestrates the entire forward pass, from processing raw
pixels to generating intermediate predictions. Key steps include:
1. Extracting multi-scale feature maps using the HGNetV2 backbone.
2. Fusing these features with the hybrid encoder.
3. Generating anchor proposals and selecting the top-k to initialize
decoder queries and reference points.
4. Generating noisy queries for contrastive denoising (if the labels
argument is provided).
5. Passing the queries and fused features through the transformer decoder
to produce iterative predictions for bounding boxes and class logits.
Arguments
keras.Model instance that serves as the feature extractor.
While any keras.Model can be used, we highly recommend using a
keras_hub.models.HGNetV2Backbone instance, as this architecture is
optimized for its outputs. If a custom backbone is provided, it
must have a stage_names attribute, or the out_features argument
for this model must be specified. This requirement helps prevent
hard-to-debug downstream dimensionality errors.encoder_hidden_dim repeated for each feature level.0 to disable denoising.(height, width).HGNetV2Backbone) that are fed into the hybrid
encoder.-1 for the last
layer."default" or
"discrete". Defaults to "default".0.5.1.0."boxes" (numpy array of shape [N, 4] with
normalized coordinates) and "labels" (numpy array of shape [N]
with class indices). Required when num_denoising > 0. Defaults to
None.None.(height, width,
channels). Height and width can be None for variable input sizes.
Defaults to (None, None, 3).None, uses the last len(decoder_in_channels)
features from the backbone's stage_names. Defaults to None."channels_first" or "channels_last". If None is specified,
it will use the image_data_format value found in your Keras
config file at ~/.keras/keras.json. Defaults to None.None or str or keras.mixed_precision.DTypePolicy. The dtype
to use for the model's computations and weights. Defaults to None.Example
import keras
import numpy as np
from keras_hub.models import DFineBackbone
from keras_hub.models import HGNetV2Backbone
# Example 1: Basic usage without denoising.
# First, build the `HGNetV2Backbone` instance.
hgnetv2 = HGNetV2Backbone(
stem_channels=[3, 16, 16],
stackwise_stage_filters=[
[16, 16, 64, 1, 3, 3],
[64, 32, 256, 1, 3, 3],
[256, 64, 512, 2, 3, 5],
[512, 128, 1024, 1, 3, 5],
],
apply_downsample=[False, True, True, True],
use_lightweight_conv_block=[False, False, True, True],
depths=[1, 1, 2, 1],
hidden_sizes=[64, 256, 512, 1024],
embedding_size=16,
use_learnable_affine_block=True,
hidden_act="relu",
image_shape=(None, None, 3),
out_features=["stage3", "stage4"],
data_format="channels_last",
)
# Then, pass the backbone instance to `DFineBackbone`.
backbone = DFineBackbone(
backbone=hgnetv2,
decoder_in_channels=[128, 128],
encoder_hidden_dim=128,
num_denoising=0, # Disable denoising
num_labels=80,
hidden_dim=128,
learn_initial_query=False,
num_queries=300,
anchor_image_size=(256, 256),
feat_strides=[16, 32],
num_feature_levels=2,
encoder_in_channels=[512, 1024],
encode_proj_layers=[1],
num_attention_heads=8,
encoder_ffn_dim=512,
num_encoder_layers=1,
hidden_expansion=0.34,
depth_multiplier=0.5,
eval_idx=-1,
num_decoder_layers=3,
decoder_attention_heads=8,
decoder_ffn_dim=512,
decoder_n_points=[6, 6],
lqe_hidden_dim=64,
num_lqe_layers=2,
out_features=["stage3", "stage4"],
image_shape=(None, None, 3),
data_format="channels_last",
seed=0,
)
# Prepare input data.
input_data = keras.random.uniform((2, 256, 256, 3))
# Forward pass.
outputs = backbone(input_data)
# Example 2: With contrastive denoising training.
labels = [
{
"boxes": np.array([[0.5, 0.5, 0.2, 0.2], [0.4, 0.4, 0.1, 0.1]]),
"labels": np.array([1, 10]),
},
{
"boxes": np.array([[0.6, 0.6, 0.3, 0.3]]),
"labels": np.array([20]),
},
]
# Pass the `HGNetV2Backbone` instance to `DFineBackbone`.
backbone_with_denoising = DFineBackbone(
backbone=hgnetv2,
decoder_in_channels=[128, 128],
encoder_hidden_dim=128,
num_denoising=100, # Enable denoising
num_labels=80,
hidden_dim=128,
learn_initial_query=False,
num_queries=300,
anchor_image_size=(256, 256),
feat_strides=[16, 32],
num_feature_levels=2,
encoder_in_channels=[512, 1024],
encode_proj_layers=[1],
num_attention_heads=8,
encoder_ffn_dim=512,
num_encoder_layers=1,
hidden_expansion=0.34,
depth_multiplier=0.5,
eval_idx=-1,
num_decoder_layers=3,
decoder_attention_heads=8,
decoder_ffn_dim=512,
decoder_n_points=[6, 6],
lqe_hidden_dim=64,
num_lqe_layers=2,
out_features=["stage3", "stage4"],
image_shape=(None, None, 3),
seed=0,
labels=labels,
)
# Forward pass with denoising.
outputs_with_denoising = backbone_with_denoising(input_data)
from_preset methodDFineBackbone.from_preset(preset, load_weights=True, **kwargs)
Instantiate a keras_hub.models.Backbone from a model preset.
A preset is a directory of configs, weights and other file assets used
to save and load a pre-trained model. The preset can be passed as a
one of:
'bert_base_en''kaggle://user/bert/keras/bert_base_en''hf://user/bert_base_en''modelscope://user/bert_base_en''./bert_base_en'This constructor can be called in one of two ways. Either from the base
class like keras_hub.models.Backbone.from_preset(), or from
a model class like keras_hub.models.GemmaBackbone.from_preset().
If calling from the base class, the subclass of the returning object
will be inferred from the config in the preset directory.
For any Backbone subclass, you can run cls.presets.keys() to list
all built-in presets available on the class.
Arguments
True, the weights will be loaded into the
model architecture. If False, the weights will be randomly
initialized.Examples
# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(
"gemma_2b_en",
)
# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(
"bert_base_en",
load_weights=False,
)
| Preset | Parameters | Description |
|---|---|---|
| dfine_nano_coco | 3.79M | D-FINE Nano model, the smallest variant in the family, pretrained on the COCO dataset. Ideal for applications where computational resources are limited. |
| dfine_small_coco | 10.33M | D-FINE Small model pretrained on the COCO dataset. Offers a balance between performance and computational efficiency. |
| dfine_small_obj2coco | 10.33M | D-FINE Small model first pretrained on Objects365 and then fine-tuned on COCO, combining broad feature learning with benchmark-specific adaptation. |
| dfine_small_obj365 | 10.62M | D-FINE Small model pretrained on the large-scale Objects365 dataset, enhancing its ability to recognize a wider variety of objects. |
| dfine_medium_coco | 19.62M | D-FINE Medium model pretrained on the COCO dataset. A solid baseline with strong performance for general-purpose object detection. |
| dfine_medium_obj2coco | 19.62M | D-FINE Medium model using a two-stage training process: pretraining on Objects365 followed by fine-tuning on COCO. |
| dfine_medium_obj365 | 19.99M | D-FINE Medium model pretrained on the Objects365 dataset. Benefits from a larger and more diverse pretraining corpus. |
| dfine_large_coco | 31.34M | D-FINE Large model pretrained on the COCO dataset. Provides high accuracy and is suitable for more demanding tasks. |
| dfine_large_obj2coco_e25 | 31.34M | D-FINE Large model pretrained on Objects365 and then fine-tuned on COCO for 25 epochs. A high-performance model with specialized tuning. |
| dfine_large_obj365 | 31.86M | D-FINE Large model pretrained on the Objects365 dataset for improved generalization and performance on diverse object categories. |
| dfine_xlarge_coco | 62.83M | D-FINE X-Large model, the largest COCO-pretrained variant, designed for state-of-the-art performance where accuracy is the top priority. |
| dfine_xlarge_obj2coco | 62.83M | D-FINE X-Large model, pretrained on Objects365 and fine-tuned on COCO, representing the most powerful model in this series for COCO-style tasks. |
| dfine_xlarge_obj365 | 63.35M | D-FINE X-Large model pretrained on the Objects365 dataset, offering maximum performance by leveraging a vast number of object categories during pretraining. |