nemo_automodel.components.loss.linear_ce
nemo_automodel.components.loss.linear_ce
Module Contents
Classes
Functions
Data
API
Bases: Module
Fused linear-projection and cross-entropy loss module.
Compute fused linear cross entropy loss that matches PyTorch’s cross_entropy behavior.
Parameters:
hidden_states
Input hidden states
labels
Target labels
lm_weight
Weight matrix for linear transformation
num_label_tokens
Number of non-padding tokens.
Check if pytorch-triton version is greater than or equal to the specified version.
Parameters:
version_str
Version string to check
Returns:
True if pytorch-triton version >= specified version
Check if pytorch-triton version is greater than or equal to 3.1.0.
Returns:
True if pytorch-triton version >= 3.1.0