What can I help you with?
Megatron Core User Guide

transformer package

The transformer package provides a customizable and configurable implementation of the transformer model architecture. Each component of a transformer stack, from entire layers down to individual linear layers, can be customized by swapping in different PyTorch modules using the “spec” parameters (see here). The configuration of the transformer (hidden size, number of layers, number of attention heads, etc.) is provided via a TransformerConfig object.

This is the entire attention portion, either self or cross attention, of a transformer layer including the query, key, and value projections, a “core” attention calculation (e.g. dot product attention), and final output linear projection.

class core.transformer.attention.Attention(*args: Any, **kwargs: Any)

Bases: megatron.core.transformer.module.MegatronModule, abc.ABC

Attention layer abstract class.

This layer only contains common modules required for the “self attn” and “cross attn” specializations.

flash_decoding(sequence_len_offset: torch.Tensor, query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, inference_key_memory: torch.Tensor, inference_value_memory: torch.Tensor, rotary_cos: torch.Tensor, rotary_sin: torch.Tensor)

The flash decoding kernel will do the following in a single execution: 1. Compute RoPE embedding with precomputed cos & sin tensors 2. Update the KV Cache 3. Performs the flash attention operation

forward(hidden_states, attention_mask, key_value_states=None, inference_params=None, rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, attention_bias=None, packed_seq_params=None, sequence_len_offset=None)

Perform a forward pass through the attention module.

abstract get_query_key_value_tensors(hidden_states, key_value_states)

This method needs to be implemented based on whether the derived class is “self-attn” or “cross-attn”.

class core.transformer.attention.CrossAttention(*args: Any, **kwargs: Any)

Bases: core.transformer.attention.Attention

Cross-attention layer class

Cross-attention layer takes input with size [s, b, h] and context with size [s, b, h] and returns output of the same size.

get_query_key_value_tensors(hidden_states, key_value_states)

Derives query tensor from hidden_states, and key/value tensors from key_value_states.

class core.transformer.attention.CrossAttentionSubmodules(linear_q: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, linear_kv: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, core_attention: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, linear_proj: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None)

Bases: object

Configuration class for specifying the submodules of a cross-attention.

core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
linear_kv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
linear_q: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
class core.transformer.attention.SelfAttention(*args: Any, **kwargs: Any)

Bases: core.transformer.attention.Attention

Self-attention layer class

Self-attention layer takes input with size [s, b, h] and returns output of the same size.

get_query_key_value_tensors(hidden_states, key_value_states=None)

Derives query, key and value tensors from hidden_states.

run_realtime_tests()

Performs a consistency check.

This function makes sure that tensors across devices are the same during an experiment. This is often not guaranteed to be so because of silent hardware failures (eg, memory corruption loading a checkpoint, network traffic corruption encountered during data transmission).

(TODO) In the future, more tensors should be checked across the training run and checked every X iterations. This is left for future work. Equality of tensors is probably not required; transmitting hashes is sufficient.

class core.transformer.attention.SelfAttentionSubmodules(linear_qkv: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, core_attention: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, linear_proj: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, q_layernorm: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, k_layernorm: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None)

Bases: object

Configuration class for specifying the submodules of a self-attention.

core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
k_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
linear_qkv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
q_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None

This is a PyTorch-only implementation of dot product attention. A more efficient implementation, like those provided by FlashAttention or CUDNN’s FusedAttention, are typically used when training speed is important.

class core.transformer.dot_product_attention.DotProductAttention(*args: Any, **kwargs: Any)

Bases: megatron.core.transformer.module.MegatronModule

Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.

We use the following notation:

h: hidden size n: number of attention heads p: number of tensor model parallel partitions b: batch size s: sequence length

forward(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, attn_mask_type: Optional[megatron.core.transformer.enums.AttnMaskType] = None, attention_bias: Optional[torch.Tensor] = None, packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None)

Forward.

class core.transformer.enums.AttnBackend(value)

Bases: enum.Enum

Attention Backend

auto = 5
flash = 1
fused = 2
local = 4
unfused = 3
class core.transformer.enums.AttnMaskType(value)

Bases: enum.Enum

Attention Mask Type

arbitrary = 5
causal = 2
no_mask = 3
padding = 1
padding_causal = 4
class core.transformer.enums.AttnType(value)

Bases: enum.Enum

Attention type

cross_attn = 2
self_attn = 1
class core.transformer.enums.ModelType(value)

Bases: enum.Enum

Model Type

encoder_or_decoder for bert, gpt etc encoder_and_decoder for multimodal , T5 etc

encoder_and_decoder = 2
encoder_or_decoder = 1

This provides a pass-through module that can be used in specs to indicate that the operation should not be performed. For example, when using LayerNorm with the subsequent linear layer, an IdentityOp can be passed in as the LayerNorm module to use.

class core.transformer.identity_op.IdentityFuncOp(*args: Any, **kwargs: Any)

Bases: core.transformer.identity_op.IdentityOp

This is a placeholder for IdentityFuncOp(…)(x) -> IdentityOp(x) -> x. Such a func is handy for ops like bias_dropout_fusion which themselves return a function at runtime based on passed arguments

forward(*args, **kwargs)
class core.transformer.identity_op.IdentityOp(*args: Any, **kwargs: Any)

Bases: torch.nn.Module

This is a placeholder for IdentityOp(x) -> x

forward(x, *args, **kwargs)

This is the entire MLP portion of the transformer layer with an input projection, non-linearity, and output projection.

class core.transformer.mlp.MLP(*args: Any, **kwargs: Any)

Bases: megatron.core.transformer.module.MegatronModule

MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension.

Returns an output and a bias to be added to the output. If config.add_bias_linear is False, the bias returned is None.

We use the following notation:

h: hidden size p: number of tensor model parallel partitions b: batch size s: sequence length

forward(hidden_states)

Perform the forward pass through the MLP block.

