core.fusions.fused_pad_routing_map#
Module Contents#
Functions#
Fused version of pad_routing_map. |
API#
- core.fusions.fused_pad_routing_map._pad_routing_map_kernel(
- routing_map_ptr,
- output_ptr,
- num_tokens,
- pad_multiple: triton.language.constexpr,
- BLOCK_SIZE: triton.language.constexpr,
- core.fusions.fused_pad_routing_map.fused_pad_routing_map(
- routing_map: torch.Tensor,
- pad_multiple: int,
Fused version of pad_routing_map.
- Parameters:
routing_map (torch.Tensor) – A boolean or integer tensor of shape [num_tokens, num_experts] indicating which tokens are routed to which experts.
pad_multiple (int) – The multiple to pad each expert’s token count to.
- Returns:
The padded routing map of shape [num_tokens, num_experts].
- Return type:
torch.Tensor