Keras 3 API documentation / Models API / Knowledge distillation / Logits distillation loss

Logits distillation loss

[source]

LogitsDistillation class

keras.distillation.LogitsDistillation(temperature=3.0, loss="kl_divergence")

Distillation loss that transfers knowledge from final model outputs.

This distillation loss applies temperature scaling to the teacher's logits before computing the loss between teacher and student predictions. It's the most common approach for knowledge distillation.

Arguments

  • temperature: Temperature for softmax scaling. Higher values produce softer probability distributions that are easier for the student to learn. Typical values range from 3-5. Defaults to 3.0.
  • loss: Loss function to use for distillation. Can be:
    • String identifier (e.g., 'kl_divergence', 'categorical_crossentropy')
    • Keras loss instance
    • Nested structure of losses matching the model output structure
    • None to skip distillation for that output (useful for multi-output models where you only want to distill some outputs) At least one loss must be non-None. Defaults to 'kl_divergence'.

Examlpe(s):

# Basic logits distillation with KL divergence
distillation_loss = LogitsDistillation(temperature=3.0)

# With categorical crossentropy loss
distillation_loss = LogitsDistillation(
    temperature=4.0,
    loss="categorical_crossentropy"
)

# With custom loss instance
distillation_loss = LogitsDistillation(
    temperature=4.0,
    loss=keras.losses.CategoricalCrossentropy(from_logits=True)
)

# For multi-output models
distillation_loss = LogitsDistillation(
    temperature=3.0,
    loss=["kl_divergence", "categorical_crossentropy"]
)

# For multi-output models, only distill some outputs
distillation_loss = LogitsDistillation(
    temperature=3.0,
    loss=["kl_divergence", None]  # Skip second output
)