core.transformer.moe.shared_experts#
Module Contents#
Classes#
State machine states for SharedExpertMLP overlapped forward pass. |
|
MLP layer for Shared Experts. |
Functions#
Decorator to validate overlap state and cached variables before method execution, and update state after method execution. |
|
Set sequence_sr for the grad_fn of a tensor to control the backward order. For older PyTorch version, do nothing (backward order is not changed). The bigger the value is, the earlier the grad_fn is scheduled. |
API#
- class core.transformer.moe.shared_experts.SharedExpertState(*args, **kwds)#
Bases:
enum.EnumState machine states for SharedExpertMLP overlapped forward pass.
Initialization
- IDLE#
0
- PRE_FORWARD_COMM_DONE#
1
- FC1_FORWARD_DONE#
2
- FC2_FORWARD_DONE#
3
- POST_FORWARD_COMM_DONE#
4
- core.transformer.moe.shared_experts.overlap_state_check(
- required_state: core.transformer.moe.shared_experts.SharedExpertState,
- next_state: core.transformer.moe.shared_experts.SharedExpertState,
Decorator to validate overlap state and cached variables before method execution, and update state after method execution.
- Parameters:
required_state – The expected SharedExpertState before this method runs.
next_state – The SharedExpertState to transition to after method execution.
- class core.transformer.moe.shared_experts._BackwardStreamWait#
Bases:
torch.autograd.Function- static forward(ctx, input, stream)#
forward
- static backward(ctx, grad_output)#
backward with stream wait
- class core.transformer.moe.shared_experts.SharedExpertMLP(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: megatron.core.transformer.mlp.MLPSubmodules,
- gate: bool,
- pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
Bases:
megatron.core.transformer.mlp.MLPMLP layer for Shared Experts.
Initialization
- stream#
None
- forward(hidden_states: torch.Tensor) → torch.Tensor#
Forward function
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: tuple = (),
- metadata: Optional[dict] = None,
Gets sharded state dict.
- wait_current_stream()#
Wait for the current stream to complete.
- pre_forward_comm(input, wait_current_stream=True)#
All Gather for SP before forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.
- linear_fc1_forward_and_act(overlapped_comm_output=None)#
Do Linear FC1 and activation function forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.
- linear_fc2_forward(overlapped_comm_output=None)#
Do Linear FC2 forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.
- post_forward_comm()#
Reduce scatter for SP after forward. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.
- get_output()#
Gets the module forward output. This function is used to overlap shared experts with the dispatcher. It is only useful when –moe-shared-expert-overlap is set and may be changed.
- core.transformer.moe.shared_experts.set_tensor_grad_fn_sequence_sr(tensor, value)#
Set sequence_sr for the grad_fn of a tensor to control the backward order. For older PyTorch version, do nothing (backward order is not changed). The bigger the value is, the earlier the grad_fn is scheduled.