nemo_automodel.components.moe.megatron.fused_a2a#

Module Contents#

Classes#

FusedDispatch

Fused dispatch operation for MoE routing combining computation and communication.

FusedCombine

Fused combine operation for MoE output combining computation and communication.

Functions#

get_hidden_bytes

Calculate the number of hidden bytes for a tensor.

get_buffer

Get or create a buffer for all-to-all communication.

Data#

API#

nemo_automodel.components.moe.megatron.fused_a2a._buffer#

None

nemo_automodel.components.moe.megatron.fused_a2a.get_hidden_bytes(x: torch.Tensor) int#

Calculate the number of hidden bytes for a tensor.

Parameters:

x (torch.Tensor) – Input tensor

Returns:

Number of hidden bytes

Return type:

int

nemo_automodel.components.moe.megatron.fused_a2a.get_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int)#

Get or create a buffer for all-to-all communication.

Parameters:
  • group (torch.distributed.ProcessGroup) – Process group for communication

  • hidden_bytes (int) – Number of hidden bytes needed

Returns:

Communication buffer

Return type:

Buffer

class nemo_automodel.components.moe.megatron.fused_a2a.FusedDispatch#

Bases: torch.autograd.Function

Fused dispatch operation for MoE routing combining computation and communication.

static forward(
ctx,
x,
token_indices,
token_probs,
num_experts,
group,
async_finish=False,
allocate_on_comm_stream=False,
)#

Forward pass of fused dispatch.

static backward(
ctx,
grad_output,
grad_token_indices,
grad_token_probs,
grad_tokens_per_expert,
grad_handle,
)#

Backward pass of fused dispatch.

class nemo_automodel.components.moe.megatron.fused_a2a.FusedCombine#

Bases: torch.autograd.Function

Fused combine operation for MoE output combining computation and communication.

static forward(
ctx,
x,
group,
handle,
async_finish=False,
allocate_on_comm_stream=False,
)#

Forward pass of fused combine.

static backward(ctx, grad_output, previous_event=None)#

Backward pass of fused combine.