nemo_automodel.components.moe.megatron.fused_a2a#
Module Contents#
Classes#
Fused dispatch operation for MoE routing combining computation and communication. |
|
Fused combine operation for MoE output combining computation and communication. |
|
Fused dispatch operation for permute + dispatch a2a + permute using the HybridEP backend. |
|
Fused combine operation for permute + combine a2a + permute using the HybridEP backend. |
Functions#
Check if DeepEP was compiled with NVSHMEM support. |
|
Calculate the number of hidden bytes for a tensor. |
|
Get or create a buffer for all-to-all communication. |
|
Initialize the HybridEP buffer, including buffer allocation and metadata initialization. |
|
Reset the HybridEP buffer. |
Data#
API#
- nemo_automodel.components.moe.megatron.fused_a2a._buffer#
None
- nemo_automodel.components.moe.megatron.fused_a2a._nvshmem_available#
None
- nemo_automodel.components.moe.megatron.fused_a2a._is_nvshmem_available() bool#
Check if DeepEP was compiled with NVSHMEM support.
Uses is_sm90_compiled() as proxy β DeepEPβs build enforces that NVSHMEM is disabled when SM90 features are disabled.
Calculate the number of hidden bytes for a tensor.
- Parameters:
x (torch.Tensor) β Input tensor
- Returns:
Number of hidden bytes
- Return type:
int
- nemo_automodel.components.moe.megatron.fused_a2a.get_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int)#
Get or create a buffer for all-to-all communication.
- Parameters:
group (torch.distributed.ProcessGroup) β Process group for communication
hidden_bytes (int) β Number of hidden bytes needed
- Returns:
Communication buffer
- Return type:
Buffer
- class nemo_automodel.components.moe.megatron.fused_a2a.FusedDispatch#
Bases:
torch.autograd.FunctionFused dispatch operation for MoE routing combining computation and communication.
- static forward(
- ctx,
- x,
- token_indices,
- token_probs,
- num_experts,
- group,
- async_finish=False,
- allocate_on_comm_stream=False,
Forward pass of fused dispatch.
- static backward(
- ctx,
- grad_output,
- grad_token_indices,
- grad_token_probs,
- grad_tokens_per_expert,
- grad_handle,
Backward pass of fused dispatch.
- class nemo_automodel.components.moe.megatron.fused_a2a.FusedCombine#
Bases:
torch.autograd.FunctionFused combine operation for MoE output combining computation and communication.
- static forward(
- ctx,
- x,
- group,
- handle,
- async_finish=False,
- allocate_on_comm_stream=False,
Forward pass of fused combine.
- static backward(ctx, grad_output, previous_event=None)#
Backward pass of fused combine.
- nemo_automodel.components.moe.megatron.fused_a2a._hybrid_ep_buffer#
None
- nemo_automodel.components.moe.megatron.fused_a2a.init_hybrid_ep_buffer(
- group: torch.distributed.ProcessGroup,
- hidden_dim: int,
- seq_len: int,
- num_local_experts: int,
- num_sms_dispatch_api: int,
- num_sms_combine_api: int,
- fp8_dispatch: bool,
Initialize the HybridEP buffer, including buffer allocation and metadata initialization.
If a runtime dispatch/combine requires a larger buffer than the one initialized, the buffer will be reallocated at runtime, incuring extra run-time overhead.
- Parameters:
group β Process group for HybridEP all-to-all communication.
hidden_dim β Hidden dimension of the input tensor.
seq_len β Maximum sequence length of the input tensor.
num_local_experts β Number of local experts.
num_sms_dispatch_api β Number of SMs used by the dispatch API.
num_sms_combine_api β Number of SMs used by the combine API.
fp8_dispatch β Whether to use FP8 communication during the dispatch phase.
- nemo_automodel.components.moe.megatron.fused_a2a.reset_hybrid_ep_buffer()#
Reset the HybridEP buffer.
- class nemo_automodel.components.moe.megatron.fused_a2a.HybridEPDispatch#
Bases:
torch.autograd.FunctionFused dispatch operation for permute + dispatch a2a + permute using the HybridEP backend.
- static forward(
- ctx,
- x,
- routing_map,
- probs,
- group,
- num_local_experts,
- num_sms_dispatch_api=24,
- num_sms_combine_api=24,
- num_permuted_tokens=None,
- pad_multiple=None,
Forward pass of fused dispatch of the HybridEP backend.
- static backward(
- ctx,
- grad_x,
- grad_probs,
- grad_scaling_factor,
- grad_tokens_per_expert,
- grad_handle,
Backward pass of fused dispatch of the HybridEP backend.
- class nemo_automodel.components.moe.megatron.fused_a2a.HybridEPCombine#
Bases:
torch.autograd.FunctionFused combine operation for permute + combine a2a + permute using the HybridEP backend.
- static forward(ctx, x, handle, num_permuted_tokens=None, pad_multiple=None)#
Forward pass of fused combine of the HybridEP backend.
- static backward(ctx, grad_x)#
Backward pass of fused combine of the HybridEP backend.