nemo_automodel.components.moe.megatron.fused_indices_converter#

Module Contents#

Classes#

IndicesToMultihot

Convert moe topk indices to multihot representation.

Functions#

_indices_to_multihot_kernel

Triton kernel for converting indices to multihot representation.

_multihot_to_indices_kernel

Triton kernel for converting multihot representation to indices.

fused_indices_to_multihot

Convert moe topk indices to multihot representation.

Data#

API#

nemo_automodel.components.moe.megatron.fused_indices_converter.null_decorator#

β€˜partial(…)’

nemo_automodel.components.moe.megatron.fused_indices_converter._indices_to_multihot_kernel(
indices_ptr,
probs_in_indices_ptr,
multihot_indices_ptr,
probs_in_multihot_ptr,
position_map_ptr,
num_of_local_experts: triton.language.constexpr,
num_of_local_experts_next_power_of_2: triton.language.constexpr,
topk: triton.language.constexpr,
topk_next_power_of_2: triton.language.constexpr,
BLOCK_SIZE: triton.language.constexpr,
)#

Triton kernel for converting indices to multihot representation.

Input: indices: [num_of_tokens, topk] probs_in_indices: [num_of_tokens, topk] Output: multihot_indices: [num_of_tokens, num_of_local_experts] probs_in_multihot: [num_of_tokens, num_of_local_experts]

Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2, then the kernel can process the following conversion:

Input Example: indices = [ [0, 1], [1, 2] ] probs_in_indices = [ [0.1, 0.2], [0.3, 0.4] ] Output Example: multihot_indices = [ [1, 1, -1, -1], [-1, 1, 1, -1] ] probs_in_multihot = [ [0.1, 0.2, 0.0, 0.0], [0.0, 0.3, 0.4, 0.0] ]

nemo_automodel.components.moe.megatron.fused_indices_converter._multihot_to_indices_kernel(
probs_in_multihot_ptr,
position_map_ptr,
probs_indices_ptr,
num_of_local_experts: triton.language.constexpr,
num_of_local_experts_next_power_of_2: triton.language.constexpr,
topk: triton.language.constexpr,
topk_next_power_of_2: triton.language.constexpr,
BLOCK_SIZE: triton.language.constexpr,
)#

Triton kernel for converting multihot representation to indices.

Input: probs_in_multihot: [num_of_tokens, num_of_local_experts] position_map: [num_of_tokens, num_of_local_experts] Output: probs_indices: [num_of_tokens, topk]

Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2, then the kernel can process the following conversion:

Input Example: probs_in_multihot = [ [0.7, 0.8, 0.0, 0.0], [0.0, 0.1, 0.9, 0.0] ] position_map = [ [1, 1, -1, -1], [-1, 1, 1, -1] ] Output Example: probs_indices = [ [0.7, 0.8], [0.1, 0.9] ]

class nemo_automodel.components.moe.megatron.fused_indices_converter.IndicesToMultihot#

Bases: torch.autograd.Function

Convert moe topk indices to multihot representation.

This class implements a custom forward and backward propagation operation for efficiently converting indices to multihot representation. It is an experimental feature and may change in future versions.

static forward(ctx, indices, probs_indices, num_of_local_experts)#

Forward function for IndicesToMultihot

Convert indices to multihot representation.

Parameters:
  • indices – [num_of_tokens, topk]

  • probs_indices – [num_of_tokens, topk]

  • num_of_local_experts – int

Returns:

[num_of_tokens, num_of_local_experts] probs_in_multihot: [num_of_tokens, num_of_local_experts]

Return type:

multihot_indices

static backward(ctx, grad_multihot_indices, grad_probs_in_multihot)#

Backward function for IndicesToMultihot

Convert multihot probs representation to indices. indices is ignored in the backward function.

Parameters:
  • grad_multihot_indices – [num_of_tokens, num_of_local_experts]

  • grad_probs_in_multihot – [num_of_tokens, num_of_local_experts]

Returns:

[num_of_tokens, topk]

Return type:

grad_probs_indices

nemo_automodel.components.moe.megatron.fused_indices_converter.fused_indices_to_multihot(
indices,
probs_indices,
num_of_local_experts,
)#

Convert moe topk indices to multihot representation.

This function is an experimental feature and may change in future versions.