core.transformer.moe.shared_experts#
Module Contents#
Classes#
MLP layer for Shared Experts. |
Functions#
Set sequence_sr for the grad_fn of a tensor to control the backward order. For older PyTorch version, do nothing (backward order is not changed). The bigger the value is, the earlier the grad_fn is scheduled. |
API#
- class core.transformer.moe.shared_experts.SharedExpertMLP(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: megatron.core.transformer.mlp.MLPSubmodules,
- gate: bool,
- pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
Bases:
megatron.core.transformer.mlp.MLPMLP layer for Shared Experts.
Initialization
- stream#
None
- forward(hidden_states)#
Forward function
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: tuple = (),
- metadata: Optional[dict] = None,
Gets sharded state dict.
- pre_forward_comm(input)#
All Gather for SP before forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.
- linear_fc1_forward_and_act(overlapped_comm_output=None)#
Do Linear FC1 and activation function forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.
- linear_fc2_forward(overlapped_comm_output=None)#
Do Linear FC2 forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.
- post_forward_comm()#
Reduce scatter for SP after forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.
- get_output()#
Gets the module forward output. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.
- core.transformer.moe.shared_experts.set_tensor_grad_fn_sequence_sr(tensor, value)#
Set sequence_sr for the grad_fn of a tensor to control the backward order. For older PyTorch version, do nothing (backward order is not changed). The bigger the value is, the earlier the grad_fn is scheduled.