sharded_state_dict(prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None) → megatron.core.dist_checkpointing.mapping.ShardedStateDict
class core.transformer.mlp.MLPSubmodules(linear_fc1: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None, linear_fc2: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None)

Bases: object

linear_fc1: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
linear_fc2: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
core.transformer.mlp.apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets)

This provides a common base class for all modules used in the transformer that contains some common functionality.

Megatron Module.

class core.transformer.module.Float16Module(*args: Any, **kwargs: Any)

Bases: core.transformer.module.MegatronModule

Float 16 Module.

config

Transformer config

Type

TransformerConfig

fp16

Specifies if the model runs in fp16 mode

Type

bool

bf16

Specifies if the model runs in bf16 mode

Type

bool

Parameters

config (TransformerConfig) – The transformer config used to initalize the model

forward(*inputs, **kwargs)
load_state_dict(state_dict, strict=True)
set_input_tensor(input_tensor)
sharded_state_dict(prefix='', *args, **kwargs)

Retrieve sharded_state_dict from the module being wrapped.

state_dict(destination=None, prefix='', keep_vars=False)
state_dict_for_save_checkpoint(prefix='', keep_vars=False)

Retrieve state_dict from the module being wrapped.

class core.transformer.module.MegatronModule(*args: Any, **kwargs: Any)

Bases: torch.nn.Module

Base Megatron module inhertied by all Models.

Megatron specific extensions of torch Module with support for pipelining

Parameters

config (TransformerConfig) – Transformer config

set_is_first_microbatch()

Sets the is_first_microbatch flag if it exists and config.fp8==True. When this flag is set, TE modules will update their fp8 parameter cache.

sharded_state_dict(prefix: str = '', sharded_offsets: Tuple[Tuple[int, int, int]] = (), metadata: Optional[dict] = None) → megatron.core.dist_checkpointing.mapping.ShardedStateDict

Default implementation for sharded state dict for distributed checkpointing.

General definition of sharded_state_dict simply calls sharded_state_dict_default (which call sharded_state_dict method if possible or a default implementation otherwise) recursively on all submodules.

Parameters
  • prefix (str) – prefix for the state dict keys

  • sharded_offsets (Tuple[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor

  • metadata (dict, optional) – metadata passed recursively to sharded_state_dict methods

Returns

dictionary of state dict keys mapped to ShardedTensors

Return type

dict

state_dict_for_save_checkpoint(prefix: str = '', keep_vars: bool = False)

Override state dict for saving checkpoints Use this function to override the state dict for saving checkpoints.

Parameters
  • prefix (str, optional) – _description_. Defaults to ‘’.

  • keep_vars (bool, optional) – _description_. Defaults to False.

Returns

_description_

Return type

_type_

core.transformer.module.conversion_helper(val, conversion)
core.transformer.module.float16_to_fp32(val)
core.transformer.module.fp32_to_float16(val, float16_convertor)
core.transformer.module.param_is_not_shared(param)

A block, or stack, of several transformer layers. The layers can all be the same or each can be unique.

class core.transformer.transformer_block.TransformerBlock(*args: Any, **kwargs: Any)

Bases: megatron.core.transformer.module.MegatronModule

Transformer class.

forward(hidden_states: torch.Tensor, attention_mask: torch.Tensor, context: Optional[torch.Tensor] = None, context_mask: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None, rotary_pos_cos: Optional[torch.Tensor] = None, rotary_pos_sin: Optional[torch.Tensor] = None, attention_bias: Optional[torch.Tensor] = None, inference_params: Optional[megatron.core.InferenceParams] = None, packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None, sequence_len_offset: Optional[torch.Tensor] = None)

Perform the forward pass through the transformer block.

This method handles the core computation of the transformer, including self-attention, optional cross-attention, and feed-forward operations.

Parameters
  • hidden_states (Tensor) – Input tensor of shape [s, b, h] where s is the sequence length, b is the batch size, and h is the hidden size.

  • attention_mask (Tensor) – Boolean tensor of shape [1, 1, s, s] for masking self-attention.

  • context (Tensor, optional) – Context tensor for cross-attention.

  • context_mask (Tensor, optional) – Mask for cross-attention context

  • rotary_pos_emb (Tensor, optional) – Rotary positional embeddings.

  • attention_bias (Tensor) – Bias tensor for Q * K.T of shape in shape broadcastable to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. Used as an alternative to apply attention mask for TE cuDNN attention.

  • inference_params (InferenceParams, optional) – Parameters for inference-time optimizations.

  • packed_seq_params (PackedSeqParams, optional) – Parameters for packed sequence processing.

Returns

The output hidden states tensor of shape [s, b, h], and optionally the updated context tensor if cross-attention is used.

Return type

Union[Tensor, Tuple[Tensor, Tensor]]

get_cuda_graph_optional_args(attention_mask: torch.Tensor, context: torch.Tensor, context_mask: torch.Tensor, rotary_pos_emb: torch.Tensor, attention_bias: torch.Tensor, inference_params: megatron.core.InferenceParams, packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams)

Get optional tensor arguments for CUDA graph.

set_input_tensor(input_tensor: torch.Tensor)

Set input tensor to be used instead of forward()’s input.

When doing pipeline parallelism the input from the previous stage comes from communication, not from the input, so the model’s forward_step_func won’t have it. This function is thus used by internal code to bypass the input provided by the forward_step_func

sharded_state_dict(prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None) → megatron.core.dist_checkpointing.mapping.ShardedStateDict

Generate a sharded state dictionary for the transformer block.

Parameters
  • prefix (str, optional) – Prefix to be added to all keys in the state dict. Defaults to an empty string.

  • sharded_offsets (tuple, optional) – Tuple of sharding offsets.

  • metadata (dict, optional) – Additional metadata for sharding. Can specify if layers are non-homogeneous. Defaults to None.

Returns

A dictionary containing the sharded state of the model.

Return type

ShardedStateDict

class core.transformer.transformer_block.TransformerBlockSubmodules(layer_specs: Optional[List[megatron.core.transformer.spec_utils.ModuleSpec]] = None, layer_norm: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, torch.nn.Module]] = None)

