transformer package

The transformer package provides a customizable and configurable implementation of the transformer model architecture. Each component of a transformer stack, from entire layers down to individual linear layers, can be customized by swapping in different PyTorch modules using the “spec” parameters (see here). The configuration of the transformer (hidden size, number of layers, number of attention heads, etc.) is provided via a TransformerConfig object.

This is the entire attention portion, either self or cross attention, of a transformer layer including the query, key, and value projections, a “core” attention calculation (e.g. dot product attention), and final output linear projection.

class core.transformer.attention.Attention(*args: Any, **kwargs: Any)

Bases: megatron.core.transformer.module.MegatronModule, abc.ABC

Attention layer abstract class.

This layer only contains common modules required for the “self attn” and “cross attn” specializations.

forward(hidden_states, attention_mask, key_value_states=None, inference_params=None, rotary_pos_emb=None, packed_seq_params=None)

abstract get_query_key_value_tensors(hidden_states, key_value_states)

This method needs to be implemented based on whether the derived class is “self-attn” or “cross-attn”.

class core.transformer.attention.CrossAttention(*args: Any, **kwargs: Any)

Bases: core.transformer.attention.Attention

Cross-attention layer class

Cross-attention layer takes input with size [s, b, h] and context with size [s, b, h] and returns output of the same size.

get_query_key_value_tensors(hidden_states, key_value_states)

Derives query tensor from hidden_states, and key/value tensors from key_value_states.

class core.transformer.attention.CrossAttentionSubmodules(linear_q: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None, linear_kv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None, core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None, linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None)

Bases: object

core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None

linear_kv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None

linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None

linear_q: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None

class core.transformer.attention.SelfAttention(*args: Any, **kwargs: Any)

Bases: core.transformer.attention.Attention

Self-attention layer class

Self-attention layer takes input with size [s, b, h] and returns output of the same size.

get_query_key_value_tensors(hidden_states, key_value_states=None)

Derives query, key and value tensors from hidden_states.

class core.transformer.attention.SelfAttentionSubmodules(linear_qkv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None, core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None, linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None)

Bases: object

core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None

linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None

linear_qkv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None

This is a PyTorch-only implementation of dot product attention. A more efficient implementation, like those provided by FlashAttention or CUDNN’s FusedAttention, are typically used when training speed is important.

class core.transformer.dot_product_attention.DotProductAttention(*args: Any, **kwargs: Any)

Bases: megatron.core.transformer.module.MegatronModule

Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: for more details.

We use the following notation:

h: hidden size n: number of attention heads p: number of tensor model parallel partitions b: batch size s: sequence length

forward(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, attn_mask_type: Optional[megatron.core.transformer.enums.AttnMaskType] = None, packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None)

class core.transformer.enums.AttnMaskType(value)

Bases: enum.Enum

An enumeration.

causal = 2

no_mask = 3

padding = 1

class core.transformer.enums.AttnType(value)

Bases: enum.Enum

An enumeration.

cross_attn = 2

self_attn = 1

class core.transformer.enums.ModelType(value)

Bases: enum.Enum

An enumeration.

encoder_and_decoder = 2

encoder_or_decoder = 1

This provides a pass-through module that can be used in specs to indicate that the operation should not be performed. For example, when using LayerNorm with the subsequent linear layer, an IdentityOp can be passed in as the LayerNorm module to use.

class core.transformer.identity_op.IdentityFuncOp(*args: Any, **kwargs: Any)

Bases: core.transformer.identity_op.IdentityOp

This is a placeholder for IdentityFuncOp(…)(x) -> IdentityOp(x) -> x. Such a func is handy for ops like bias_dropout_fusion which themselves return a function at runtime based on passed arguments

forward(*args, **kwargs)

class core.transformer.identity_op.IdentityOp(*args: Any, **kwargs: Any)

Bases: torch.nn.Module

This is a placeholder for IdentityOp(x) -> x

forward(x, *args, **kwargs)

This is the entire MLP portion of the transformer layer with an input projection, non-linearity, and output projection.

class core.transformer.mlp.MLP(*args: Any, **kwargs: Any)

Bases: megatron.core.transformer.module.MegatronModule

MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension.

Returns an output and a bias to be added to the output. If config.add_bias_linear is False, the bias returned is None.

