bridge.training.losses
#
Module Contents#
Functions#
Create a partial loss function configured for masked next-token loss. |
|
Loss function. |
Data#
API#
- bridge.training.losses.SPIKY_LOSS_FACTOR: int#
10
- bridge.training.losses.create_masked_next_token_loss_function(
- loss_mask: torch.Tensor,
- check_for_nan_in_loss: bool,
- check_for_spiky_loss: bool,
Create a partial loss function configured for masked next-token loss.
This replaces the generic helper previously in utils/loss_utils.py.
- bridge.training.losses.masked_next_token_loss(
- loss_mask: torch.Tensor,
- output_tensor: torch.Tensor,
- check_for_nan_in_loss: bool = True,
- check_for_spiky_loss: bool = False,
Loss function.
- Parameters:
loss_mask – Used to mask out some portions of the loss
output_tensor – The tensor with the losses
check_for_nan_in_loss – Whether to check for NaN values in the loss
check_for_spiky_loss – Whether to check for spiky loss values
- Returns:
The loss scalar for this micro-batch
The number of non-padded tokens in this microbatch
A dict containing reporting metrics on the loss and number of tokens across the data parallel ranks
- Return type:
tuple containing