bridge.training.losses
#
Module Contents#
Functions#
Loss function. |
Data#
API#
- bridge.training.losses.SPIKY_LOSS_FACTOR: int#
10
- bridge.training.losses.masked_next_token_loss(
- loss_mask: torch.Tensor,
- output_tensor: torch.Tensor,
- check_for_nan_in_loss: bool = True,
- check_for_spiky_loss: bool = False,
Loss function.
- Parameters:
loss_mask – Used to mask out some portions of the loss
output_tensor – The tensor with the losses
check_for_nan_in_loss – Whether to check for NaN values in the loss
check_for_spiky_loss – Whether to check for spiky loss values
- Returns:
The loss scalar for this micro-batch
The number of non-padded tokens in this microbatch
A dict containing reporting metrics on the loss and number of tokens across the data parallel ranks
- Return type:
tuple containing