bridge.training.post_training.distillation#

Module Contents#

Classes#

ModelOptDistillConfig

Configuration settings for Model Optimizer distillation.

Functions#

loss_func_kd

Loss function (with KD Loss support).

_mask_loss

API#

class bridge.training.post_training.distillation.ModelOptDistillConfig#

Bases: modelopt.torch.distill.plugins.megatron.DistillationConfig

Configuration settings for Model Optimizer distillation.

bridge.training.post_training.distillation.loss_func_kd(
output_tensor: torch.Tensor,
loss_mask: torch.Tensor,
original_loss_fn: Callable,
model: megatron.core.transformer.MegatronModule,
)#

Loss function (with KD Loss support).

Parameters:
  • output_tensor (Tensor) – The tensor with the losses

  • loss_mask (Tensor) – Used to mask out some portions of the loss

  • original_loss_fn (Callable) – The original loss function

  • model (GPTModel) – The model (can be wrapped)

bridge.training.post_training.distillation._mask_loss(output_tensor: torch.Tensor, loss_mask: torch.Tensor)#