nemo_automodel.loss.linear_ce
#
Module Contents#
Functions#
Check if pytorch-triton version is greater than or equal to the specified version. |
|
Check if pytorch-triton version is greater than or equal to 3.1.0. |
|
Compute fused linear cross entropy loss that matches PyTorch’s cross_entropy behavior. |
API#
- nemo_automodel.loss.linear_ce.new_is_triton_greater_or_equal(version_str)[source]#
Check if pytorch-triton version is greater than or equal to the specified version.
- Parameters:
version_str – Version string to check
- Returns:
True if pytorch-triton version >= specified version
- Return type:
bool
- nemo_automodel.loss.linear_ce.new_is_triton_greater_or_equal_3_2_0()[source]#
Check if pytorch-triton version is greater than or equal to 3.1.0.
- Returns:
True if pytorch-triton version >= 3.1.0
- Return type:
bool
- nemo_automodel.loss.linear_ce.fused_linear_cross_entropy(
- hidden_states: torch.Tensor,
- lm_weight: torch.Tensor,
- labels: torch.Tensor,
- num_items_in_batch: int = None,
- ignore_index: int = -100,
- reduction: str = 'mean',
- logit_softcapping: float = 0,
- accuracy_threshold: str = 'auto',
Compute fused linear cross entropy loss that matches PyTorch’s cross_entropy behavior.
- Parameters:
hidden_states – Input hidden states
lm_weight – Weight matrix for linear transformation
labels – Target labels
num_items_in_batch – Number of valid tokens (where labels != ignore_index)
ignore_index – Value to ignore in labels (default: -100)
reduction – Reduction method (‘mean’ or ‘sum’)
logit_softcapping – Value for softcapping logits (0 means no capping)
accuracy_threshold – Threshold for accuracy computation