nemo_automodel.components.loss.linear_ce#

Module Contents#

Classes#

Functions#

new_is_triton_greater_or_equal

Check if pytorch-triton version is greater than or equal to the specified version.

new_is_triton_greater_or_equal_3_2_0

Check if pytorch-triton version is greater than or equal to 3.1.0.

API#

nemo_automodel.components.loss.linear_ce.new_is_triton_greater_or_equal(version_str)#

Check if pytorch-triton version is greater than or equal to the specified version.

Parameters:

version_str – Version string to check

Returns:

True if pytorch-triton version >= specified version

Return type:

bool

nemo_automodel.components.loss.linear_ce.new_is_triton_greater_or_equal_3_2_0()#

Check if pytorch-triton version is greater than or equal to 3.1.0.

Returns:

True if pytorch-triton version >= 3.1.0

Return type:

bool

class nemo_automodel.components.loss.linear_ce.FusedLinearCrossEntropy(
ignore_index: int = -100,
logit_softcapping: float = 0,
reduction: str = 'sum',
)#

Bases: torch.nn.Module

Initialization

Fused linear cross entropy loss.

Parameters:
  • ignore_index (int) – Target value that is ignored when computing the loss. Defaults to -100.

  • logit_softcapping (float) – Value for softcapping logits (0 means no capping). Defaults to 0.

  • reduction (str) – Type of reduction. Defaults to β€œsum”.

forward(
hidden_states: torch.Tensor,
labels: torch.Tensor,
lm_weight: torch.Tensor,
num_label_tokens: Optional[int] = None,
) torch.Tensor#

Compute fused linear cross entropy loss that matches PyTorch’s cross_entropy behavior.

Parameters:
  • hidden_states – Input hidden states

  • labels – Target labels

  • lm_weight – Weight matrix for linear transformation

  • num_label_tokens – Number of non-padding tokens.