Keras 3 API documentation / Models API / Knowledge distillation / Feature distillation loss

Feature distillation loss

[source]

FeatureDistillation class

keras.distillation.FeatureDistillation(
    loss="mse", teacher_layer_name=None, student_layer_name=None
)

Feature distillation loss.

Feature distillation transfers knowledge from intermediate layers of the teacher model to corresponding layers of the student model. This approach helps the student learn better internal representations and often leads to better performance compared to logits-only distillation.

Arguments

  • loss: Loss function to use for feature distillation. Can be:
    • String identifier (e.g., 'mse', 'cosine_similarity', 'mae')
    • Keras loss instance
    • Nested structure of losses matching the layer output structure
    • None to skip distillation for that output (useful for multi-output models where you only want to distill some outputs) At least one loss must be non-None. Defaults to 'mse'.
  • teacher_layer_name: Name of the teacher layer to extract features from. If None, uses the final output. Defaults to None.
  • student_layer_name: Name of the student layer to extract features from. If None, uses the final output. Defaults to None.

Examlpe(s):

# Basic feature distillation from final outputs
distillation_loss = FeatureDistillation(loss="mse")

# Distill from specific intermediate layers
distillation_loss = FeatureDistillation(
    loss="mse",
    teacher_layer_name="dense_1",
    student_layer_name="dense_1"
)

# Use cosine similarity for different feature sizes
distillation_loss = FeatureDistillation(
    loss="cosine_similarity",
    teacher_layer_name="conv2d_2",
    student_layer_name="conv2d_1"
)

# With custom loss instance
distillation_loss = FeatureDistillation(
    loss=keras.losses.MeanAbsoluteError()
)

# For multi-output models
distillation_loss = FeatureDistillation(
    loss=["mse", "cosine_similarity"]
)

# For multi-output models, only distill some outputs
distillation_loss = FeatureDistillation(
    loss=["mse", None, "cosine_similarity"]  # Skip middle output
)