Bases: object

Dataclass for specifying the submodules of a transformer block.

This class defines the structure for configuring the layers and normalization within a transformer block, allowing for flexible and customizable architecture designs.

Parameters
  • layer_specs (List[ModuleSpec], optional) – A list of module specifications for the layers within the transformer block. Each specification typically defines a complete transformer layer (e.g., self-attention, feed-forward network).

  • layer_norm (Optional[Union[ModuleSpec, torch.nn.Module]], optional) – Specification or instance of the layer normalization to be applied.

layer_norm: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, torch.nn.Module]] = None
layer_specs: List[megatron.core.transformer.spec_utils.ModuleSpec] = None
core.transformer.transformer_block.get_num_layers_to_build(config: megatron.core.transformer.transformer_config.TransformerConfig) → int

Determine the number of transformer layers to build for the current pipeline stage. :param config: Configuration object containing transformer model parameters. :type config: TransformerConfig

Returns

The number of layers to be built for the current pipeline stage.

Return type

int

This contains all of the configuration options for the transformer. Using a dataclass reduces code bloat by keeping all arguments together in a dataclass instead of passing several arguments through multiple layers of function calls.

class core.transformer.transformer_config.MLATransformerConfig(tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, virtual_pipeline_model_parallel_size: Optional[int] = None, sequence_parallel: bool = False, context_parallel_size: int = 1, hierarchical_context_parallel_sizes: Optional[list[int]] = None, expert_model_parallel_size: int = 1, expert_tensor_parallel_size: Optional[int] = None, moe_extended_tp: bool = False, perform_initialization: bool = True, use_cpu_initialization: bool = False, fp16: bool = False, bf16: bool = False, params_dtype: torch.dtype = torch.float32, timers: Optional[Callable] = None, finalize_model_grads_func: Optional[Callable] = None, grad_scale_func: Optional[Callable] = None, no_sync_func: Optional[Callable] = None, grad_sync_func: Optional[Callable] = None, param_sync_func: Optional[Callable] = None, deterministic_mode: bool = False, enable_autocast: bool = False, autocast_dtype: Optional[torch.dtype] = None, num_microbatches_with_partial_activation_checkpoints: Optional[int] = None, gradient_accumulation_fusion: bool = False, async_tensor_model_parallel_allreduce: bool = False, use_te_rng_tracker: bool = False, tp_comm_overlap: bool = False, tp_comm_bulk_wgrad: bool = True, tp_comm_bulk_dgrad: bool = True, tp_comm_overlap_ag: bool = True, tp_comm_overlap_rs: bool = True, tp_comm_overlap_rs_dgrad: bool = False, tp_comm_split_ag: bool = True, tp_comm_atomic_ag: bool = False, tp_comm_split_rs: bool = True, tp_comm_atomic_rs: bool = False, cross_entropy_loss_fusion: bool = False, tp_comm_overlap_disable_qkv: bool = False, tp_comm_overlap_disable_fc1: bool = False, tp_comm_bootstrap_backend: str = 'nccl', pipeline_dtype: Optional[torch.dtype] = None, variable_seq_lengths: bool = False, overlap_p2p_comm: bool = False, batch_p2p_comm: bool = True, batch_p2p_sync: bool = True, use_ring_exchange_p2p: bool = False, deallocate_pipeline_outputs: bool = False, defer_embedding_wgrad_compute: bool = False, wgrad_deferral_limit: int = 0, pipeline_model_parallel_split_rank: Optional[int] = None, overlap_p2p_comm_warmup_flush: bool = False, microbatch_group_size_per_vp_stage: Optional[int] = None, cpu_offloading: bool = False, cpu_offloading_num_layers: int = 0, _cpu_offloading_context: Optional[ContextManager] = None, cpu_offloading_activations: bool = True, cpu_offloading_weights: bool = True, barrier_with_L1_time: bool = True, num_layers: int = 0, first_pipeline_num_layers: Optional[int] = None, last_pipeline_num_layers: Optional[int] = None, hidden_size: int = 0, num_attention_heads: int = 0, attention_backend: megatron.core.transformer.enums.AttnBackend = megatron.core.transformer.enums.AttnBackend.auto, softmax_scale: Optional[float] = None, num_query_groups: Optional[int] = None, ffn_hidden_size: Optional[int] = None, kv_channels: Optional[int] = None, hidden_dropout: float = 0.1, attention_dropout: float = 0.1, fp32_residual_connection: bool = False, apply_residual_connection_post_layernorm: bool = False, layernorm_epsilon: float = 1e-05, layernorm_zero_centered_gamma: bool = False, add_bias_linear: bool = True, add_qkv_bias: bool = False, gated_linear_unit: bool = False, activation_func: Callable = torch.nn.functional.gelu, activation_func_fp8_input_store: bool = False, num_moe_experts: Optional[int] = None, rotary_interleaved: bool = False, window_size: Optional[Tuple[int, int]] = None, normalization: str = 'RMSNorm', qk_layernorm: bool = False, test_mode: bool = False, calculate_per_token_loss: bool = False, multi_latent_attention: bool = True, init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None, init_method_std: float = 0.02, apply_query_key_layer_scaling: bool = False, attention_softmax_in_fp32: bool = True, bias_activation_fusion: bool = False, masked_softmax_fusion: bool = False, persist_layer_norm: bool = False, memory_efficient_layer_norm: bool = False, bias_dropout_fusion: bool = False, apply_rope_fusion: bool = False, recompute_granularity: Optional[str] = None, recompute_method: Optional[str] = None, recompute_num_layers: Optional[int] = None, distribute_saved_activations: Optional[bool] = None, fp8: Optional[str] = None, fp8_margin: int = 0, fp8_interval: int = 1, fp8_amax_history_len: int = 1, fp8_amax_compute_algo: str = 'most_recent', fp8_wgrad: bool = True, fp8_dot_product_attention: bool = False, fp8_multi_head_attention: bool = False, tp_only_amax_red: bool = False, moe_shared_expert_intermediate_size: Optional[int] = None, moe_shared_expert_overlap: bool = False, moe_layer_freq: int = 1, moe_ffn_hidden_size: Optional[int] = None, moe_router_load_balancing_type: str = 'aux_loss', moe_router_topk: int = 2, moe_router_topk_limited_devices: Optional[int] = None, moe_router_pre_softmax: bool = False, moe_router_topk_scaling_factor: Optional[float] = None, moe_grouped_gemm: bool = False, moe_use_legacy_grouped_gemm: bool = False, moe_aux_loss_coeff: float = 0, moe_z_loss_coeff: Optional[float] = None, moe_input_jitter_eps: Optional[float] = None, moe_token_dropping: bool = False, moe_token_dispatcher_type: str = 'allgather', moe_per_layer_logging: bool = False, moe_expert_capacity_factor: Optional[float] = None, moe_pad_expert_input_to_capacity: bool = False, moe_token_drop_policy: str = 'probs', moe_layer_recompute: bool = False, cp_comm_type: Optional[Union[str, List[str]]] = None, clone_scatter_output_in_embedding: bool = True, disable_parameter_transpose_cache: bool = False, enable_cuda_graph: bool = False, cuda_graph_retain_backward_graph: bool = False, external_cuda_graph: bool = False, config_logger_dir: str = '', flash_decode: bool = False, inference_rng_tracker: bool = False, q_lora_rank: int = 512, kv_lora_rank: int = 512, qk_head_dim: int = 128, qk_pos_emb_head_dim: int = 64, v_head_dim: int = 128, rotary_base: float = 10000, rotary_scaling_factor: float = 40, max_position_embeddings: int = 163840, beta_fast: float = 32, beta_slow: float = 1, mscale: float = 0.707, mscale_all_dim: float = 0.707)

