nemo_automodel.components.loss.linear_ce

View as Markdown

Module Contents

Classes

NameDescription
FusedLinearCrossEntropyFused linear-projection and cross-entropy loss module.

Functions

NameDescription
new_is_triton_greater_or_equalCheck if pytorch-triton version is greater than or equal to the specified version.
new_is_triton_greater_or_equal_3_2_0Check if pytorch-triton version is greater than or equal to 3.1.0.

Data

HAVE_CUT_CROSS_ENTROPY

API

class nemo_automodel.components.loss.linear_ce.FusedLinearCrossEntropy(
ignore_index: int = -100,
logit_softcapping: float = 0,
reduction: str = 'sum'
)

Bases: Module

Fused linear-projection and cross-entropy loss module.

nemo_automodel.components.loss.linear_ce.FusedLinearCrossEntropy.forward(
hidden_states: torch.Tensor,
labels: torch.Tensor,
lm_weight: torch.Tensor,
num_label_tokens: typing.Optional[int] = None
) -> torch.Tensor

Compute fused linear cross entropy loss that matches PyTorch’s cross_entropy behavior.

Parameters:

hidden_states
torch.Tensor

Input hidden states

labels
torch.Tensor

Target labels

lm_weight
torch.Tensor

Weight matrix for linear transformation

num_label_tokens
Optional[int]Defaults to None

Number of non-padding tokens.

nemo_automodel.components.loss.linear_ce.new_is_triton_greater_or_equal(
version_str
)

Check if pytorch-triton version is greater than or equal to the specified version.

Parameters:

version_str

Version string to check

Returns:

True if pytorch-triton version >= specified version

nemo_automodel.components.loss.linear_ce.new_is_triton_greater_or_equal_3_2_0()

Check if pytorch-triton version is greater than or equal to 3.1.0.

Returns:

True if pytorch-triton version >= 3.1.0

nemo_automodel.components.loss.linear_ce.HAVE_CUT_CROSS_ENTROPY = True