Source code for nemo.collections.common.losses.smoothed_cross_entropy

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch

from nemo.core.classes import Exportable, Loss, NeuralModule, typecheck
from nemo.core.neural_types import LabelsType, LogprobsType, LossType, MaskType, NeuralType

__all__ = ["SmoothedCrossEntropyLoss", "SmoothedNLLLoss"]


[docs]class SmoothedCrossEntropyLoss(Loss): """ Calculates Cross-entropy loss with label smoothing for a batch of sequences. SmoothedCrossEntropyLoss: 1) excludes padding tokens from loss calculation 2) allows to use label smoothing regularization 3) allows to calculate loss for the desired number of last tokens 4) per_token_reduction - if False disables reduction per token Args: label_smoothing (float): label smoothing regularization coefficient predict_last_k (int): parameter which sets the number of last tokens to calculate the loss for, for example 0: (default) calculate loss on the entire sequence (e.g., NMT) 1: calculate loss on the last token only (e.g., LM evaluation) Intermediate values allow to control the trade-off between eval time (proportional to the number of batches) and eval performance (proportional to the number of context tokens) pad_id (int): padding id eps (float): the small eps number to avoid division buy zero """ @property def input_types(self): """Returns definitions of module input ports. """ return { "log_probs": NeuralType(("B", "T", "D"), LogprobsType()), "labels": NeuralType(("B", "T"), LabelsType()), "output_mask": NeuralType(("B", "T"), MaskType(), optional=True), } @property def output_types(self): """Returns definitions of module output ports. """ return {"loss": NeuralType(elements_type=LossType())}
[docs] def __init__( self, pad_id: Optional[int] = None, label_smoothing: Optional[float] = 0.0, predict_last_k: Optional[int] = 0, eps: float = 1e-6, per_token_reduction: bool = True, ): super().__init__() self._pad_id = pad_id self._eps = eps self._predict_last_k = predict_last_k self._label_smoothing = label_smoothing self._per_token_reduction = per_token_reduction
[docs] @typecheck() def forward(self, log_probs, labels, output_mask=None): """ Args: log_probs: float tensor of shape batch_size x seq_len x vocab_size, values should be log probabilities labels: int tensor of shape batch_size x seq_len output_mask: binary tensor of shape batch_size x seq_len eps: epsilon param to avoid divide by zero in loss calculation """ if output_mask is None and self._pad_id is None: raise ValueError("Both output_mask and pad_id are None") if output_mask is None and self._pad_id is not None: output_mask = (labels != self._pad_id).to(log_probs.dtype) if output_mask.dtype is not log_probs.dtype: output_mask = output_mask.to(log_probs.dtype) batch_size, seq_len, vocab_size = log_probs.size() smoothing = vocab_size * self._label_smoothing / (vocab_size - 1) target_log_probs = log_probs.gather(2, labels.unsqueeze(2)).squeeze(2) smoothing_log_probs = log_probs.mean(dim=-1) neg_log_likelihood = (1.0 - smoothing) * target_log_probs + smoothing * smoothing_log_probs neg_log_likelihood = neg_log_likelihood[:, -self._predict_last_k :] output_mask = output_mask[:, -self._predict_last_k :] # when False avoid per token reduction if self._per_token_reduction: neg_log_likelihood = -torch.sum(neg_log_likelihood * output_mask) neg_log_likelihood = neg_log_likelihood / (output_mask.sum() + self._eps) else: neg_log_likelihood = -(neg_log_likelihood * output_mask) return neg_log_likelihood
class SmoothedNLLLoss(NeuralModule, Exportable): """ Calculate negative log likelihodd for sequence input, also applies label smoothing (if set). """ @property def input_types(self): """Returns definitions of module input ports. """ return { "log_probs": NeuralType(("B", "T", "D"), LogprobsType()), "labels": NeuralType(("B", "T"), LabelsType()), "output_mask": NeuralType(("B", "T"), MaskType(), optional=True), "lengths": NeuralType(("B"), LabelsType(), optional=True), } @property def output_types(self): """Returns definitions of module output ports. """ return {"loss": NeuralType(elements_type=LossType())} def __init__(self, reduction='mean', label_smoothing=0.0, eps=1e-8, **kwargs): super().__init__() self.reduction = reduction self.label_smoothing = label_smoothing self.nll_loss = torch.nn.NLLLoss(reduction='none', **kwargs) self.eps = eps # small constant to avoid divide by zero @typecheck() def forward(self, log_probs, labels, output_mask=None, lengths=None): """ Params: - log_probs: BxTxC - labels: B - output_mask: BxT - lengths: B """ if output_mask is None and lengths is None: output_mask = torch.ones_like(log_probs).float() elif output_mask is None and lengths is not None: output_mask = torch.arange(log_probs.size(1), device=log_probs.device)[None, :] < lengths[:, None] output_mask = output_mask.float() log_probs = log_probs.transpose(1, 2) # BxTxC -> BxCxT loss = output_mask * self.nll_loss(log_probs, labels) batch_size = loss.size(0) if self.reduction == "mean": loss = loss.sum() / (torch.sum(output_mask) + self.eps) elif self.reduction == "batchmean": loss = loss.sum() / batch_size elif self.reduction == "batch": loss = loss.reshape(batch_size, -1).sum(1) / (output_mask.reshape(batch_size, -1).sum(1) + self.eps) if self.label_smoothing == 0.0: return loss else: # Regularizing Neural Networks by Penalizing Confident Output Distributions. # https://arxiv.org/abs/1701.06548 loss_reg = torch.mean(log_probs, dim=1) * output_mask if self.reduction == "mean": loss_reg = torch.sum(loss_reg) / torch.sum(output_mask) elif self.reduction == "batchmean": loss_reg = torch.sum(loss_reg) / labels.shape[0] elif self.reduction == "batch": loss_reg = loss_reg.sum(1) / output_mask.sum(1) return -self.label_smoothing * loss_reg + (1 - self.label_smoothing) * loss