Bases: core.transformer.transformer_config.TransformerConfig

Configuration object for megatron-core Multi-Latent Attention (MLA) transformers.

The initialization function has an argument for each parameter, including those in ModelParallelConfig. Included YaRN RoPE parameters that is fused in MLA.

beta_fast: float = 32

Beta fast for YaRN RoPE.

beta_slow: float = 1

Beta slow for YaRN RoPE.

kv_lora_rank: int = 512

Rank of Key and Value tensors’ low rank representation.

max_position_embeddings: int = 163840

Maximum position embeddings for the original model.

mscale: float = 0.707

Mscale for YaRN RoPE in Multi-Latent Attention.

mscale_all_dim: float = 0.707

Mscale all dimensions for YaRN RoPE in Multi-Latent Attention.

multi_latent_attention: bool = True

Whether to use Multi-Latent Attention.

normalization: str = 'RMSNorm'

Default normalization layer for MLA models is RMSNorm.

q_lora_rank: int = 512

Rank of Query tensor’s low rank representation.

qk_head_dim: int = 128

Dimension of the head in the QK projection. q_head_dim = qk_head_dim + qk_pos_emb_head_dim

qk_pos_emb_head_dim: int = 64

Dimension of the position embedding in the QK projection.

rotary_base: float = 10000

Rotary base for the rotary embeddings.

rotary_scaling_factor: float = 40

Rotary scaling factor for the rotary embeddings.

v_head_dim: int = 128

Dimension of the head in the V projection.

