nemo_automodel.components.moe.megatron.moe_utils

View as Markdown

Module Contents

Classes

NameDescription
MoEAuxLossAutoScalerAn AutoScaler that triggers the backward pass and scales the grad for auxiliary loss.
WeightedBiasQuickGeGLUFunctionAutograd function for token-wise weighted Quick-GEGLU with bias support.
WeightedGEGLUFunctionAutograd function for token-wise weighted GEGLU.
WeightedQuickGeGLUFunctionAutograd function for token-wise weighted Quick-GEGLU (no bias).
WeightedSwiGLUFunctionAutograd function for token-wise weighted SwiGLU.

Functions

NameDescription
gegluGEGLU activation function.
geglu_backCompute the input gradient for tanh-approximated GEGLU activation.
permutePermute the tokens and probs based on the mask.
quick_gegluPerforms Quick-GELU-based GEGLU activation : quick_gelu(y1) * (y2 + offset).
quick_geglu_backCompute the input gradient for Quick-GEGLU activation.
quick_geluSigmoid approximation of gelu
swigluApply SwiGLU activation to an interleaved gate/up tensor.
swiglu_backCompute the input gradient for SwiGLU activation.
unpermuteRestore the original order of tokens after permutation. If probs are provided, it
weighted_bias_geglu_implToken-wise-weighted bias GEGLU fusion (tanh-approximated GELU gating).
weighted_bias_quick_gegluToken-wise weighted Quick-GEGLU activation with bias.
weighted_bias_quick_geglu_backBackward helper for weighted Quick-GEGLU with bias.
weighted_bias_quick_geglu_implToken-wise-weighted bias quick_geglu fusion.
weighted_bias_swiglu_implToken-wise-weighted bias swiglu fusion.
weighted_gegluApply GEGLU activation and token-wise routing weights.
weighted_geglu_backCompute input and weight gradients for weighted GEGLU.
weighted_quick_gegluToken-wise-weighted Quick-GEGLU activation.
weighted_quick_geglu_backBackward helper for weighted Quick-GEGLU.
weighted_swigluApply SwiGLU activation and token-wise routing weights.
weighted_swiglu_backCompute input and weight gradients for weighted SwiGLU.

API

class nemo_automodel.components.moe.megatron.moe_utils.MoEAuxLossAutoScaler()

Bases: Function

An AutoScaler that triggers the backward pass and scales the grad for auxiliary loss.

main_loss_backward_scale
Tensor = None
nemo_automodel.components.moe.megatron.moe_utils.MoEAuxLossAutoScaler.backward(
ctx,
grad_output: torch.Tensor
)
staticmethod

Compute and scale the gradient for auxiliary loss..

Parameters:

grad_output
torch.Tensor

The gradient of the output.

Returns:

Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient.

nemo_automodel.components.moe.megatron.moe_utils.MoEAuxLossAutoScaler.forward(
ctx,
output: torch.Tensor,
aux_loss: torch.Tensor
)
staticmethod

Preserve the aux_loss by storing it in the context to avoid garbage collection.

Parameters:

output
torch.Tensor

The output tensor.

aux_loss
torch.Tensor

The auxiliary loss tensor.

Returns:

torch.Tensor: The output tensor.

class nemo_automodel.components.moe.megatron.moe_utils.WeightedBiasQuickGeGLUFunction()

Bases: Function

Autograd function for token-wise weighted Quick-GEGLU with bias support.

nemo_automodel.components.moe.megatron.moe_utils.WeightedBiasQuickGeGLUFunction.backward(
ctx,
grad_output
)
staticmethod
nemo_automodel.components.moe.megatron.moe_utils.WeightedBiasQuickGeGLUFunction.forward(
ctx,
input: torch.Tensor,
bias: torch.Tensor,
weights: torch.Tensor,
fp8_input_store: bool,
linear_offset: torch.Tensor
)
staticmethod
class nemo_automodel.components.moe.megatron.moe_utils.WeightedGEGLUFunction()

Bases: Function

Autograd function for token-wise weighted GEGLU.

nemo_automodel.components.moe.megatron.moe_utils.WeightedGEGLUFunction.backward(
ctx,
grad_output
)
staticmethod
nemo_automodel.components.moe.megatron.moe_utils.WeightedGEGLUFunction.forward(
ctx,
input,
weights,
fp8_input_store
)
staticmethod
class nemo_automodel.components.moe.megatron.moe_utils.WeightedQuickGeGLUFunction()

