nemo_automodel.components.loss.linear_ce#

Module Contents#

Classes#

Functions#

new_is_triton_greater_or_equal

Check if pytorch-triton version is greater than or equal to the specified version.

new_is_triton_greater_or_equal_3_2_0

Check if pytorch-triton version is greater than or equal to 3.1.0.

API#

nemo_automodel.components.loss.linear_ce.new_is_triton_greater_or_equal(version_str)[source]#

Check if pytorch-triton version is greater than or equal to the specified version.

Parameters:

version_str – Version string to check

Returns:

True if pytorch-triton version >= specified version

Return type:

bool

nemo_automodel.components.loss.linear_ce.new_is_triton_greater_or_equal_3_2_0()[source]#

Check if pytorch-triton version is greater than or equal to 3.1.0.

Returns:

True if pytorch-triton version >= 3.1.0

Return type:

bool

class nemo_automodel.components.loss.linear_ce.FusedLinearCrossEntropy(
ignore_index: int = -100,
logit_softcapping: float = 0,
reduction: str = 'sum',
)[source]#

Initialization

Fused linear cross entropy loss.

Parameters:
  • ignore_index (int) – Target value that is ignored when computing the loss. Defaults to -100.

  • logit_softcapping (float) – Value for softcapping logits (0 means no capping). Defaults to 0.

  • reduction (str) – Type of reduction. Defaults to “sum”.

__call__(
hidden_states: torch.Tensor,
labels: torch.Tensor,
lm_weight: torch.Tensor,
) torch.Tensor[source]#

Compute fused linear cross entropy loss that matches PyTorch’s cross_entropy behavior.

Parameters:
  • hidden_states – Input hidden states

  • labels – Target labels

  • lm_weight – Weight matrix for linear transformation