We use the following notation:

h: hidden size p: number of tensor model parallel partitions b: batch size s: sequence length


sharded_state_dict(prefix: str = '', sharded_offsets: tuple = ()) → megatron.core.dist_checkpointing.mapping.ShardedStateDict

class core.transformer.mlp.MLPSubmodules(linear_fc1: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None, linear_fc2: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None)

Bases: object

linear_fc1: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None

linear_fc2: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None

This provides a common base class for all modules used in the transformer that contains some common functionality.

Megatron Module.

class core.transformer.module.Float16Module(*args: Any, **kwargs: Any)

Bases: core.transformer.module.MegatronModule

Float 16 Module.


Transformer config




Specifies if the model runs in fp16 mode




Specifies if the model runs in bf16 mode




config (TransformerConfig) – The transformer config used to initalize the model

forward(*inputs, **kwargs)

load_state_dict(state_dict, strict=True)


sharded_state_dict(prefix='', *args, **kwargs)

Retrieve sharded_state_dict from the module being wrapped.

state_dict(destination=None, prefix='', keep_vars=False)

state_dict_for_save_checkpoint(prefix='', keep_vars=False)

Retrieve state_dict from the module being wrapped.

class core.transformer.module.MegatronModule(*args: Any, **kwargs: Any)

Bases: torch.nn.Module

Base Megatron module inhertied by all Models.

Megatron specific extensions of torch Module with support for pipelining


config (TransformerConfig) – Transformer config


Sets the is_first_microbatch flag if it exists. When this flag is set, TE modules will update their fp8 parameter cache.

sharded_state_dict(prefix: str = '', sharded_offsets: Tuple[Tuple[int, int, int]] = ()) → megatron.core.dist_checkpointing.mapping.ShardedStateDict

Default implementation for sharded state dict for distributed checkpointing.

