core.transformer.moe.shared_experts#

Module Contents#

Classes#

SharedExpertMLP

MLP layer for Shared Experts.

Functions#

set_tensor_grad_fn_sequence_sr

Set sequence_sr for the grad_fn of a tensor to control the backward order. For older PyTorch version, do nothing (backward order is not changed). The bigger the value is, the earlier the grad_fn is scheduled.

API#

class core.transformer.moe.shared_experts.SharedExpertMLP(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.mlp.MLPSubmodules,
gate: bool,
pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.mlp.MLP

MLP layer for Shared Experts.

Initialization

stream#

None

forward(hidden_states)#

Forward function

sharded_state_dict(
prefix: str = '',
sharded_offsets: tuple = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Gets sharded state dict.

pre_forward_comm(input)#

All Gather for SP before forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.

linear_fc1_forward_and_act(overlapped_comm_output=None)#

Do Linear FC1 and activation function forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.

linear_fc2_forward(overlapped_comm_output=None)#

Do Linear FC2 forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.

post_forward_comm()#

Reduce scatter for SP after forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.

get_output()#

Gets the module forward output. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.

core.transformer.moe.shared_experts.set_tensor_grad_fn_sequence_sr(tensor, value)#

Set sequence_sr for the grad_fn of a tensor to control the backward order. For older PyTorch version, do nothing (backward order is not changed). The bigger the value is, the earlier the grad_fn is scheduled.