Losses#
- class nemo.collections.common.losses.AggregatorLoss(*args: Any, **kwargs: Any)[source]#
Bases:
Loss
Sums several losses into one.
- Parameters
num_inputs – number of input losses
weights – a list of coefficient for merging losses
- property input_types#
Returns definitions of module input ports.
- property output_types#
Returns definitions of module output ports.
- class nemo.collections.common.losses.CrossEntropyLoss(*args: Any, **kwargs: Any)[source]#
Bases:
CrossEntropyLoss
,Serialization
,Typing
- __init__(logits_ndim=2, weight=None, reduction='mean', ignore_index=-100)[source]#
- Parameters
logits_ndim (int) – number of dimensions (or rank) of the logits tensor
weight (list) – list of rescaling weight given to each class
reduction (str) – type of the reduction over the batch
- forward(logits, labels, loss_mask=None)[source]#
- Parameters
logits (float) – output of the classifier
labels (long) – ground truth labels
loss_mask (bool/float/int) – tensor to specify the masking
- property input_types#
Returns definitions of module input ports.
- property output_types#
Returns definitions of module output ports.
- class nemo.collections.common.losses.MSELoss(*args: Any, **kwargs: Any)[source]#
Bases:
MSELoss
,Serialization
,Typing
- __init__(reduction: str = 'mean')[source]#
- Parameters
reduction – type of the reduction over the batch
- forward(preds: torch.Tensor, labels: torch.Tensor) torch.Tensor [source]#
- Parameters
preds – output of the classifier
labels – ground truth labels
- property input_types#
Returns definitions of module input ports.
- property output_types#
Returns definitions of module output ports.
- class nemo.collections.common.losses.SmoothedCrossEntropyLoss(*args: Any, **kwargs: Any)[source]#
Bases:
Loss
Calculates Cross-entropy loss with label smoothing for a batch of sequences.
SmoothedCrossEntropyLoss: 1) excludes padding tokens from loss calculation 2) allows to use label smoothing regularization 3) allows to calculate loss for the desired number of last tokens 4) per_token_reduction - if False disables reduction per token
- Parameters
label_smoothing (float) – label smoothing regularization coefficient
predict_last_k (int) – parameter which sets the number of last tokens to calculate the loss for, for example 0: (default) calculate loss on the entire sequence (e.g., NMT) 1: calculate loss on the last token only (e.g., LM evaluation) Intermediate values allow to control the trade-off between eval time (proportional to the number of batches) and eval performance (proportional to the number of context tokens)
pad_id (int) – padding id
eps (float) – the small eps number to avoid division buy zero
- __init__(pad_id: Optional[int] = None, label_smoothing: Optional[float] = 0.0, predict_last_k: Optional[int] = 0, eps: float = 1e-06, per_token_reduction: bool = True)[source]#
- forward(log_probs, labels, output_mask=None)[source]#
- Parameters
log_probs – float tensor of shape batch_size x seq_len x vocab_size, values should be log probabilities
labels – int tensor of shape batch_size x seq_len
output_mask – binary tensor of shape batch_size x seq_len
eps – epsilon param to avoid divide by zero in loss calculation
- property input_types#
Returns definitions of module input ports.
- property output_types#
Returns definitions of module output ports.
- class nemo.collections.common.losses.SpanningLoss(*args: Any, **kwargs: Any)[source]#
Bases:
Loss
implements start and end loss of a span e.g. for Question Answering.
- forward(logits, start_positions, end_positions)[source]#
- Parameters
logits – Output of question answering head, which is a token classfier.
start_positions – Ground truth start positions of the answer w.r.t. input sequence. If question is unanswerable, this will be pointing to start token, e.g. [CLS], of the input sequence.
end_positions – Ground truth end positions of the answer w.r.t. input sequence. If question is unanswerable, this will be pointing to start token, e.g. [CLS], of the input sequence.
- property input_types#
Returns definitions of module input ports.
- property output_types#
Returns definitions of module output ports.