General definition of sharded_state_dict simply calls sharded_state_dict_default (which call sharded_state_dict method if possible or a default implementation otherwise) recursively on all submodules.

  • prefix (str) – prefix for the state dict keys

  • sharded_offsets (Tuple[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor


dictionary of state dict keys mapped to ShardedTensors

Return type


state_dict_for_save_checkpoint(prefix: str = '', keep_vars: bool = False)

Override state dict for saving checkpoints Use this function to override the state dict for saving checkpoints.

  • prefix (str, optional) – _description_. Defaults to ‘’.

  • keep_vars (bool, optional) – _description_. Defaults to False.



Return type


core.transformer.module.conversion_helper(val, conversion)


core.transformer.module.fp32_to_float16(val, float16_convertor)


A block, or stack, of several transformer layers. The layers can all be the same or each can be unique.

class core.transformer.transformer_block.TransformerBlock(*args: Any, **kwargs: Any)

Bases: megatron.core.transformer.module.MegatronModule

Transformer class.

forward(hidden_states: torch.Tensor, attention_mask: torch.Tensor, context: Optional[torch.Tensor] = None, context_mask: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None, inference_params: Optional[megatron.core.InferenceParams] = None, packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None)

set_input_tensor(input_tensor: torch.Tensor)

Set input tensor to be used instead of forward()’s input.

When doing pipeline parallelism the input from the previous stage comes from communication, not from the input, so the model’s forward_step_func won’t have it. This function is thus used by internal code to bypass the input provided by the forward_step_func

sharded_state_dict(prefix: str = '', sharded_offsets: tuple = ()) → megatron.core.dist_checkpointing.mapping.ShardedStateDict

class core.transformer.transformer_block.TransformerBlockSubmodules(layer_specs: List[megatron.core.transformer.spec_utils.ModuleSpec] = None)

Bases: object

layer_specs: List[megatron.core.transformer.spec_utils.ModuleSpec] = None

core.transformer.transformer_block.get_num_layers_to_build(config: megatron.core.transformer.transformer_config.TransformerConfig) → int

This contains all of the configuration options for the transformer. Using a dataclass reduces code bloat by keeping all arguments together in a dataclass instead of passing several arguments through multiple layers of function calls.

class core.transformer.transformer_config.TransformerConfig(tensor_model_parallel_size: int = 1, context_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, virtual_pipeline_model_parallel_size: Optional[int] = None, sequence_parallel: bool = False, expert_model_parallel_size: int = 1, perform_initialization: bool = True, use_cpu_initialization: bool = False, fp16: bool = False, bf16: bool = False, params_dtype: torch.dtype = torch.float32, timers: Optional[Callable] = None, gradient_accumulation_fusion: bool = False, async_tensor_model_parallel_allreduce: bool = False, tp_comm_overlap: bool = False, tp_comm_split_ag: bool = True, tp_comm_atomic_ag: bool = False, tp_comm_split_rs: bool = True, tp_comm_atomic_rs: bool = False, tp_comm_bulk_wgrad: bool = True, tp_comm_bulk_dgrad: bool = True, finalize_model_grads_func: Optional[Callable] = None, pipeline_dtype: Optional[torch.dtype] = None, grad_scale_func: Optional[Callable] = None, enable_autocast: bool = False, autocast_dtype: Optional[torch.dtype] = None, variable_seq_lengths: bool = False, num_microbatches_with_partial_activation_checkpoints: Optional[int] = None, overlap_p2p_comm: bool = False, batch_p2p_comm: bool = True, batch_p2p_sync: bool = True, use_ring_exchange_p2p: bool = False, deallocate_pipeline_outputs: bool = False, no_sync_func: Optional[Callable] = None, grad_sync_func: Optional[Callable] = None, param_sync_func: Optional[Callable] = None, pipeline_model_parallel_split_rank: Optional[int] = None, cpu_offloading: bool = False, cpu_offloading_num_layers: int = 0, _cpu_offloading_context: Optional[ContextManager] = None, cpu_offloading_activations: bool = True, cpu_offloading_weights: bool = True, barrier_with_L1_time: bool = True, num_layers: int = 0, hidden_size: int = 0, num_attention_heads: int = 0, num_query_groups: Optional[int] = None, ffn_hidden_size: Optional[int] = None, kv_channels: Optional[int] = None, hidden_dropout: float = 0.1, attention_dropout: float = 0.1, fp32_residual_connection: bool = False, apply_residual_connection_post_layernorm: bool = False, layernorm_epsilon: float = 1e-05, layernorm_zero_centered_gamma: bool = False, add_bias_linear: bool = True, add_qkv_bias: bool = False, gated_linear_unit: bool = False, activation_func: Callable = torch.nn.functional.gelu, num_moe_experts: Optional[int] = None, rotary_interleaved: bool = False, window_size: Optional[Tuple[int, int]] = None, init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None, init_method_std: float = 0.02, apply_query_key_layer_scaling: bool = False, attention_softmax_in_fp32: bool = True, bias_activation_fusion: bool = False, masked_softmax_fusion: bool = False, persist_layer_norm: bool = False, memory_efficient_layer_norm: bool = False, bias_dropout_fusion: bool = False, apply_rope_fusion: bool = False, recompute_granularity: Optional[str] = None, recompute_method: Optional[str] = None, recompute_num_layers: Optional[int] = None, distribute_saved_activations: Optional[bool] = None, fp8: Optional[str] = None, fp8_margin: int = 0, fp8_interval: int = 1, fp8_amax_history_len: int = 1, fp8_amax_compute_algo: str = 'most_recent', fp8_wgrad: bool = True, clone_scatter_output_in_embedding: bool = True, disable_parameter_transpose_cache: bool = False, normalization: bool = 'LayerNorm', moe_router_load_balancing_type: str = 'aux_loss', moe_router_topk: int = 2, moe_grouped_gemm: bool = False, moe_aux_loss_coeff: float = 0, moe_z_loss_coeff: Optional[float] = None, moe_input_jitter_eps: Optional[float] = None, moe_token_dropping: bool = False, max_position_embeddings: int = 0, rotary_percent: float = 0)

Bases: core.model_parallel_config.ModelParallelConfig

Configuration object for megatron-core transformers.

  • num_layers (int) – Number of transformer layers in a transformer block.

  • hidden_size (int) – Transformer hidden size.

  • ffn_hidden_size (int) – Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size if not provided. Defaults to None.’)

  • num_attention_heads (int) – Number of transformer attention heads.

  • kv_channels (int) – Projection weights dimension in multi-head attention. This is set to hidden_size // num_attention_heads if not provided. Defaults to None.

  • num_query_groups (int) – Number of query groups for group query attention. If None, normal attention is used.

  • hidden_dropout (float) – Dropout probability for transformer hidden state. Defaults to 0.1.

  • attention_dropout (float) – Post attention dropout probability. Defaults to 0.1.

  • fp32_residual_connection (bool) – If true, move residual connections to fp32.

  • apply_residual_connection_post_layernorm (bool) – If true, uses the original BERT residule connection ordering. Defaults to False.

  • layernorm_epsilon (float) – Layernorm epsilon. Defaults to 1e-5.

  • layernorm_zero_centered_gamma (bool) – if set to ‘True’, the LayerNorm is adjusted to center the gamma values around 0. This improves numerical stability. Defaults to False.

  • add_bias_linear (bool) – Include a bias term in all linear layers (QKV projections, after core attention, and two in MLP layer). Default is True.

  • add_qkv_bias (bool) – Add a bias term only for QKV projections. Default is False.

  • gated_linear_unit (bool) – Use a gated linear unit for the first linear layer in the MLP. Defaults to False.

  • activation_func (Callable) – Activation function to use for the non-linearity in the MLP. Defaults to F.gelu.

  • num_moe_experts (int) – Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Defaults to None (no MoE).

  • rotary_interleaved (bool) – True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of first half and second half (LLaMa style). Default to False.

  • init_method (Callable) – Method to initialize weights. Note that bias is always set to zero. Should be a function that takes a single Tensor and initializes it. Defaults to megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with mean=0.0 and std=init_method_Std.

  • output_layer_init_method (Callable) – Method to initialize weights of the output layer of both attention and MLP blocks. Defaults to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers).

  • init_method_std (float) – Standard deviation of the zero mean normal for the default initialization method, not used if init_method and output_layer_init_method are provided. Defaults to 0.02.

  • apply_query_key_layer_scaling (bool) – If true, scale Q * K^T by 1 / layer-number. Defaults to True.

  • attention_softmax_in_fp32 (bool) – If true, run attention masking and softmax in fp32. This should be true if apply_query_key_layer_scaling is true.

  • bias_gelu_fustion (bool) – If true, fuses bias and gelu. Defaults to False.

  • masked_softmax_fusion (bool) – If true, uses softmax fusion.

  • persist_layer_norm (bool) – If true, uses the persistent fused layer norm kernel. This kernel only supports a fixed set of hidden sizes. Defaults to False.

  • memory_efficient_layer_norm (bool) – If True, and using local layers (not from TransformerEngine), tells Apex to use the memory efficient fused LayerNorm kernel. Ignored if not using LayerNorm. Defaults to False.

  • bias_dropout_fusion (bool) – If true, uses bias dropout fusion.

  • recompute_granularity (str) – megatron-core supports ‘selective’ activation checkpointing where only the memory intensive part of attention is checkpointed. These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: for more details. ‘full’ will checkpoint the entire transformer layer. Must be ‘selective’ or ‘full’. ‘selective’ always uses all layers. Defaults to None.

  • recompute_method (str) – uniform will uniformly divide the total number of transformer layers in a transformer block and recompute the input activation of each divided chunk at the specified granularity. block will recompute the input activations for only a set number of transformer layers per pipeline stage. The rest of the layers in the pipeline stage will not have any activations recomputed. Must be ‘uniform’ or ‘block’. Defaults to None.

  • recompute_num_layers (int) – When recompute_method is uniform, recompute_num_layers is the number of transformer layers in each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is the number of transformer layers to recompute within each pipeline stage. Must be None for ‘selective’ activation checkpointing. Defaults to None.

  • distribute_saved_activations (bool) – If true, distribute recomputed activations across the model parallel group. Defaults to None.

  • fp8 (str) – If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined choices: (1) ‘e4m3’ uniformly uses e4m3 for all FP8 tensors, (2) ‘hybrid’ uses e4m3 for all FP8 activation and weight tensors and e5m2 for all FP8 output activation gradient tensors. Defaults to None.

  • fp8_margin (int) – Margin for the scaling factor computation.

  • fp8_interval (int) – Controls how often the scaling factor is recomputed.

  • fp8_amax_history_len (int) – The length of the amax history window used for scaling factor computation.

  • fp8_amax_compute_algo (str) – Algorithm used for choosing the amax value for the scaling factor computation. There are 2 predefined choices: max chooses the largest amax in the history window, while most_recent always chooses the most recently seen value.

  • fp8_wgrad (bool) – When set to False, override FP8 config options and do the wgrad computation in higher precision. Defaults to True.

  • clone_scatter_output_in_embedding (bool) – When set to true, clone the output of scatter_to_sequence_parallel_region in embedding layer to facilitate garbage collection of input.

  • disable_parameter_transpose_cache (bool) – When set to true, the parameter transposes are not cached for subsequent iterations. Defaults to False.

  • normalization (str) – Swtich b/w LayerNorm and RMSNorm as normalization layers. For now, these are primarily used by Transformer-Engine’s layers like LayerNormLinear. Default value is LayerNorm.

  • window_size ((int,int) or None) – If not None, then will use sliding window attention. The size of the window is specified by the numbers inside the tuple; -1 is special value meaning “infinite window size”.

  • moe_router_load_balancing_type (str) – Determines the load balancing strategy for the router. “aux_loss” corresponds to the load balancing loss used in GShard and SwitchTransformer, “sinkhorn” corresponds to the balancing algorithm used in S-BASE, and “none” implies no load balancing. The default is “aux_loss”.

  • moe_router_topk (int) – Number of experts to route to for each token. The default is 2.

  • moe_grouped_gemm (bool) – When there are multiple experts per rank, compress multiple local (potentially small)

  • (https (gemms in a single kernel launch to improve the utilization and performance by leveraging the Grouped GEMM feature introduced since CUTLASS 2.8) – //

  • moe_aux_loss_coeff (float) – Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.

  • moe_z_loss_coeff (float) – Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended.

  • moe_input_jitter_eps (float) – Add noise to the input tensor by applying jitter with a specified epsilon value.

  • moe_token_dropping (bool) – This feature involves selectively dropping and padding tokens for each expert to achieve a specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note: Currently unsupported.

add_bias_linear: bool = True

add_qkv_bias: bool = False

apply_query_key_layer_scaling: bool = False

apply_residual_connection_post_layernorm: bool = False

apply_rope_fusion: bool = False

attention_dropout: float = 0.1

attention_softmax_in_fp32: bool = True

bias_activation_fusion: bool = False

bias_dropout_fusion: bool = False

clone_scatter_output_in_embedding: bool = True

disable_parameter_transpose_cache: bool = False

distribute_saved_activations: bool = None

ffn_hidden_size: int = None

fp32_residual_connection: bool = False

fp8: str = None

fp8_amax_compute_algo: str = 'most_recent'

fp8_amax_history_len: int = 1

fp8_interval: int = 1

fp8_margin: int = 0

fp8_wgrad: bool = True

gated_linear_unit: bool = False

hidden_dropout: float = 0.1

hidden_size: int = 0

init_method: Callable = None

init_method_std: float = 0.02

kv_channels: int = None

layernorm_epsilon: float = 1e-05

layernorm_zero_centered_gamma: bool = False

masked_softmax_fusion: bool = False

max_position_embeddings: int = 0

memory_efficient_layer_norm: bool = False

moe_aux_loss_coeff: float = 0

moe_grouped_gemm: bool = False

moe_input_jitter_eps: float = None

moe_router_load_balancing_type: str = 'aux_loss'

moe_router_topk: int = 2

moe_token_dropping: bool = False

moe_z_loss_coeff: float = None

normalization: bool = 'LayerNorm'

num_attention_heads: int = 0

num_layers: int = 0

num_moe_experts: int = None

num_query_groups: int = None

output_layer_init_method: Callable = None

persist_layer_norm: bool = False

recompute_granularity: str = None

recompute_method: str = None

recompute_num_layers: int = None

rotary_interleaved: bool = False

rotary_percent: float = 0

window_size: Optional[Tuple[int, int]] = None

A single standard transformer layer including attention and MLP blocks.

class core.transformer.transformer_layer.BaseTransformerLayer

Bases: abc.ABC

A common parent class for TransformerLayer like implementations.

A dummy class that is subclassed by similar TransformerLayer`s e.g. the `TransformerLayer in this file and possibly other TransformerLayer implementations that aim to use TransformerBlock as the base module. The main purpose is to check if any layer (or module) provided in the spec is a subclass of this class to allow fanning-out of that spec for all the layers in the TransformerBlock. See _get_block_submodules method implementation in file for more details.

class core.transformer.transformer_layer.TransformerLayer(*args: Any, **kwargs: Any)

Bases: megatron.core.transformer.module.MegatronModule, core.transformer.transformer_layer.BaseTransformerLayer

A single transformer layer.

Transformer layer takes input with size [s, b, h] and returns an output of the same size.

forward(hidden_states, attention_mask, context=None, context_mask=None, rotary_pos_emb=None, inference_params=None, packed_seq_params=None)

sharded_state_dict(prefix: str = '', sharded_offsets: tuple = ()) → megatron.core.dist_checkpointing.mapping.ShardedStateDict

class core.transformer.transformer_layer.TransformerLayerSubmodules(input_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, self_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, self_attn_bda: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityFuncOp, pre_cross_attn_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, cross_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, cross_attn_bda: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityFuncOp, pre_mlp_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, mlp: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, mlp_bda: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityFuncOp, sharded_state_dict_keys_map: Dict[str, str] = <factory>)

Bases: object

sharded_state_dict_keys_map: Dict[str, str]

Various utilities used in the transformer implementation.

Utilities for transformer layers.

core.transformer.utils.attention_mask_func(attention_scores, attention_mask)



OpenAI’s gelu implementation.

core.transformer.utils.get_default_causal_mask(sq: int) → torch.Tensor

Return the causal upper triangular mask for softmax input.

core.transformer.utils.get_linear_layer(rows, columns, init_method, perform_initialization=True)

Simple linear layer with weight initialization.

core.transformer.utils.make_sharded_object_for_checkpoint(obj: Any, key: str, sharded_offsets: Iterable[Tuple[int, int, int]] = (), replica_id: Union[None, int, Tuple[int, ...]] = None, **kwargs)

Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group).

  • obj (object) – any object to be sharded

  • key (str) – unique identifier of the object

  • sharded_offsets (Iterable[Tuple[int, int, int]]) – offsets normally prepended to ShardedTensors, will be used as global offsets for ShardedObject

  • replica_id (Union[None, int, Tuple[int, ...]]) – replica id