class core.transformer.transformer_config.TransformerConfig(tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, virtual_pipeline_model_parallel_size: Optional[int] = None, sequence_parallel: bool = False, context_parallel_size: int = 1, hierarchical_context_parallel_sizes: Optional[list[int]] = None, expert_model_parallel_size: int = 1, expert_tensor_parallel_size: Optional[int] = None, moe_extended_tp: bool = False, perform_initialization: bool = True, use_cpu_initialization: bool = False, fp16: bool = False, bf16: bool = False, params_dtype: torch.dtype = torch.float32, timers: Optional[Callable] = None, finalize_model_grads_func: Optional[Callable] = None, grad_scale_func: Optional[Callable] = None, no_sync_func: Optional[Callable] = None, grad_sync_func: Optional[Callable] = None, param_sync_func: Optional[Callable] = None, deterministic_mode: bool = False, enable_autocast: bool = False, autocast_dtype: Optional[torch.dtype] = None, num_microbatches_with_partial_activation_checkpoints: Optional[int] = None, gradient_accumulation_fusion: bool = False, async_tensor_model_parallel_allreduce: bool = False, use_te_rng_tracker: bool = False, tp_comm_overlap: bool = False, tp_comm_bulk_wgrad: bool = True, tp_comm_bulk_dgrad: bool = True, tp_comm_overlap_ag: bool = True, tp_comm_overlap_rs: bool = True, tp_comm_overlap_rs_dgrad: bool = False, tp_comm_split_ag: bool = True, tp_comm_atomic_ag: bool = False, tp_comm_split_rs: bool = True, tp_comm_atomic_rs: bool = False, cross_entropy_loss_fusion: bool = False, tp_comm_overlap_disable_qkv: bool = False, tp_comm_overlap_disable_fc1: bool = False, tp_comm_bootstrap_backend: str = 'nccl', pipeline_dtype: Optional[torch.dtype] = None, variable_seq_lengths: bool = False, overlap_p2p_comm: bool = False, batch_p2p_comm: bool = True, batch_p2p_sync: bool = True, use_ring_exchange_p2p: bool = False, deallocate_pipeline_outputs: bool = False, defer_embedding_wgrad_compute: bool = False, wgrad_deferral_limit: int = 0, pipeline_model_parallel_split_rank: Optional[int] = None, overlap_p2p_comm_warmup_flush: bool = False, microbatch_group_size_per_vp_stage: Optional[int] = None, cpu_offloading: bool = False, cpu_offloading_num_layers: int = 0, _cpu_offloading_context: Optional[ContextManager] = None, cpu_offloading_activations: bool = True, cpu_offloading_weights: bool = True, barrier_with_L1_time: bool = True, num_layers: int = 0, first_pipeline_num_layers: Optional[int] = None, last_pipeline_num_layers: Optional[int] = None, hidden_size: int = 0, num_attention_heads: int = 0, attention_backend: megatron.core.transformer.enums.AttnBackend = megatron.core.transformer.enums.AttnBackend.auto, softmax_scale: Optional[float] = None, num_query_groups: Optional[int] = None, ffn_hidden_size: Optional[int] = None, kv_channels: Optional[int] = None, hidden_dropout: float = 0.1, attention_dropout: float = 0.1, fp32_residual_connection: bool = False, apply_residual_connection_post_layernorm: bool = False, layernorm_epsilon: float = 1e-05, layernorm_zero_centered_gamma: bool = False, add_bias_linear: bool = True, add_qkv_bias: bool = False, gated_linear_unit: bool = False, activation_func: Callable = torch.nn.functional.gelu, activation_func_fp8_input_store: bool = False, num_moe_experts: Optional[int] = None, rotary_interleaved: bool = False, window_size: Optional[Tuple[int, int]] = None, normalization: bool = 'LayerNorm', qk_layernorm: bool = False, test_mode: bool = False, calculate_per_token_loss: bool = False, multi_latent_attention: bool = False, init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None, init_method_std: float = 0.02, apply_query_key_layer_scaling: bool = False, attention_softmax_in_fp32: bool = True, bias_activation_fusion: bool = False, masked_softmax_fusion: bool = False, persist_layer_norm: bool = False, memory_efficient_layer_norm: bool = False, bias_dropout_fusion: bool = False, apply_rope_fusion: bool = False, recompute_granularity: Optional[str] = None, recompute_method: Optional[str] = None, recompute_num_layers: Optional[int] = None, distribute_saved_activations: Optional[bool] = None, fp8: Optional[str] = None, fp8_margin: int = 0, fp8_interval: int = 1, fp8_amax_history_len: int = 1, fp8_amax_compute_algo: str = 'most_recent', fp8_wgrad: bool = True, fp8_dot_product_attention: bool = False, fp8_multi_head_attention: bool = False, tp_only_amax_red: bool = False, moe_shared_expert_intermediate_size: Optional[int] = None, moe_shared_expert_overlap: bool = False, moe_layer_freq: int = 1, moe_ffn_hidden_size: Optional[int] = None, moe_router_load_balancing_type: str = 'aux_loss', moe_router_topk: int = 2, moe_router_topk_limited_devices: Optional[int] = None, moe_router_pre_softmax: bool = False, moe_router_topk_scaling_factor: Optional[float] = None, moe_grouped_gemm: bool = False, moe_use_legacy_grouped_gemm: bool = False, moe_aux_loss_coeff: float = 0, moe_z_loss_coeff: Optional[float] = None, moe_input_jitter_eps: Optional[float] = None, moe_token_dropping: bool = False, moe_token_dispatcher_type: str = 'allgather', moe_per_layer_logging: bool = False, moe_expert_capacity_factor: Optional[float] = None, moe_pad_expert_input_to_capacity: bool = False, moe_token_drop_policy: str = 'probs', moe_layer_recompute: bool = False, cp_comm_type: Optional[Union[str, List[str]]] = None, clone_scatter_output_in_embedding: bool = True, disable_parameter_transpose_cache: bool = False, enable_cuda_graph: bool = False, cuda_graph_retain_backward_graph: bool = False, external_cuda_graph: bool = False, config_logger_dir: str = '', flash_decode: bool = False, inference_rng_tracker: bool = False)

Bases: core.model_parallel_config.ModelParallelConfig

Configuration object for megatron-core transformers.

The initialization function has an argument for each parameter, including those in ModelParallelConfig.

activation_func: Callable

Activation function to use for the non-linearity in the MLP.

activation_func_fp8_input_store: bool = False

Store the input of MLP activation function in FP8 for backprop to save memory. The stored input is casted back to the original precision before backprop compuatation.

add_bias_linear: bool = True

Include a bias term in all linear layers (QKV projections, after core attention, and two in MLP layer).

add_qkv_bias: bool = False

Add a bias term only for QKV projections.

apply_query_key_layer_scaling: bool = False

If true, scale Q * K^T by 1 / layer-number. This improve numeric stability when training with fp16.

apply_residual_connection_post_layernorm: bool = False

If True, uses the original BERT residule connection ordering.

apply_rope_fusion: bool = False

If True, use fused RoPE kernel.

attention_backend: megatron.core.transformer.enums.AttnBackend

Attention backend to run. By default we let transformer engine decide the best backend to run (except in the case of local). If attention backend is local we use the local pytorch implementation in mcore. Users can specify exact backend by changing this config.

attention_dropout: float = 0.1

Post attention dropout probability.

attention_softmax_in_fp32: bool = True

If True, run attention masking and softmax in fp32. This should be True if apply_query_key_layer_scaling is True.

bias_activation_fusion: bool = False

If True, fuses bias addition and the activation function when possible.

bias_dropout_fusion: bool = False

If True, uses bias dropout fusion.

calculate_per_token_loss: bool = False

Whether cross entropy loss is calculated over the actual number of non-padded tokens in the global batch, versus the default behavior of assuming all tokens are non-padded.

clone_scatter_output_in_embedding: bool = True

When set to True, clone the output of scatter_to_sequence_parallel_region in embedding layer to facilitate garbage collection of input.

config_logger_dir: str = ''

