core.transformer.mlp#

Module Contents#

Classes#

LinearFc1Interface

Interface for linear_fc1 module in MLP.

LinearFc1Builder

Protocol describing how to build a linear_fc1 layer in MLP.

TEActivationFunctionInterface

Interface for activation_function module in MLP.

TEActivationFunctionBuilder

Protocol for activation_function module in MLP.

LinearFc2Interface

Interface for linear_fc2 module in MLP.

LinearFc2Builder

Protocol describing how to build a linear_fc2 layer in MLP.

MLPSubmodules

The dataclass for ModuleSpecs of MLP submodules including linear fc1, activation function, linear fc2.

MLP

MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension.

Functions#

Data#

API#

core.transformer.mlp.logger#

‘getLogger(…)’

class core.transformer.mlp.LinearFc1Interface#

Bases: typing.Protocol

Interface for linear_fc1 module in MLP.

forward(
hidden_states: torch.Tensor,
/,
) tuple[torch.Tensor, torch.Tensor | None]#

Forward method for linear_fc1 module.

backward_dw() None#

Backward method for linear_fc1 module.

class core.transformer.mlp.LinearFc1Builder#

Bases: typing.Protocol

Protocol describing how to build a linear_fc1 layer in MLP.

__call__(
input_size: int,
output_size: int,
/,
*,
config: megatron.core.transformer.transformer_config.TransformerConfig,
init_method: collections.abc.Callable[[torch.Tensor], None],
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str | None,
tp_group: torch.distributed.ProcessGroup | None,
stride: int = 1,
) core.transformer.mlp.LinearFc1Interface#

Builds a linear_fc1 layer for MLP.

class core.transformer.mlp.TEActivationFunctionInterface#

Bases: typing.Protocol

Interface for activation_function module in MLP.

forward(input_: torch.Tensor, /) torch.Tensor#

Forward method for activation_function module.

class core.transformer.mlp.TEActivationFunctionBuilder#

Bases: typing.Protocol

Protocol for activation_function module in MLP.

__call__(
*,
config: megatron.core.transformer.transformer_config.TransformerConfig,
) core.transformer.mlp.TEActivationFunctionInterface#

Builds an activation function module for MLP.

class core.transformer.mlp.LinearFc2Interface#

Bases: typing.Protocol

Interface for linear_fc2 module in MLP.

forward(
hidden_states: torch.Tensor,
/,
) tuple[torch.Tensor, torch.Tensor | None]#

Forward method for linear_fc2 module.

backward_dw() None#

Backward method for linear_fc2 module.

class core.transformer.mlp.LinearFc2Builder#

Bases: typing.Protocol

Protocol describing how to build a linear_fc2 layer in MLP.

__call__(
input_size: int,
output_size: int,
/,
*,
config: megatron.core.transformer.transformer_config.TransformerConfig,
init_method: collections.abc.Callable[[torch.Tensor], None],
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str | None,
tp_group: torch.distributed.ProcessGroup | None,
) core.transformer.mlp.LinearFc2Interface#

Builds a linear_fc2 layer for MLP.

class core.transformer.mlp.MLPSubmodules#

The dataclass for ModuleSpecs of MLP submodules including linear fc1, activation function, linear fc2.

linear_fc1: core.transformer.mlp.LinearFc1Builder#

None

linear_fc2: core.transformer.mlp.LinearFc2Builder#

None

activation_func: core.transformer.mlp.TEActivationFunctionBuilder | None#

None

Builder for an activation function module; only used if config.use_te_activation_func is True.

class core.transformer.mlp.MLP(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: core.transformer.mlp.MLPSubmodules,
is_expert: bool = False,
input_size: Optional[int] = None,
ffn_hidden_size: Optional[int] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension.

Returns an output and a bias to be added to the output. If config.add_bias_linear is False, the bias returned is None.

We use the following notation: h: hidden size p: number of tensor model parallel partitions b: batch size s: sequence length

Initialization

forward(
hidden_states: torch.Tensor,
per_token_scale: torch.Tensor | None = None,
**kwargs,
)#

Perform the forward pass through the MLP block.

sharded_state_dict(
prefix: str = '',
sharded_offsets: tuple = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Return the sharded state dictionary of the module.

backward_dw()#
core.transformer.mlp.apply_swiglu_sharded_factory(
original_sh_ten,
sharded_offsets,
singleton_local_shards: bool = False,
)#