bridge.training.losses#

Module Contents#

Functions#

masked_next_token_loss

Loss function.

Data#

API#

bridge.training.losses.SPIKY_LOSS_FACTOR: int#

10

bridge.training.losses.masked_next_token_loss(
loss_mask: torch.Tensor,
output_tensor: torch.Tensor,
check_for_nan_in_loss: bool = True,
check_for_spiky_loss: bool = False,
) tuple[torch.Tensor, torch.Tensor, dict[str, tuple[torch.Tensor, torch.Tensor]]]#

Loss function.

Parameters:
  • loss_mask – Used to mask out some portions of the loss

  • output_tensor – The tensor with the losses

  • check_for_nan_in_loss – Whether to check for NaN values in the loss

  • check_for_spiky_loss – Whether to check for spiky loss values

Returns:

  • The loss scalar for this micro-batch

  • The number of non-padded tokens in this microbatch

  • A dict containing reporting metrics on the loss and number of tokens across the data parallel ranks

Return type:

tuple containing