nemo_automodel.components.loss.linear_ce
#
Module Contents#
Classes#
Functions#
Check if pytorch-triton version is greater than or equal to the specified version. |
|
Check if pytorch-triton version is greater than or equal to 3.1.0. |
API#
- nemo_automodel.components.loss.linear_ce.new_is_triton_greater_or_equal(version_str)#
Check if pytorch-triton version is greater than or equal to the specified version.
- Parameters:
version_str β Version string to check
- Returns:
True if pytorch-triton version >= specified version
- Return type:
bool
- nemo_automodel.components.loss.linear_ce.new_is_triton_greater_or_equal_3_2_0()#
Check if pytorch-triton version is greater than or equal to 3.1.0.
- Returns:
True if pytorch-triton version >= 3.1.0
- Return type:
bool
- class nemo_automodel.components.loss.linear_ce.FusedLinearCrossEntropy(
- ignore_index: int = -100,
- logit_softcapping: float = 0,
- reduction: str = 'sum',
Bases:
torch.nn.Module
Initialization
Fused linear cross entropy loss.
- Parameters:
ignore_index (int) β Target value that is ignored when computing the loss. Defaults to -100.
logit_softcapping (float) β Value for softcapping logits (0 means no capping). Defaults to 0.
reduction (str) β Type of reduction. Defaults to βsumβ.
- forward(
- hidden_states: torch.Tensor,
- labels: torch.Tensor,
- lm_weight: torch.Tensor,
- num_label_tokens: Optional[int] = None,
Compute fused linear cross entropy loss that matches PyTorchβs cross_entropy behavior.
- Parameters:
hidden_states β Input hidden states
labels β Target labels
lm_weight β Weight matrix for linear transformation
num_label_tokens β Number of non-padding tokens.