Skip to content

Loss

ClassifierLossReduction

Bases: BERTMLMLossWithReduction

A class for calculating the cross entropy loss of classification output.

This class used for calculating the loss, and for logging the reduced loss across micro batches.

Source code in bionemo/esm2/model/finetune/loss.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class ClassifierLossReduction(BERTMLMLossWithReduction):
    """A class for calculating the cross entropy loss of classification output.

    This class used for calculating the loss, and for logging the reduced loss across micro batches.
    """

    def forward(
        self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU. The averaging of the loss is done in https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L304-L314.

        Args:
            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
            forward_out: the output of the forward method inside classification head.
        """
        targets = batch["labels"].squeeze()  # [b] or [b, s] for sequence-level or token-level classification
        loss_mask = batch["loss_mask"]

        classification_output = forward_out["classification_output"]  # [b, num_class] or [b, s, num_class]
        if classification_output.dim() == 3:
            classification_output = classification_output.permute(1, 0, 2).contiguous()  # change to [s, b, num_class]
        elif classification_output.dim() == 2:
            # NOTE: this is for sequence-level classification, we artificially create a sequence dimension to use the same code path as token-level classification
            classification_output = classification_output.unsqueeze(0)  # change to [1, b, num_class]
            targets = targets.unsqueeze(1)  # change to [b, 1]
            loss_mask = torch.ones((targets.shape[0], 1), dtype=loss_mask.dtype, device=loss_mask.device)
        else:
            raise ValueError(f"Unexpected classification output dimension: {classification_output.dim()}")

        # NOTE: token_logits is [sequence, batch] but labels and other fields, including the loss are [batch, sequence]
        unreduced_token_loss = unreduced_token_loss_fn(classification_output, targets)  # [b s]
        loss_sum, num_valid_tokens = masked_token_loss(unreduced_token_loss, loss_mask)  # torch.float, torch.int

        if self.validation_step and not self.val_drop_last and loss_sum.isnan():
            assert num_valid_tokens == 0, "Got NaN loss with non-empty input"
            if loss_mask.count_nonzero() != 0:
                raise ValueError("Got NaN loss with non-empty input")
            loss_sum = torch.zeros_like(num_valid_tokens)

        num_valid_tokens = num_valid_tokens.clone().detach().to(torch.int)
        loss_sum_and_ub_size = torch.cat([loss_sum.clone().detach().view(1), num_valid_tokens.view(1)])
        return loss_sum, num_valid_tokens, {"loss_sum_and_ub_size": loss_sum_and_ub_size}

forward(batch, forward_out)

Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU. The averaging of the loss is done in https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L304-L314.

Parameters:

Name Type Description Default
batch Dict[str, Tensor]

A batch of data that gets passed to the original forward inside LitAutoEncoder.

required
forward_out Dict[str, Tensor]

the output of the forward method inside classification head.

required
Source code in bionemo/esm2/model/finetune/loss.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def forward(
    self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU. The averaging of the loss is done in https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L304-L314.

    Args:
        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
        forward_out: the output of the forward method inside classification head.
    """
    targets = batch["labels"].squeeze()  # [b] or [b, s] for sequence-level or token-level classification
    loss_mask = batch["loss_mask"]

    classification_output = forward_out["classification_output"]  # [b, num_class] or [b, s, num_class]
    if classification_output.dim() == 3:
        classification_output = classification_output.permute(1, 0, 2).contiguous()  # change to [s, b, num_class]
    elif classification_output.dim() == 2:
        # NOTE: this is for sequence-level classification, we artificially create a sequence dimension to use the same code path as token-level classification
        classification_output = classification_output.unsqueeze(0)  # change to [1, b, num_class]
        targets = targets.unsqueeze(1)  # change to [b, 1]
        loss_mask = torch.ones((targets.shape[0], 1), dtype=loss_mask.dtype, device=loss_mask.device)
    else:
        raise ValueError(f"Unexpected classification output dimension: {classification_output.dim()}")

    # NOTE: token_logits is [sequence, batch] but labels and other fields, including the loss are [batch, sequence]
    unreduced_token_loss = unreduced_token_loss_fn(classification_output, targets)  # [b s]
    loss_sum, num_valid_tokens = masked_token_loss(unreduced_token_loss, loss_mask)  # torch.float, torch.int

    if self.validation_step and not self.val_drop_last and loss_sum.isnan():
        assert num_valid_tokens == 0, "Got NaN loss with non-empty input"
        if loss_mask.count_nonzero() != 0:
            raise ValueError("Got NaN loss with non-empty input")
        loss_sum = torch.zeros_like(num_valid_tokens)

    num_valid_tokens = num_valid_tokens.clone().detach().to(torch.int)
    loss_sum_and_ub_size = torch.cat([loss_sum.clone().detach().view(1), num_valid_tokens.view(1)])
    return loss_sum, num_valid_tokens, {"loss_sum_and_ub_size": loss_sum_and_ub_size}

RegressorLossReduction

Bases: BERTMLMLossWithReduction

A class for calculating the MSE loss of regression output.

This class used for calculating the loss, and for logging the reduced loss across micro batches.

Source code in bionemo/esm2/model/finetune/loss.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class RegressorLossReduction(BERTMLMLossWithReduction):
    """A class for calculating the MSE loss of regression output.

    This class used for calculating the loss, and for logging the reduced loss across micro batches.
    """

    def forward(
        self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        """Calculates the sum of squared errors within a micro-batch. A micro-batch is a batch of data on a single GPU. The averaging of the loss, i.e. MSE loss, is done in https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L304-L314.

        Args:
            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
            forward_out: the output of the forward method inside classification head.
        """
        regression_output = forward_out["regression_output"]
        targets = batch["labels"].to(dtype=regression_output.dtype)  # [b, 1]

        num_valid_tokens = torch.tensor(targets.numel(), dtype=torch.int, device=targets.device)

        cp_size = parallel_state.get_context_parallel_world_size()
        if cp_size == 1:
            loss_sum = ((regression_output - targets) ** 2).sum()  # torch.float
        else:
            raise NotImplementedError("Context Parallel support is not implemented for this loss")

        loss_sum_and_ub_size = torch.cat([loss_sum.clone().detach().view(1), num_valid_tokens.view(1)])
        return loss_sum, num_valid_tokens, {"loss_sum_and_ub_size": loss_sum_and_ub_size}

forward(batch, forward_out)

Calculates the sum of squared errors within a micro-batch. A micro-batch is a batch of data on a single GPU. The averaging of the loss, i.e. MSE loss, is done in https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L304-L314.

Parameters:

Name Type Description Default
batch Dict[str, Tensor]

A batch of data that gets passed to the original forward inside LitAutoEncoder.

required
forward_out Dict[str, Tensor]

the output of the forward method inside classification head.

required
Source code in bionemo/esm2/model/finetune/loss.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def forward(
    self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
    """Calculates the sum of squared errors within a micro-batch. A micro-batch is a batch of data on a single GPU. The averaging of the loss, i.e. MSE loss, is done in https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L304-L314.

    Args:
        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
        forward_out: the output of the forward method inside classification head.
    """
    regression_output = forward_out["regression_output"]
    targets = batch["labels"].to(dtype=regression_output.dtype)  # [b, 1]

    num_valid_tokens = torch.tensor(targets.numel(), dtype=torch.int, device=targets.device)

    cp_size = parallel_state.get_context_parallel_world_size()
    if cp_size == 1:
        loss_sum = ((regression_output - targets) ** 2).sum()  # torch.float
    else:
        raise NotImplementedError("Context Parallel support is not implemented for this loss")

    loss_sum_and_ub_size = torch.cat([loss_sum.clone().detach().view(1), num_valid_tokens.view(1)])
    return loss_sum, num_valid_tokens, {"loss_sum_and_ub_size": loss_sum_and_ub_size}