nemo_automodel.components.moe.megatron.fused_a2a
#
Module Contents#
Classes#
Fused dispatch operation for MoE routing combining computation and communication. |
|
Fused combine operation for MoE output combining computation and communication. |
Functions#
Calculate the number of hidden bytes for a tensor. |
|
Get or create a buffer for all-to-all communication. |
Data#
API#
- nemo_automodel.components.moe.megatron.fused_a2a._buffer#
None
Calculate the number of hidden bytes for a tensor.
- Parameters:
x (torch.Tensor) – Input tensor
- Returns:
Number of hidden bytes
- Return type:
int
- nemo_automodel.components.moe.megatron.fused_a2a.get_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int)#
Get or create a buffer for all-to-all communication.
- Parameters:
group (torch.distributed.ProcessGroup) – Process group for communication
hidden_bytes (int) – Number of hidden bytes needed
- Returns:
Communication buffer
- Return type:
Buffer
- class nemo_automodel.components.moe.megatron.fused_a2a.FusedDispatch#
Bases:
torch.autograd.Function
Fused dispatch operation for MoE routing combining computation and communication.
- static forward(
- ctx,
- x,
- token_indices,
- token_probs,
- num_experts,
- group,
- async_finish=False,
- allocate_on_comm_stream=False,
Forward pass of fused dispatch.
- static backward(
- ctx,
- grad_output,
- grad_token_indices,
- grad_token_probs,
- grad_tokens_per_expert,
- grad_handle,
Backward pass of fused dispatch.
- class nemo_automodel.components.moe.megatron.fused_a2a.FusedCombine#
Bases:
torch.autograd.Function
Fused combine operation for MoE output combining computation and communication.
- static forward(
- ctx,
- x,
- group,
- handle,
- async_finish=False,
- allocate_on_comm_stream=False,
Forward pass of fused combine.
- static backward(ctx, grad_output, previous_event=None)#
Backward pass of fused combine.