core.fusions.fused_pad_routing_map#

Module Contents#

Functions#

_pad_routing_map_kernel

fused_pad_routing_map

Fused version of pad_routing_map.

API#

core.fusions.fused_pad_routing_map._pad_routing_map_kernel(
routing_map_ptr,
output_ptr,
num_tokens,
pad_multiple: triton.language.constexpr,
BLOCK_SIZE: triton.language.constexpr,
)#
core.fusions.fused_pad_routing_map.fused_pad_routing_map(
routing_map: torch.Tensor,
pad_multiple: int,
) torch.Tensor#

Fused version of pad_routing_map.

Parameters:
  • routing_map (torch.Tensor) – A boolean or integer tensor of shape [num_tokens, num_experts] indicating which tokens are routed to which experts.

  • pad_multiple (int) – The multiple to pad each expert’s token count to.

Returns:

The padded routing map of shape [num_tokens, num_experts].

Return type:

torch.Tensor