When non-empty, dumps entry-point configs to config_logger_dir

cp_comm_type: Union[str, List[str]] = None

Inter-gpu communication type for context parallelism. str: all layers share same communication type. List[str]: each layer has its separate communication type. cp_comm_type of each layer can be “p2p” or “all_gather” or “a2a” or “a2a+p2p”. “p2p”: Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. “all_gather”: All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. “a2a”: Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. “a2a+p2p”: A hierarchical implementation of context parallelism to attention. It uses A2A communications in low-level CP groups (e.g., via NVLink), and P2P communications in high-level CP groups (e.g., via IBLink).

cuda_graph_retain_backward_graph: bool = False

When set to true, cudagraph backward passes will be graph captured with ‘retain_grad=True’ This may enable cudagraphs for certain modules that are not completely cudagraph safe. For more details, see: https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html.

disable_parameter_transpose_cache: bool = False

When set to true, the parameter transposes are not cached for subsequent iterations.

distribute_saved_activations: bool = None

If True, distribute recomputed activations across the model parallel group.

enable_cuda_graph: bool = False

When set to true, TransformerLayer layers are swapped with a CUDA graphed version.

external_cuda_graph: bool = False

When set to true, TransformerLayer layers are swapped with user provided CUDA graphs.

ffn_hidden_size: int = None

Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size if not provided.

first_pipeline_num_layers: int = None

Number of transformer layers on first pipeline stage. None implies equal layer division across PP ranks.

flash_decode: bool = False

Use the optimized flash decoding kernel during inference.

fp32_residual_connection: bool = False

If true, move residual connections to fp32.

fp8: str = None

If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined choices (1) ‘e4m3’ uniformly uses e4m3 for all FP8 tensors, (2) ‘hybrid’ uses e4m3 for all FP8 activation and weight tensors and e5m2 for all FP8 output activation gradient tensors.

fp8_amax_compute_algo: str = 'most_recent'

Algorithm used for choosing the amax value for the scaling factor computation. There are 2 predefined choices: max chooses the largest amax in the history window, while most_recent always chooses the most recently seen value.

fp8_amax_history_len: int = 1

The length of the amax history window used for scaling factor computation.

fp8_dot_product_attention: bool = False

When set to True, use the FP8 implementation of Dot Product Attention.

fp8_interval: int = 1

DEPRECATED from TransformerEngine v1.8.0. This flag is ignored. Controls how often the scaling factor is recomputed.

fp8_margin: int = 0

Margin for the scaling factor computation.

fp8_multi_head_attention: bool = False

When set to True, use the FP8 implementation of Multi Head Attention.

fp8_wgrad: bool = True

When set to False, override FP8 config options and do the wgrad computation in higher precision.

gated_linear_unit: bool = False

Use a gated linear unit for the first linear layer in the MLP.

hidden_dropout: float = 0.1

Dropout probability for transformer hidden state.

hidden_size: int = 0

Transformer hidden size.

inference_rng_tracker: bool = False

Whether we should instantiate a separate RNG tracker for inference.

init_method: Callable = None

Method to initialize weights. Note that bias is always set to zero. Should be a function that takes a single Tensor and initializes it. If None, will be set to megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with mean=0.0 and std=init_method_std.

init_method_std: float = 0.02

Standard deviation of the zero mean normal for the default initialization method, not used if init_method and output_layer_init_method are provided.

kv_channels: int = None

Projection weights dimension in multi-head attention. This is set to hidden_size // num_attention_heads if not provided.

last_pipeline_num_layers: int = None

Number of transformer layers on last pipeline stage. None implies equal layer division across PP ranks.

layernorm_epsilon: float = 1e-05

Epsilon value for any LayerNorm operations.

layernorm_zero_centered_gamma: bool = False

If set to True, the LayerNorm is adjusted to center the gamma values around 0. This improves numerical stability.

masked_softmax_fusion: bool = False

If True, uses softmax fusion.

memory_efficient_layer_norm: bool = False

If True, and using local layers (not from TransformerEngine), tells Apex to use the memory efficient fused LayerNorm kernel. Ignored if not using LayerNorm.

moe_aux_loss_coeff: float = 0

Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended.

moe_expert_capacity_factor: float = None

The capacity factor for each expert, None means no token will be dropped. The default is None.

Type

moe_expert_capacity_factor (float)

moe_ffn_hidden_size: int = None

MoE Feed-Forward Network hidden size

moe_grouped_gemm: bool = False

When there are multiple experts per rank, compress multiple local (potentially small) gemms in a single kernel launch to improve the utilization and performance by leveraging the Grouped GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).

moe_input_jitter_eps: float = None

Add noise to the input tensor by applying jitter with a specified epsilon value.

moe_layer_freq: int = 1

Frequency between MoE layers and Dense layers. Accepts either: - An integer N: Represents a 1:N ratio, meaning one expert layer for every N-1 dense layers. - A string containing a Python list expression that defines a custom pattern, e.g.: “([1]*3+[0]*1)*3” evaluates to [1,1,1,0,1,1,1,0,1,1,1,0] where 1 indicates an expert layer and 0 indicates a dense layer.

moe_layer_recompute: bool = False

checkpointing moe_layer to save actiavtion memory.

Type

Memory optimization

moe_pad_expert_input_to_capacity: bool = False

If True, pads the input for each expert to match the expert capacity length, effective only after the moe_expert_capacity_factor is set. The default setting is False.

Type

moe_pad_expert_input_to_capacity (bool)

moe_per_layer_logging: bool = False

Enable per-layer logging for MoE, currently supports auxiliary loss and z loss.

moe_router_load_balancing_type: str = 'aux_loss'

The load balancing strategy for the router. “aux_loss” corresponds to the load balancing loss used in GShard and SwitchTransformer; “seq_aux_loss” corresponds to the loss used in DeepSeekV2, which computes the loss for each individual sample; “sinkhorn” corresponds to the balancing algorithm used in S-BASE, and “none” implies no load balancing. The default is “aux_loss”.

