nemo_automodel.components.loss.linear_ce
#
Module Contents#
Classes#
Functions#
Check if pytorch-triton version is greater than or equal to the specified version. |
|
Check if pytorch-triton version is greater than or equal to 3.1.0. |
API#
- nemo_automodel.components.loss.linear_ce.new_is_triton_greater_or_equal(version_str)[source]#
Check if pytorch-triton version is greater than or equal to the specified version.
- Parameters:
version_str – Version string to check
- Returns:
True if pytorch-triton version >= specified version
- Return type:
bool
- nemo_automodel.components.loss.linear_ce.new_is_triton_greater_or_equal_3_2_0()[source]#
Check if pytorch-triton version is greater than or equal to 3.1.0.
- Returns:
True if pytorch-triton version >= 3.1.0
- Return type:
bool
- class nemo_automodel.components.loss.linear_ce.FusedLinearCrossEntropy(
- ignore_index: int = -100,
- logit_softcapping: float = 0,
- reduction: str = 'sum',
Initialization
Fused linear cross entropy loss.
- Parameters:
ignore_index (int) – Target value that is ignored when computing the loss. Defaults to -100.
logit_softcapping (float) – Value for softcapping logits (0 means no capping). Defaults to 0.
reduction (str) – Type of reduction. Defaults to “sum”.
- __call__(
- hidden_states: torch.Tensor,
- labels: torch.Tensor,
- lm_weight: torch.Tensor,
Compute fused linear cross entropy loss that matches PyTorch’s cross_entropy behavior.
- Parameters:
hidden_states – Input hidden states
labels – Target labels
lm_weight – Weight matrix for linear transformation