core.transformer.mlp#

Module Contents#

Classes#

MLPSubmodules

The dataclass for ModuleSpecs of MLP submodules including linear fc1, activation function, linear fc2.

MLP

MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension.

Functions#

Data#

API#

core.transformer.mlp.logger#

‘getLogger(…)’

class core.transformer.mlp.MLPSubmodules#

The dataclass for ModuleSpecs of MLP submodules including linear fc1, activation function, linear fc2.

linear_fc1: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

activation_func: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

linear_fc2: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

class core.transformer.mlp.MLP(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: core.transformer.mlp.MLPSubmodules,
is_expert: bool = False,
input_size: Optional[int] = None,
ffn_hidden_size: int = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension.

Returns an output and a bias to be added to the output. If config.add_bias_linear is False, the bias returned is None.

We use the following notation: h: hidden size p: number of tensor model parallel partitions b: batch size s: sequence length

Initialization

forward(hidden_states, per_token_scale=None)#

Perform the forward pass through the MLP block.

sharded_state_dict(
prefix: str = '',
sharded_offsets: tuple = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Return the sharded state dictionary of the module.

backward_dw()#
core.transformer.mlp.apply_swiglu_sharded_factory(
original_sh_ten,
sharded_offsets,
singleton_local_shards: bool = False,
)#