nemo_automodel.loss.linear_ce#

Module Contents#

Functions#

new_is_triton_greater_or_equal

Check if pytorch-triton version is greater than or equal to the specified version.

new_is_triton_greater_or_equal_3_2_0

Check if pytorch-triton version is greater than or equal to 3.1.0.

fused_linear_cross_entropy

Compute fused linear cross entropy loss that matches PyTorch’s cross_entropy behavior.

API#

nemo_automodel.loss.linear_ce.new_is_triton_greater_or_equal(version_str)[source]#

Check if pytorch-triton version is greater than or equal to the specified version.

Parameters:

version_str – Version string to check

Returns:

True if pytorch-triton version >= specified version

Return type:

bool

nemo_automodel.loss.linear_ce.new_is_triton_greater_or_equal_3_2_0()[source]#

Check if pytorch-triton version is greater than or equal to 3.1.0.

Returns:

True if pytorch-triton version >= 3.1.0

Return type:

bool

nemo_automodel.loss.linear_ce.fused_linear_cross_entropy(
hidden_states: torch.Tensor,
lm_weight: torch.Tensor,
labels: torch.Tensor,
num_items_in_batch: int = None,
ignore_index: int = -100,
reduction: str = 'mean',
logit_softcapping: float = 0,
accuracy_threshold: str = 'auto',
)[source]#

Compute fused linear cross entropy loss that matches PyTorch’s cross_entropy behavior.

Parameters:
  • hidden_states – Input hidden states

  • lm_weight – Weight matrix for linear transformation

  • labels – Target labels

  • num_items_in_batch – Number of valid tokens (where labels != ignore_index)

  • ignore_index – Value to ignore in labels (default: -100)

  • reduction – Reduction method (‘mean’ or ‘sum’)

  • logit_softcapping – Value for softcapping logits (0 means no capping)

  • accuracy_threshold – Threshold for accuracy computation