nemo_automodel.components.moe.megatron.fused_indices_converter
nemo_automodel.components.moe.megatron.fused_indices_converter
Module Contents
Classes
Functions
Data
API
Bases: Function
Convert moe topk indices to multihot representation.
This class implements a custom forward and backward propagation operation for efficiently converting indices to multihot representation. It is an experimental feature and may change in future versions.
Backward function for IndicesToMultihot
Convert multihot probs representation to indices. indices is ignored in the backward function.
Parameters:
[num_of_tokens, num_of_local_experts]
[num_of_tokens, num_of_local_experts]
Returns:
[num_of_tokens, topk]
Forward function for IndicesToMultihot
Convert indices to multihot representation.
Parameters:
[num_of_tokens, topk]
[num_of_tokens, topk]
int
Returns:
[num_of_tokens, num_of_local_experts]
Triton kernel for converting indices to multihot representation.
Output: multihot_indices: [num_of_tokens, num_of_local_experts] probs_in_multihot: [num_of_tokens, num_of_local_experts]
Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2, then the kernel can process the following conversion:
Output Example: multihot_indices = [ [1, 1, -1, -1], [-1, 1, 1, -1] ] probs_in_multihot = [ [0.1, 0.2, 0.0, 0.0], [0.0, 0.3, 0.4, 0.0] ]
Triton kernel for converting multihot representation to indices.
Output: probs_indices: [num_of_tokens, topk]
Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2, then the kernel can process the following conversion:
Output Example: probs_indices = [ [0.7, 0.8], [0.1, 0.9] ]
Convert moe topk indices to multihot representation.
This function is an experimental feature and may change in future versions.