Bases: Function

Autograd function for token-wise weighted Quick-GEGLU (no bias).

nemo_automodel.components.moe.megatron.moe_utils.WeightedQuickGeGLUFunction.backward(
ctx,
grad_output
)
staticmethod
nemo_automodel.components.moe.megatron.moe_utils.WeightedQuickGeGLUFunction.forward(
ctx,
input: torch.Tensor,
weights: torch.Tensor,
fp8_input_store: bool,
linear_offset: torch.Tensor
)
staticmethod
class nemo_automodel.components.moe.megatron.moe_utils.WeightedSwiGLUFunction()

Bases: Function

Autograd function for token-wise weighted SwiGLU.

nemo_automodel.components.moe.megatron.moe_utils.WeightedSwiGLUFunction.backward(
ctx,
grad_output
)
staticmethod
nemo_automodel.components.moe.megatron.moe_utils.WeightedSwiGLUFunction.forward(
ctx,
input,
weights,
fp8_input_store
)
staticmethod
nemo_automodel.components.moe.megatron.moe_utils.geglu(
y
)

GEGLU activation function. Splits the input in half along the last dimension and applies: GEGLU(y) = GELU_tanh(y_gate) * y_up

Used by Gemma4 MoE expert layers (hidden_activation=“gelu_pytorch_tanh”).

nemo_automodel.components.moe.megatron.moe_utils.geglu_back(
g,
y
)

Compute the input gradient for tanh-approximated GEGLU activation.

nemo_automodel.components.moe.megatron.moe_utils.permute(
tokens,
routing_map,
probs: typing.Optional[torch.Tensor] = None,
num_out_tokens: typing.Optional[int] = None,
fused: bool = False,
drop_and_pad: bool = False
)

Permute the tokens and probs based on the mask. Tokens with the same designated expert will be grouped together. The shape of mask is [tokens, num_experts], it indicates which experts were selected by each token.

When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to expert capacity. This function exploits this feature to use ops that support cuda graph.

Parameters:

tokens
torch.Tensor

The input token tensor, [num_tokens, hidden].

routing_map
torch.Tensor

The sparse token to expert mapping, [num_tokens, num_experts].

probs
torch.TensorDefaults to None

The probs tensor, [num_tokens, num_experts].

num_out_tokens
intDefaults to None

The number of output tokens. If None, it’s set to the number of input tokens.

fused
boolDefaults to False

Whether use the fused permute function.

drop_and_pad
boolDefaults to False

Whether or not the token dispatcher uses token-drop and pads the number of tokens to the expert capacity. If set to true, routing_map has a fixed number of non-zeros in each column.

Returns: torch.Tensor

The permuted token tensor.

nemo_automodel.components.moe.megatron.moe_utils.quick_geglu(
y: torch.Tensor,
linear_offset: float = 0.0
) -> torch.Tensor

Performs Quick-GELU-based GEGLU activation : quick_gelu(y1) * (y2 + offset).

Parameters:

y
torch.Tensor

Input tensor split into two halves on the last dimension.

linear_offset
floatDefaults to 0.0

Optional linear offset added to the second half before gating.

Returns: torch.Tensor

Tensor after applying the GEGLU activation.

nemo_automodel.components.moe.megatron.moe_utils.quick_geglu_back(
g,
y,
linear_offset: float = 0.0
) -> torch.Tensor

Compute the input gradient for Quick-GEGLU activation.

nemo_automodel.components.moe.megatron.moe_utils.quick_gelu(
y: torch.Tensor,
alpha: float = 1.702
) -> torch.Tensor

Sigmoid approximation of gelu

nemo_automodel.components.moe.megatron.moe_utils.swiglu(
y
)

Apply SwiGLU activation to an interleaved gate/up tensor.

nemo_automodel.components.moe.megatron.moe_utils.swiglu_back(
g,
y
)

Compute the input gradient for SwiGLU activation.

nemo_automodel.components.moe.megatron.moe_utils.unpermute(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
restore_shape: torch.Size,
probs: torch.Tensor = None,
routing_map: torch.Tensor = None,
fused: bool = False,
drop_and_pad: bool = False
)

