core.transformer.moe.shared_experts#

Module Contents#

Classes#

SharedExpertState

State machine states for SharedExpertMLP overlapped forward pass.

_BackwardStreamWait

SharedExpertMLP

MLP layer for Shared Experts.

Functions#

overlap_state_check

Decorator to validate overlap state and cached variables before method execution, and update state after method execution.

set_tensor_grad_fn_sequence_sr

Set sequence_sr for the grad_fn of a tensor to control the backward order. For older PyTorch version, do nothing (backward order is not changed). The bigger the value is, the earlier the grad_fn is scheduled.

API#

class core.transformer.moe.shared_experts.SharedExpertState(*args, **kwds)#

Bases: enum.Enum

State machine states for SharedExpertMLP overlapped forward pass.

Initialization

IDLE#

0

PRE_FORWARD_COMM_DONE#

1

FC1_FORWARD_DONE#

2

FC2_FORWARD_DONE#

3

POST_FORWARD_COMM_DONE#

4

core.transformer.moe.shared_experts.overlap_state_check(
required_state: core.transformer.moe.shared_experts.SharedExpertState,
next_state: core.transformer.moe.shared_experts.SharedExpertState,
)#

Decorator to validate overlap state and cached variables before method execution, and update state after method execution.

Parameters:
  • required_state – The expected SharedExpertState before this method runs.

  • next_state – The SharedExpertState to transition to after method execution.

class core.transformer.moe.shared_experts._BackwardStreamWait#

Bases: torch.autograd.Function

static forward(ctx, input, stream)#

forward

static backward(ctx, grad_output)#

backward with stream wait

class core.transformer.moe.shared_experts.SharedExpertMLP(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.mlp.MLPSubmodules,
gate: bool,
pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.mlp.MLP

MLP layer for Shared Experts.

Initialization

stream#

None

forward(hidden_states: torch.Tensor) torch.Tensor#

Forward function

sharded_state_dict(
prefix: str = '',
sharded_offsets: tuple = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Gets sharded state dict.

wait_current_stream()#

Wait for the current stream to complete.

pre_forward_comm(input, wait_current_stream=True)#

All Gather for SP before forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.

linear_fc1_forward_and_act(overlapped_comm_output=None)#

Do Linear FC1 and activation function forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.

linear_fc2_forward(overlapped_comm_output=None)#

Do Linear FC2 forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.

post_forward_comm()#

Reduce scatter for SP after forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.

get_output()#

Gets the module forward output. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.

core.transformer.moe.shared_experts.set_tensor_grad_fn_sequence_sr(tensor, value)#

Set sequence_sr for the grad_fn of a tensor to control the backward order. For older PyTorch version, do nothing (backward order is not changed). The bigger the value is, the earlier the grad_fn is scheduled.