nemo_automodel.components.moe.megatron.fused_indices_converter
#
Module Contents#
Classes#
Convert moe topk indices to multihot representation. |
Functions#
Triton kernel for converting indices to multihot representation. |
|
Triton kernel for converting multihot representation to indices. |
|
Convert moe topk indices to multihot representation. |
Data#
API#
- nemo_automodel.components.moe.megatron.fused_indices_converter.null_decorator#
βpartial(β¦)β
- nemo_automodel.components.moe.megatron.fused_indices_converter._indices_to_multihot_kernel(
- indices_ptr,
- probs_in_indices_ptr,
- multihot_indices_ptr,
- probs_in_multihot_ptr,
- position_map_ptr,
- num_of_local_experts: triton.language.constexpr,
- num_of_local_experts_next_power_of_2: triton.language.constexpr,
- topk: triton.language.constexpr,
- topk_next_power_of_2: triton.language.constexpr,
- BLOCK_SIZE: triton.language.constexpr,
Triton kernel for converting indices to multihot representation.
Input: indices: [num_of_tokens, topk] probs_in_indices: [num_of_tokens, topk] Output: multihot_indices: [num_of_tokens, num_of_local_experts] probs_in_multihot: [num_of_tokens, num_of_local_experts]
Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2, then the kernel can process the following conversion:
Input Example: indices = [ [0, 1], [1, 2] ] probs_in_indices = [ [0.1, 0.2], [0.3, 0.4] ] Output Example: multihot_indices = [ [1, 1, -1, -1], [-1, 1, 1, -1] ] probs_in_multihot = [ [0.1, 0.2, 0.0, 0.0], [0.0, 0.3, 0.4, 0.0] ]
- nemo_automodel.components.moe.megatron.fused_indices_converter._multihot_to_indices_kernel(
- probs_in_multihot_ptr,
- position_map_ptr,
- probs_indices_ptr,
- num_of_local_experts: triton.language.constexpr,
- num_of_local_experts_next_power_of_2: triton.language.constexpr,
- topk: triton.language.constexpr,
- topk_next_power_of_2: triton.language.constexpr,
- BLOCK_SIZE: triton.language.constexpr,
Triton kernel for converting multihot representation to indices.
Input: probs_in_multihot: [num_of_tokens, num_of_local_experts] position_map: [num_of_tokens, num_of_local_experts] Output: probs_indices: [num_of_tokens, topk]
Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2, then the kernel can process the following conversion:
Input Example: probs_in_multihot = [ [0.7, 0.8, 0.0, 0.0], [0.0, 0.1, 0.9, 0.0] ] position_map = [ [1, 1, -1, -1], [-1, 1, 1, -1] ] Output Example: probs_indices = [ [0.7, 0.8], [0.1, 0.9] ]
- class nemo_automodel.components.moe.megatron.fused_indices_converter.IndicesToMultihot#
Bases:
torch.autograd.Function
Convert moe topk indices to multihot representation.
This class implements a custom forward and backward propagation operation for efficiently converting indices to multihot representation. It is an experimental feature and may change in future versions.
- static forward(ctx, indices, probs_indices, num_of_local_experts)#
Forward function for IndicesToMultihot
Convert indices to multihot representation.
- Parameters:
indices β [num_of_tokens, topk]
probs_indices β [num_of_tokens, topk]
num_of_local_experts β int
- Returns:
[num_of_tokens, num_of_local_experts] probs_in_multihot: [num_of_tokens, num_of_local_experts]
- Return type:
multihot_indices
- static backward(ctx, grad_multihot_indices, grad_probs_in_multihot)#
Backward function for IndicesToMultihot
Convert multihot probs representation to indices. indices is ignored in the backward function.
- Parameters:
grad_multihot_indices β [num_of_tokens, num_of_local_experts]
grad_probs_in_multihot β [num_of_tokens, num_of_local_experts]
- Returns:
[num_of_tokens, topk]
- Return type:
grad_probs_indices
- nemo_automodel.components.moe.megatron.fused_indices_converter.fused_indices_to_multihot(
- indices,
- probs_indices,
- num_of_local_experts,
Convert moe topk indices to multihot representation.
This function is an experimental feature and may change in future versions.