nemo_automodel.components.moe.megatron.fused_indices_converter

View as Markdown

Module Contents

Classes

NameDescription
IndicesToMultihotConvert moe topk indices to multihot representation.

Functions

NameDescription
_indices_to_multihot_kernelTriton kernel for converting indices to multihot representation.
_multihot_to_indices_kernelTriton kernel for converting multihot representation to indices.
fused_indices_to_multihotConvert moe topk indices to multihot representation.

Data

HAVE_TRITON

null_decorator

API

class nemo_automodel.components.moe.megatron.fused_indices_converter.IndicesToMultihot()

Bases: Function

Convert moe topk indices to multihot representation.

This class implements a custom forward and backward propagation operation for efficiently converting indices to multihot representation. It is an experimental feature and may change in future versions.

nemo_automodel.components.moe.megatron.fused_indices_converter.IndicesToMultihot.backward(
ctx,
grad_multihot_indices,
grad_probs_in_multihot
)
staticmethod

Backward function for IndicesToMultihot

Convert multihot probs representation to indices. indices is ignored in the backward function.

Parameters:

grad_multihot_indices

[num_of_tokens, num_of_local_experts]

grad_probs_in_multihot

[num_of_tokens, num_of_local_experts]

Returns:

[num_of_tokens, topk]

nemo_automodel.components.moe.megatron.fused_indices_converter.IndicesToMultihot.forward(
ctx,
indices,
probs_indices,
num_of_local_experts
)
staticmethod

Forward function for IndicesToMultihot

Convert indices to multihot representation.

Parameters:

indices

[num_of_tokens, topk]

probs_indices

[num_of_tokens, topk]

num_of_local_experts

int

Returns:

[num_of_tokens, num_of_local_experts]

nemo_automodel.components.moe.megatron.fused_indices_converter._indices_to_multihot_kernel(
indices_ptr,
probs_in_indices_ptr,
multihot_indices_ptr,
probs_in_multihot_ptr,
position_map_ptr,
num_of_local_experts: triton.language.constexpr,
num_of_local_experts_next_power_of_2: triton.language.constexpr,
topk: triton.language.constexpr,
topk_next_power_of_2: triton.language.constexpr,
BLOCK_SIZE: triton.language.constexpr
)

Triton kernel for converting indices to multihot representation.

Output: multihot_indices: [num_of_tokens, num_of_local_experts] probs_in_multihot: [num_of_tokens, num_of_local_experts]

Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2, then the kernel can process the following conversion:

Output Example: multihot_indices = [ [1, 1, -1, -1], [-1, 1, 1, -1] ] probs_in_multihot = [ [0.1, 0.2, 0.0, 0.0], [0.0, 0.3, 0.4, 0.0] ]

nemo_automodel.components.moe.megatron.fused_indices_converter._multihot_to_indices_kernel(
probs_in_multihot_ptr,
position_map_ptr,
probs_indices_ptr,
num_of_local_experts: triton.language.constexpr,
num_of_local_experts_next_power_of_2: triton.language.constexpr,
topk: triton.language.constexpr,
topk_next_power_of_2: triton.language.constexpr,
BLOCK_SIZE: triton.language.constexpr
)

Triton kernel for converting multihot representation to indices.

Output: probs_indices: [num_of_tokens, topk]

Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2, then the kernel can process the following conversion:

Output Example: probs_indices = [ [0.7, 0.8], [0.1, 0.9] ]

nemo_automodel.components.moe.megatron.fused_indices_converter.fused_indices_to_multihot(
indices,
probs_indices,
num_of_local_experts
)

Convert moe topk indices to multihot representation.

This function is an experimental feature and may change in future versions.

nemo_automodel.components.moe.megatron.fused_indices_converter.HAVE_TRITON = True
nemo_automodel.components.moe.megatron.fused_indices_converter.null_decorator = partial(lambda x: x)