core.transformer.utils.make_sharded_tensors_for_checkpoint(state_dict: megatron.core.dist_checkpointing.mapping.StateDict, prefix: str, tensor_parallel_layers_axis_map: Optional[Dict[str, int]] = None, sharded_offsets: Iterable[Tuple[int, int, int]] = (), extra_state_suffix: str = '_extra_state')

Wraps tensors from transformer layers with ShardedTensor or ShardedObject.

For a given state_dict, wraps: - all _extra_states with ShardedObject - all tensors specified in tensor_parallel_layers_axis_map with TP and DP sharded ShardedTensor - other values with DP sharded ShardedTensor

  • state_dict (StateDict) – state_dict to convert

  • prefix (str) – prefix appended to keys in final state dict

  • tensor_parallel_layers_axis_map (Dict[str, int], optional) – dict mapping layer names to the axis for TP sharding

  • sharded_offsets (Iterable[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related), passed along to ShardedTensor

  • extra_state_suffix (str, default = '_extra_state') – layers with this suffix will be wrapped with ShardedObject instead of ShardedTensor.


core.transformer.utils.sharded_state_dict_default(module: torch.nn.Module, prefix: str = '', sharded_offsets: Tuple[Tuple[int, int, int]] = ()) → megatron.core.dist_checkpointing.mapping.ShardedStateDict

Provides implementation for sharded_state_dict method for non-MegatronModules.

Tries to call module.sharded_state_dict when possible, otherwise uses regular state dict and assumes tensors are replicated across TP and DP.

keep_vars=True is passed to module.state_dict so that optimizer states can be sharded later on.

  • module (torch.nn.Module) – module which sharded state dict we want to obtain

  • prefix (str) – prefix for the state dict keys

  • sharded_offsets (Tuple[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor


dictionary of state dict keys mapped to ShardedTensors

Return type


Previous fusions package
Next Mixture of Experts package
© Copyright 2022-2024, NVIDIA. Last updated on Mar 16, 2024.