core.transformer.mlp#
Module Contents#
Classes#
The dataclass for ModuleSpecs of MLP submodules including linear fc1, activation function, linear fc2. |
|
MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. |
Functions#
Data#
API#
- core.transformer.mlp.logger#
‘getLogger(…)’
- class core.transformer.mlp.MLPSubmodules#
The dataclass for ModuleSpecs of MLP submodules including linear fc1, activation function, linear fc2.
- linear_fc1: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- activation_func: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- linear_fc2: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- class core.transformer.mlp.MLP(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: core.transformer.mlp.MLPSubmodules,
- is_expert: bool = False,
- input_size: Optional[int] = None,
- ffn_hidden_size: int = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
megatron.core.transformer.module.MegatronModuleMLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension.
Returns an output and a bias to be added to the output. If config.add_bias_linear is False, the bias returned is None.
We use the following notation: h: hidden size p: number of tensor model parallel partitions b: batch size s: sequence length
Initialization
- forward(hidden_states, per_token_scale=None)#
Perform the forward pass through the MLP block.
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: tuple = (),
- metadata: Optional[dict] = None,
Return the sharded state dictionary of the module.
- backward_dw()#
- core.transformer.mlp.apply_swiglu_sharded_factory(
- original_sh_ten,
- sharded_offsets,
- singleton_local_shards: bool = False,