core.transformer.mlp#
Module Contents#
Classes#
Interface for linear_fc1 module in MLP. |
|
Protocol describing how to build a linear_fc1 layer in MLP. |
|
Interface for activation_function module in MLP. |
|
Protocol for activation_function module in MLP. |
|
Interface for linear_fc2 module in MLP. |
|
Protocol describing how to build a linear_fc2 layer in MLP. |
|
The dataclass for ModuleSpecs of MLP submodules including linear fc1, activation function, linear fc2. |
|
MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. |
Functions#
Data#
API#
- core.transformer.mlp.logger#
‘getLogger(…)’
- class core.transformer.mlp.LinearFc1Interface#
Bases:
typing.ProtocolInterface for linear_fc1 module in MLP.
- forward(
- hidden_states: torch.Tensor,
- /,
Forward method for linear_fc1 module.
- backward_dw() None#
Backward method for linear_fc1 module.
- class core.transformer.mlp.LinearFc1Builder#
Bases:
typing.ProtocolProtocol describing how to build a linear_fc1 layer in MLP.
- __call__(
- input_size: int,
- output_size: int,
- /,
- *,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- init_method: collections.abc.Callable[[torch.Tensor], None],
- gather_output: bool,
- bias: bool,
- skip_bias_add: bool,
- is_expert: bool,
- tp_comm_buffer_name: str | None,
- tp_group: torch.distributed.ProcessGroup | None,
- stride: int = 1,
Builds a linear_fc1 layer for MLP.
- class core.transformer.mlp.TEActivationFunctionInterface#
Bases:
typing.ProtocolInterface for activation_function module in MLP.
- forward(input_: torch.Tensor, /) torch.Tensor#
Forward method for activation_function module.
- class core.transformer.mlp.TEActivationFunctionBuilder#
Bases:
typing.ProtocolProtocol for activation_function module in MLP.
- __call__(
- *,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
Builds an activation function module for MLP.
- class core.transformer.mlp.LinearFc2Interface#
Bases:
typing.ProtocolInterface for linear_fc2 module in MLP.
- forward(
- hidden_states: torch.Tensor,
- /,
Forward method for linear_fc2 module.
- backward_dw() None#
Backward method for linear_fc2 module.
- class core.transformer.mlp.LinearFc2Builder#
Bases:
typing.ProtocolProtocol describing how to build a linear_fc2 layer in MLP.
- __call__(
- input_size: int,
- output_size: int,
- /,
- *,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- init_method: collections.abc.Callable[[torch.Tensor], None],
- bias: bool,
- input_is_parallel: bool,
- skip_bias_add: bool,
- is_expert: bool,
- tp_comm_buffer_name: str | None,
- tp_group: torch.distributed.ProcessGroup | None,
Builds a linear_fc2 layer for MLP.
- class core.transformer.mlp.MLPSubmodules#
The dataclass for ModuleSpecs of MLP submodules including linear fc1, activation function, linear fc2.
- linear_fc1: core.transformer.mlp.LinearFc1Builder#
None
- linear_fc2: core.transformer.mlp.LinearFc2Builder#
None
- activation_func: core.transformer.mlp.TEActivationFunctionBuilder | None#
None
Builder for an activation function module; only used if config.use_te_activation_func is True.
- class core.transformer.mlp.MLP(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: core.transformer.mlp.MLPSubmodules,
- is_expert: bool = False,
- input_size: Optional[int] = None,
- ffn_hidden_size: Optional[int] = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
megatron.core.transformer.module.MegatronModuleMLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension.
Returns an output and a bias to be added to the output. If config.add_bias_linear is False, the bias returned is None.
We use the following notation: h: hidden size p: number of tensor model parallel partitions b: batch size s: sequence length
Initialization
- forward(
- hidden_states: torch.Tensor,
- per_token_scale: torch.Tensor | None = None,
- **kwargs,
Perform the forward pass through the MLP block.
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: tuple = (),
- metadata: Optional[dict] = None,
Return the sharded state dictionary of the module.
- backward_dw()#
- core.transformer.mlp.apply_swiglu_sharded_factory(
- original_sh_ten,
- sharded_offsets,
- singleton_local_shards: bool = False,