core.transformer.moe.fused_a2a#
Module Contents#
Classes#
Fused dispatch operation for MoE routing combining computation and communication. |
|
Fused combine operation for MoE output combining computation and communication. |
|
Fused dispatch operation for permute + dispatch a2a + permute using the HybridEP backend |
|
Fused combine operation for permute + combine a2a + permute using the HybridEP backend |
Functions#
Calculate the number of hidden bytes for a tensor. |
|
Get or create a buffer for all-to-all communication. |
|
Initialize the HybridEP buffer, including buffer allocation and metadata initialization. |
Data#
API#
- core.transformer.moe.fused_a2a._buffer#
None
Calculate the number of hidden bytes for a tensor.
- Parameters:
x (torch.Tensor) – Input tensor
- Returns:
Number of hidden bytes
- Return type:
int
- core.transformer.moe.fused_a2a.get_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int)#
Get or create a buffer for all-to-all communication.
- Parameters:
group (torch.distributed.ProcessGroup) – Process group for communication
hidden_bytes (int) – Number of hidden bytes needed
- Returns:
Communication buffer
- Return type:
Buffer
- class core.transformer.moe.fused_a2a.FusedDispatch#
Bases:
torch.autograd.FunctionFused dispatch operation for MoE routing combining computation and communication.
- static forward(
- ctx,
- x,
- token_indices,
- token_probs,
- num_experts,
- group,
- async_finish=False,
- allocate_on_comm_stream=False,
Forward pass of fused dispatch.
- static backward(
- ctx,
- grad_output,
- grad_token_indices,
- grad_token_probs,
- grad_tokens_per_expert,
- grad_handle,
Backward pass of fused dispatch.
- class core.transformer.moe.fused_a2a.FusedCombine#
Bases:
torch.autograd.FunctionFused combine operation for MoE output combining computation and communication.
- static forward(
- ctx,
- x,
- group,
- handle,
- async_finish=False,
- allocate_on_comm_stream=False,
Forward pass of fused combine.
- static backward(ctx, grad_output, previous_event=None)#
Backward pass of fused combine.
- core.transformer.moe.fused_a2a._hybrid_ep_buffer#
None
- core.transformer.moe.fused_a2a.init_hybrid_ep_buffer(
- group: torch.distributed.ProcessGroup,
- hidden_dim: int,
- seq_len: int,
- num_local_experts: int,
- num_sms_dispatch_api: int,
- num_sms_combine_api: int,
- fp8_dispatch: bool,
Initialize the HybridEP buffer, including buffer allocation and metadata initialization.
If a runtime dispatch/combine requires a larger buffer than the one initialized, the buffer will be reallocated at runtime, incuring extra run-time overhead.
- Parameters:
group (torch.distributed.ProcessGroup) – Process group for HybridEP all-to-all communication.
hidden_dim (int) – Hidden dimension of the input tensor.
seq_len (int) – Maximum sequence length of the input tensor.
num_local_experts (int) – Number of local experts.
num_sms_dispatch_api (int) – Number of SMs used by the dispatch API.
num_sms_combine_api (int) – Number of SMs used by the combine API.
fp8_dispatch (bool) – Whether to use FP8 communication during the dispatch phase.
- class core.transformer.moe.fused_a2a.HybridEPDispatch#
Bases:
torch.autograd.FunctionFused dispatch operation for permute + dispatch a2a + permute using the HybridEP backend
- static forward(
- ctx,
- x,
- routing_map,
- probs,
- group,
- num_local_experts,
- num_sms_dispatch_api=24,
- num_sms_combine_api=24,
- num_permuted_tokens=None,
- pad_multiple=None,
Forward pass of fused dispatch of the HybridEP backend
- static backward(
- ctx,
- grad_x,
- grad_probs,
- grad_scaling_factor,
- grad_tokens_per_expert,
- grad_handle,
Backward pass of fused dispatch of the HybridEP backend
- class core.transformer.moe.fused_a2a.HybridEPCombine#
Bases:
torch.autograd.FunctionFused combine operation for permute + combine a2a + permute using the HybridEP backend
- static forward(ctx, x, handle, num_permuted_tokens=None, pad_multiple=None)#
Forward pass of fused combine of the HybridEP backend
- static backward(ctx, grad_x)#
Backward pass of fused combine of the HybridEP backend