Restore the original order of tokens after permutation. If probs are provided, it will also apply them to the tokens before restoring the order.

When drop_and_pad=True, the tensors will have the following properties:

  • In routing_map, the number of non-zeros in each column equals to expert capacity
  • The size of sorted_indices equals to num_experts * capacity, each split of capacity contains the indices of tokens routed to an expert. This function exploits these features to use ops that support cuda graph.

Parameters:

permuted_tokens
torch.Tensor

The permuted token tensor.

sorted_indices
torch.Tensor

The indices used to sort the tokens.

restore_shape
torch.Size

The shape of the unpermuted tensor.

probs
torch.TensorDefaults to None

The unpermuted probs tensor,

routing_map
torch.TensorDefaults to None

Token to expert mapping, shape [num_tokens, num_experts].

fused
boolDefaults to False

Whether use the fused unpermute function.

drop_and_pad
boolDefaults to False

Whether or not the token dispatcher uses token-drop and pads the number of tokens to the expert capacity.

Returns:

torch.Tensor: The tokens restored to their original order.

nemo_automodel.components.moe.megatron.moe_utils.weighted_bias_geglu_impl(
input,
weights,
fp8_input_store = False
)

Token-wise-weighted bias GEGLU fusion (tanh-approximated GELU gating).

nemo_automodel.components.moe.megatron.moe_utils.weighted_bias_quick_geglu(
y: torch.Tensor,
bias: torch.Tensor,
weights: torch.Tensor,
linear_offset: float = 0.0
) -> torch.Tensor

Token-wise weighted Quick-GEGLU activation with bias.

Parameters:

y
torch.Tensor

Input tensor before bias addition.

bias
torch.Tensor

Bias tensor broadcastable to y.

weights
torch.Tensor

Weight tensor with shape [tokens, 1] broadcasting over feature dim.

linear_offset
floatDefaults to 0.0

Optional linear offset for the second half before gating.

Returns: torch.Tensor

Activated tensor with same dtype as y.

nemo_automodel.components.moe.megatron.moe_utils.weighted_bias_quick_geglu_back(
g,
y,
bias,
weights,
linear_offset: float = 0.0
)

Backward helper for weighted Quick-GEGLU with bias.

Returns gradients w.r.t input y, bias, and weights.

nemo_automodel.components.moe.megatron.moe_utils.weighted_bias_quick_geglu_impl(
input,
bias,
weights,
fp8_input_store = False,
linear_offset = 0.0,
clamp_value = None,
alpha = 1.702
)

Token-wise-weighted bias quick_geglu fusion. input: [num_selected_experts * seq_len, hidden_size * 2] bias: None weights: [num_selected_experts * seq_len, 1] fp8_input_store: bool linear_offset: float output: [num_selected_experts * seq_len, hidden_size]

nemo_automodel.components.moe.megatron.moe_utils.weighted_bias_swiglu_impl(
input,
weights,
fp8_input_store = False
)

Token-wise-weighted bias swiglu fusion.

nemo_automodel.components.moe.megatron.moe_utils.weighted_geglu(
y,
weights
)

Apply GEGLU activation and token-wise routing weights.

nemo_automodel.components.moe.megatron.moe_utils.weighted_geglu_back(
g,
y,
weights
)

Compute input and weight gradients for weighted GEGLU.

nemo_automodel.components.moe.megatron.moe_utils.weighted_quick_geglu(
y: torch.Tensor,
weights: torch.Tensor,
linear_offset: float = 0.0
) -> torch.Tensor

Token-wise-weighted Quick-GEGLU activation.

The weights tensor is expected to have the same first-dimension length as y and a trailing singleton dimension so that it broadcasts over the feature dimension.

nemo_automodel.components.moe.megatron.moe_utils.weighted_quick_geglu_back(
g,
y,
weights,
linear_offset: float = 0.0
)

Backward helper for weighted Quick-GEGLU. Returns gradient w.r.t input y and weights.

nemo_automodel.components.moe.megatron.moe_utils.weighted_swiglu(
y,
weights
)

Apply SwiGLU activation and token-wise routing weights.

nemo_automodel.components.moe.megatron.moe_utils.weighted_swiglu_back(
g,
y,
weights
)

Compute input and weight gradients for weighted SwiGLU.