core.transformer.moe.fused_a2a#

Module Contents#

Classes#

FusedDispatch

Fused dispatch operation for MoE routing combining computation and communication.

FusedCombine

Fused combine operation for MoE output combining computation and communication.

HybridEPDispatch

Fused dispatch operation for permute + dispatch a2a + permute using the HybridEP backend

HybridEPCombine

Fused combine operation for permute + combine a2a + permute using the HybridEP backend

Functions#

get_hidden_bytes

Calculate the number of hidden bytes for a tensor.

get_buffer

Get or create a buffer for all-to-all communication.

init_hybrid_ep_buffer

Initialize the HybridEP buffer, including buffer allocation and metadata initialization.

Data#

API#

core.transformer.moe.fused_a2a._buffer#

None

core.transformer.moe.fused_a2a.get_hidden_bytes(x: torch.Tensor) int#

Calculate the number of hidden bytes for a tensor.

Parameters:

x (torch.Tensor) – Input tensor

Returns:

Number of hidden bytes

Return type:

int

core.transformer.moe.fused_a2a.get_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int)#

Get or create a buffer for all-to-all communication.

Parameters:
  • group (torch.distributed.ProcessGroup) – Process group for communication

  • hidden_bytes (int) – Number of hidden bytes needed

Returns:

Communication buffer

Return type:

Buffer

class core.transformer.moe.fused_a2a.FusedDispatch#

Bases: torch.autograd.Function

Fused dispatch operation for MoE routing combining computation and communication.

static forward(
ctx,
x,
token_indices,
token_probs,
num_experts,
group,
async_finish=False,
allocate_on_comm_stream=False,
)#

Forward pass of fused dispatch.

static backward(
ctx,
grad_output,
grad_token_indices,
grad_token_probs,
grad_tokens_per_expert,
grad_handle,
)#

Backward pass of fused dispatch.

class core.transformer.moe.fused_a2a.FusedCombine#

Bases: torch.autograd.Function

Fused combine operation for MoE output combining computation and communication.

static forward(
ctx,
x,
group,
handle,
async_finish=False,
allocate_on_comm_stream=False,
)#

Forward pass of fused combine.

static backward(ctx, grad_output, previous_event=None)#

Backward pass of fused combine.

core.transformer.moe.fused_a2a._hybrid_ep_buffer#

None

core.transformer.moe.fused_a2a.init_hybrid_ep_buffer(
group: torch.distributed.ProcessGroup,
hidden_dim: int,
seq_len: int,
num_local_experts: int,
num_sms_dispatch_api: int,
num_sms_combine_api: int,
fp8_dispatch: bool,
) None#

Initialize the HybridEP buffer, including buffer allocation and metadata initialization.

If a runtime dispatch/combine requires a larger buffer than the one initialized, the buffer will be reallocated at runtime, incuring extra run-time overhead.

Parameters:
  • group (torch.distributed.ProcessGroup) – Process group for HybridEP all-to-all communication.

  • hidden_dim (int) – Hidden dimension of the input tensor.

  • seq_len (int) – Maximum sequence length of the input tensor.

  • num_local_experts (int) – Number of local experts.

  • num_sms_dispatch_api (int) – Number of SMs used by the dispatch API.

  • num_sms_combine_api (int) – Number of SMs used by the combine API.

  • fp8_dispatch (bool) – Whether to use FP8 communication during the dispatch phase.

class core.transformer.moe.fused_a2a.HybridEPDispatch#

Bases: torch.autograd.Function

Fused dispatch operation for permute + dispatch a2a + permute using the HybridEP backend

static forward(
ctx,
x,
routing_map,
probs,
group,
num_local_experts,
num_sms_dispatch_api=24,
num_sms_combine_api=24,
num_permuted_tokens=None,
pad_multiple=None,
)#

Forward pass of fused dispatch of the HybridEP backend

static backward(
ctx,
grad_x,
grad_probs,
grad_scaling_factor,
grad_tokens_per_expert,
grad_handle,
)#

Backward pass of fused dispatch of the HybridEP backend

class core.transformer.moe.fused_a2a.HybridEPCombine#

Bases: torch.autograd.Function

Fused combine operation for permute + combine a2a + permute using the HybridEP backend

static forward(ctx, x, handle, num_permuted_tokens=None, pad_multiple=None)#

Forward pass of fused combine of the HybridEP backend

static backward(ctx, grad_x)#

Backward pass of fused combine of the HybridEP backend