core.transformer.moe.experts#

Module Contents#

Classes#

GroupedMLP

An efficient implementation of the Experts layer using GroupedGEMM.

TEGroupedMLP

An efficient implementation of the Experts layer using TE’s GroupedLinear.

SequentialMLP

An implementation of the Experts layer using a sequence of MLP layers.

Functions#

expert_dist_ckpt_decorator

Decorator of shared_state_dict in expert layer for distributed checkpoint. Since !1940, the TP size for Expert layer can be different with Attention. To make distributed checkpoint work in such cases, we use a decorator to replace the default TP parallel states with expert-TP parallel states.

Data#

API#

core.transformer.moe.experts.logger#

‘getLogger(…)’

core.transformer.moe.experts.expert_dist_ckpt_decorator(func)#

Decorator of shared_state_dict in expert layer for distributed checkpoint. Since !1940, the TP size for Expert layer can be different with Attention. To make distributed checkpoint work in such cases, we use a decorator to replace the default TP parallel states with expert-TP parallel states.

class core.transformer.moe.experts.GroupedMLP(
num_local_experts: int,
config: megatron.core.transformer.transformer_config.TransformerConfig,
pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

An efficient implementation of the Experts layer using GroupedGEMM.

Executes multiple experts in parallel to maximize computational efficiency.

Initialization

forward(
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
permuted_probs: torch.Tensor,
)#

Forward step of the GroupedMLP.

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Maps local expert to global experts. The sharded_state_dict for the weight parts are compatible with the SequentialMLP, whereas the optimizer states are not due to the limitation from weight transposing. That is, for finetuning scenario, the checkpoint is compatible with the SequentialMLP.

When singleton_local_shards metadata flag is True, experts are broken down into separate tensors and stored under separate global keys. Additionally, similarly to MLP, layers with GLU activations are broken down into separate w and v tensors.

backward_dw()#

Performs backward pass for weight gradients in Experts. Empty implementation for compatibility with SequentialMLP and TEGroupedMLP.

class core.transformer.moe.experts.TEGroupedMLP(
num_local_experts,
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.mlp.MLPSubmodules,
pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

An efficient implementation of the Experts layer using TE’s GroupedLinear.

Executes multiple experts in parallel to maximize computational efficiency.

Initialization

static _apply_bias(
intermediate_parallel,
bias_parallel,
tokens_per_expert,
permuted_probs,
)#
forward(
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
permuted_probs: torch.Tensor,
) Tuple[torch.Tensor, Optional[torch.Tensor]]#

Forward of TEGroupedMLP

Parameters:
  • permuted_local_hidden_states (torch.Tensor) – The permuted input hidden states of the

  • experts. (local)

  • tokens_per_expert (torch.Tensor) – The number of tokens per expert.

  • permuted_probs (torch.Tensor) – The permuted probs of each token produced by the router.

Returns:

The output of the local experts.

Return type:

output (torch.Tensor)

sharded_state_dict(
prefix: str = '',
sharded_offsets: tuple = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Maps local expert to global experts. The sharded state dict is interchangable with SequentialMLP’s.

backward_dw()#

Performs backward pass for weight gradients in TEGroupedMLP.

This method executes the backward pass for weight gradients by calling backward_dw() on the linear layers in reverse order (fc2 followed by fc1). If an error occurs during execution, it is caught and re-raised with a descriptive message.

class core.transformer.moe.experts.SequentialMLP(
num_local_experts,
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.mlp.MLPSubmodules,
pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

An implementation of the Experts layer using a sequence of MLP layers.

This class executes each expert sequentially.

Initialization

_pad_tensor_for_quantization(hidden, probs)#

Padding tensor shape to multiples of 16/32.

forward(
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
permuted_probs: torch.Tensor,
)#

Forward step of the SequentialMLP.

backward_dw()#

Backward pass for weight gradients in SequentialMLP.

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Maps local expert to global experts.