moe_router_pre_softmax: bool = False

Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.

moe_router_topk: int = 2

Number of experts to route to for each token.

moe_router_topk_limited_devices: int = None

Number of expert parallel ranks to consider for each token during routing. Perform top-k routing on a subset of expert parallel ranks by first selecting N ranks for each token, then conducting top-k selection among experts on these devices. None means no device limitation.

moe_router_topk_scaling_factor: float = None

Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax enabled. Defaults to None, which means no scaling.

moe_shared_expert_intermediate_size: int = None

Shared expert total ffn hidden size. It should be equal to ‘num_shared_experts * ffn_size_of_each_shared_expert’ if there are multiple shared experts. None means no shared expert.

moe_shared_expert_overlap: bool = False

Enable overlapping between shared expert computations and dispatcher communications. Without this, the shared epxerts execute after the routed experts.

moe_token_dispatcher_type: str = 'allgather'

The type of token dispatcher to use. The default is ‘allgather’. Options are ‘allgather’ and ‘alltoall’.

moe_token_drop_policy: str = 'probs'

The policy to drop tokens. Can be either “probs” or “position”. If “probs”, the tokens with the lowest probabilities will be dropped. If “position”, tokens at the end of each batch will be dropped.

moe_token_dropping: bool = False

This feature involves selectively dropping and padding tokens for each expert to achieve a specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is currently unsupported so should remain False.

moe_use_legacy_grouped_gemm: bool = False

Use legacy GroupedMLP rather than TEGroupedMLP. Note: The legacy one will be deprecated soon.

moe_z_loss_coeff: float = None

Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended.

multi_latent_attention: bool = False

Whether to use multi-latent attention.

normalization: bool = 'LayerNorm'

Which norm to use for normalization layers, valid options are LayerNorm and RMSNorm.

num_attention_heads: int = 0

Number of transformer attention heads.

num_layers: int = 0

Number of transformer layers in a transformer block.

num_moe_experts: int = None

Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None for no MoE.

num_query_groups: int = None

Number of query groups for group query attention. If None, normal attention is used.

output_layer_init_method: Callable = None

Method to initialize weights of the output layer of both attention and MLP blocks. If None, will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers).

persist_layer_norm: bool = False

If True, uses the persistent fused layer norm kernel. This kernel only supports a fixed set of hidden sizes.

qk_layernorm: bool = False

Whether to apply LayerNorm to the query and key embeddings.

recompute_granularity: str = None

