Distiller classkeras.distillation.Distiller(
teacher,
student,
distillation_losses,
distillation_loss_weights=None,
student_loss_weight=0.5,
name="distiller",
**kwargs
)
Distillation model for transferring knowledge from teacher to student.
Knowledge distillation transfers knowledge from a large, complex model (teacher) to a smaller, simpler model (student). The student learns from both ground truth labels and the teacher's predictions, often achieving better performance than training on labels alone.
Arguments
keras.Model that serves as the knowledge source.
The teacher model is frozen during distillation.keras.Model to be trained through distillation.keras.distillation.LogitsDistillation,
keras.distillation.FeatureDistillation, or custom distillation
losses.distillation_losses. If None,
equal weights are used."distiller".Model
class.Attributes
Examples
# Basic distillation with KerasHub models
import keras_hub as hub
teacher = hub.models.CausalLM.from_preset("gemma_2b_en")
student = hub.models.CausalLM.from_preset(
"gemma_1.1_2b_en", load_weights=False
)
# Single distillation loss
distiller = Distiller(
teacher=teacher,
student=student,
distillation_losses=LogitsDistillation(temperature=3.0),
)
# Compile the distiller (like any Keras model)
distiller.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Train the distiller
distiller.fit(x_train, y_train, epochs=10)
# Access the trained student model
trained_student = distiller.student
# Multiple distillation losses
distiller = Distiller(
teacher=teacher,
student=student,
distillation_losses=[
LogitsDistillation(temperature=3.0),
FeatureDistillation(
teacher_layer_name="dense_1",
student_layer_name="dense_1"
)
],
distillation_loss_weights=[1.0, 0.5],
)
# Compile with custom settings
distiller.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)