LogitsDistillation classkeras.distillation.LogitsDistillation(temperature=3.0, loss="kl_divergence")
Distillation loss that transfers knowledge from final model outputs.
This distillation loss applies temperature scaling to the teacher's logits before computing the loss between teacher and student predictions. It's the most common approach for knowledge distillation.
Arguments
None to skip distillation for that output (useful for
multi-output models where you only want to distill some outputs)
At least one loss must be non-None. Defaults to 'kl_divergence'.Examlpe(s):
# Basic logits distillation with KL divergence
distillation_loss = LogitsDistillation(temperature=3.0)
# With categorical crossentropy loss
distillation_loss = LogitsDistillation(
temperature=4.0,
loss="categorical_crossentropy"
)
# With custom loss instance
distillation_loss = LogitsDistillation(
temperature=4.0,
loss=keras.losses.CategoricalCrossentropy(from_logits=True)
)
# For multi-output models
distillation_loss = LogitsDistillation(
temperature=3.0,
loss=["kl_divergence", "categorical_crossentropy"]
)
# For multi-output models, only distill some outputs
distillation_loss = LogitsDistillation(
temperature=3.0,
loss=["kl_divergence", None] # Skip second output
)