Determines which type of activation recompute to use. Megatron-core supports ‘selective’ activation checkpointing where only the memory intensive part of attention is checkpointed. These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models (https://arxiv.org/abs/2205.05198) for more details. ‘full’ will checkpoint the entire transformer layer. If None, no recompute is performed and all activations are saved. If set, must be ‘selective’ or ‘full’. ‘selective’ always uses all layers.

recompute_method: str = None

Determines which transformer layers will be recomputed. uniform will uniformly divide the total number of transformer layers in a transformer block and recompute the input activation of each divided chunk at the specified granularity. block will recompute the input activations for only a set number of transformer layers per pipeline stage. The rest of the layers in the pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all layers will do recomputation. If set, must be ‘uniform’ or ‘block’.

recompute_num_layers: int = None

When recompute_method is uniform, recompute_num_layers is the number of transformer layers in each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is the number of transformer layers to recompute within each pipeline stage. Must be None for ‘selective’ activation checkpointing.

rotary_interleaved: bool = False

True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of first half and second half (LLaMa style). Default to False.

softmax_scale: float = None

Softmax scale for attention scaling.

test_mode: bool = False

Whether to run real-time tests.

tp_only_amax_red: bool = False

When set to True, reduce the FP8 AMAX only in the TP or TP-CP domain

use_te_rng_tracker: bool = False

Whether to use the TE or MCore version of the RNG tracker.

window_size: Optional[Tuple[int, int]] = None

If not None, then will use sliding window attention. The size of the window is specified by the numbers inside the tuple; -1 is special value meaning “infinite window size”.

A single standard transformer layer including attention and MLP blocks.

class core.transformer.transformer_layer.BaseTransformerLayer

Bases: abc.ABC

A common parent class for TransformerLayer like implementations.

A dummy class that is subclassed by similar TransformerLayer`s e.g. the `TransformerLayer in this file and possibly other TransformerLayer implementations that aim to use TransformerBlock as the base module. The main purpose is to check if any layer (or module) provided in the spec is a subclass of this class to allow fanning-out of that spec for all the layers in the TransformerBlock. See _get_block_submodules method implementation in transformer_block.py file for more details.

class core.transformer.transformer_layer.TransformerLayer(*args: Any, **kwargs: Any)

Bases: megatron.core.transformer.module.MegatronModule, core.transformer.transformer_layer.BaseTransformerLayer

A single transformer layer.

Transformer layer takes input with size [s, b, h] and returns an output of the same size.

forward(hidden_states, attention_mask=None, context=None, context_mask=None, rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, attention_bias=None, inference_params=None, packed_seq_params=None, sequence_len_offset=None)

Perform a forward pass through the transformer layer.

This method implements the core computation of a transformer layer, including self-attention, cross-attention (if applicable), and feed-forward operations.

Parameters
  • hidden_states (Tensor) – Input tensor of shape [s, b, h] where s is sequence length, b is batch size, and h is hidden size.

  • attention_mask (Tensor) – Mask tensor for self-attention.

  • context (Tensor, optional) – Context tensor for cross-attention.

  • context_mask (Tensor, optional) – Mask tensor for cross-attention.

  • rotary_pos_emb (Tensor, optional) – Rotary positional embeddings.

  • attention_bias (Tensor, optional) – Bias tensor for Q * K.T.

  • inference_params (object, optional) – Parameters for inference-time optimizations.

  • packed_seq_params (object, optional) – Parameters for packed sequence processing.

Returns

A tuple containing:

output (Tensor): Transformed hidden states of shape [s, b, h]. context (Tensor): Updated context tensor if cross-attention is used, otherwise None.

Return type

Tuple[Tensor, Tensor]

sharded_state_dict(prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None) → megatron.core.dist_checkpointing.mapping.ShardedStateDict

Generate a sharded state dictionary for the transformer layer.

Parameters
  • prefix (str, optional) – Prefix to be added to all keys in the state dict.

  • sharded_offsets (tuple, optional) – Tuple of sharding offsets.

  • metadata (Optional[dict], optional) – Additional metadata for sharding.

Returns

A dictionary containing the sharded state of the transformer layer.

Return type

ShardedStateDict

class core.transformer.transformer_layer.TransformerLayerSubmodules(input_layernorm: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, self_attention: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, self_attn_bda: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityFuncOp, pre_cross_attn_layernorm: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, cross_attention: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, cross_attn_bda: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityFuncOp, pre_mlp_layernorm: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, mlp: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, mlp_bda: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityFuncOp, sharded_state_dict_keys_map: typing.Dict[str, str] = <factory>)

Bases: object

Configuration class for specifying the submodules of a transformer layer.

This class defines the structure and default implementations for various components of a transformer layer, allowing for flexible customization of the layer’s architecture.

Parameters
  • input_layernorm (Union[ModuleSpec, type]) – Specification for the input layer normalization.

  • self_attention (Union[ModuleSpec, type]) – Specification for the self-attention mechanism.

  • self_attn_bda (Union[ModuleSpec, type]) – Specification for the bias-dropout-add operation after self-attention.

  • pre_cross_attn_layernorm (Union[ModuleSpec, type]) – Specification for the layer normalization before cross-attention.

  • cross_attention (Union[ModuleSpec, type]) – Specification for the cross-attention mechanism.

  • cross_attn_bda (Union[ModuleSpec, type]) – Specification for the bias-dropout-add operation after cross-attention.

  • pre_mlp_layernorm (Union[ModuleSpec, type]) – Specification for the layer normalization before the MLP.

  • mlp (Union[ModuleSpec, type]) – Specification for the MLP in Dense layer.

  • mlp_bda (Union[ModuleSpec, type]) – Specification for the bias-dropout-add operation after the MLP.

  • sharded_state_dict_keys_map (Dict[str, str]) – Mapping for sharded tensor keys to be applied in the sharded_state_dict method.

sharded_state_dict_keys_map: Dict[str, str]

Various utilities used in the transformer implementation.

Utilities for transformer layers.

core.transformer.utils.attention_mask_func(attention_scores, attention_mask)
core.transformer.utils.erf_gelu(x)
core.transformer.utils.gelu_impl(x)

OpenAI’s gelu implementation.

core.transformer.utils.get_default_causal_mask(sq: int) → torch.Tensor

Return the causal upper triangular mask for softmax input.

core.transformer.utils.get_linear_layer(rows, columns, init_method, perform_initialization=True)

Simple linear layer with weight initialization.

core.transformer.utils.make_sharded_object_for_checkpoint(obj: Any, key: str, sharded_offsets: Iterable[Tuple[int, int, int]] = (), replica_id: Union[None, int, Tuple[int, ...]] = None, **kwargs)

Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group).

Parameters
  • obj (object) – any object to be sharded

  • key (str) – unique identifier of the object

  • sharded_offsets (Iterable[Tuple[int, int, int]]) – offsets normally prepended to ShardedTensors, will be used as global offsets for ShardedObject

  • replica_id (Union[None, int, Tuple[int, ...]]) – replica id

core.transformer.utils.make_sharded_tensors_for_checkpoint(state_dict: megatron.core.dist_checkpointing.mapping.StateDict, prefix: str, tensor_parallel_layers_axis_map: Optional[Dict[str, int]] = None, sharded_offsets: Iterable[Tuple[int, int, int]] = (), extra_state_suffix: str = '_extra_state')

Wraps tensors from transformer layers with ShardedTensor or ShardedObject.

For a given state_dict, wraps: - all _extra_states with ShardedObject - all tensors specified in tensor_parallel_layers_axis_map with TP and DP sharded ShardedTensor - other values with DP sharded ShardedTensor

Parameters
  • state_dict (StateDict) – state_dict to convert

  • prefix (str) – prefix appended to keys in final state dict

  • tensor_parallel_layers_axis_map (Dict[str, int], optional) – dict mapping layer names to the axis for TP sharding

  • sharded_offsets (Iterable[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related), passed along to ShardedTensor

  • extra_state_suffix (str, default = '_extra_state') – layers with this suffix will be wrapped with ShardedObject instead of ShardedTensor.

core.transformer.utils.openai_gelu(x)
core.transformer.utils.sharded_state_dict_default(module: torch.nn.Module, prefix: str = '', sharded_offsets: Tuple[Tuple[int, int, int]] = (), metadata: Optional[dict] = None) → megatron.core.dist_checkpointing.mapping.ShardedStateDict

Provides implementation for sharded_state_dict method for non-MegatronModules.

Tries to call module.sharded_state_dict when possible, otherwise uses regular state dict and assumes tensors are replicated across TP and DP.

keep_vars=True is passed to module.state_dict so that optimizer states can be sharded later on.

Parameters
  • module (torch.nn.Module) – module which sharded state dict we want to obtain

  • prefix (str) – prefix for the state dict keys

  • sharded_offsets (Tuple[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor

  • metadata (dict, optional) – metadata passed to module sharded_state_dict method

Returns

dictionary of state dict keys mapped to ShardedTensors

Return type

dict

Previous fusions package
Next Mixture of Experts package
© Copyright 2022-2025, NVIDIA. Last updated on Jan 14, 2025.