transformer package
The transformer package provides a customizable and configurable implementation of the transformer model architecture. Each component of a transformer stack, from entire layers down to individual linear layers, can be customized by swapping in different PyTorch modules using the “spec” parameters (see here). The configuration of the transformer (hidden size, number of layers, number of attention heads, etc.) is provided via a TransformerConfig object.
This is the entire attention portion, either self or cross attention, of a transformer layer including the query, key, and value projections, a “core” attention calculation (e.g. dot product attention), and final output linear projection.
- class core.transformer.attention.Attention(*args: Any, **kwargs: Any)
Bases:
megatron.core.transformer.module.MegatronModule
,abc.ABC
Attention layer abstract class.
This layer only contains common modules required for the “self attn” and “cross attn” specializations.
- flash_decoding(sequence_len_offset: torch.Tensor, query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, inference_key_memory: torch.Tensor, inference_value_memory: torch.Tensor, rotary_cos: torch.Tensor, rotary_sin: torch.Tensor)
The flash decoding kernel will do the following in a single execution: 1. Compute RoPE embedding with precomputed cos & sin tensors 2. Update the KV Cache 3. Performs the flash attention operation
- forward(hidden_states, attention_mask, key_value_states=None, inference_params=None, rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, attention_bias=None, packed_seq_params=None, sequence_len_offset=None)
Perform a forward pass through the attention module.
- abstract get_query_key_value_tensors(hidden_states, key_value_states)
This method needs to be implemented based on whether the derived class is “self-attn” or “cross-attn”.
- class core.transformer.attention.CrossAttention(*args: Any, **kwargs: Any)
Bases:
core.transformer.attention.Attention
Cross-attention layer class
Cross-attention layer takes input with size [s, b, h] and context with size [s, b, h] and returns output of the same size.
- get_query_key_value_tensors(hidden_states, key_value_states)
Derives query tensor from hidden_states, and key/value tensors from key_value_states.
- class core.transformer.attention.CrossAttentionSubmodules(linear_q: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, linear_kv: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, core_attention: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, linear_proj: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None)
Bases:
object
Configuration class for specifying the submodules of a cross-attention.
- core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
- linear_kv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
- linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
- linear_q: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
- class core.transformer.attention.SelfAttention(*args: Any, **kwargs: Any)
Bases:
core.transformer.attention.Attention
Self-attention layer class
Self-attention layer takes input with size [s, b, h] and returns output of the same size.
- get_query_key_value_tensors(hidden_states, key_value_states=None)
Derives query, key and value tensors from hidden_states.
- run_realtime_tests()
Performs a consistency check.
This function makes sure that tensors across devices are the same during an experiment. This is often not guaranteed to be so because of silent hardware failures (eg, memory corruption loading a checkpoint, network traffic corruption encountered during data transmission).
(TODO) In the future, more tensors should be checked across the training run and checked every X iterations. This is left for future work. Equality of tensors is probably not required; transmitting hashes is sufficient.
- class core.transformer.attention.SelfAttentionSubmodules(linear_qkv: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, core_attention: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, linear_proj: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, q_layernorm: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None, k_layernorm: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, type]] = None)
Bases:
object
Configuration class for specifying the submodules of a self-attention.
- core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
- k_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
- linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
- linear_qkv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
- q_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
This is a PyTorch-only implementation of dot product attention. A more efficient implementation, like those provided by FlashAttention or CUDNN’s FusedAttention, are typically used when training speed is important.
- class core.transformer.dot_product_attention.DotProductAttention(*args: Any, **kwargs: Any)
Bases:
megatron.core.transformer.module.MegatronModule
Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
- We use the following notation:
h: hidden size n: number of attention heads p: number of tensor model parallel partitions b: batch size s: sequence length
- forward(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, attn_mask_type: Optional[megatron.core.transformer.enums.AttnMaskType] = None, attention_bias: Optional[torch.Tensor] = None, packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None)
Forward.
- class core.transformer.enums.AttnBackend(value)
Bases:
enum.Enum
Attention Backend
- auto = 5
- flash = 1
- fused = 2
- local = 4
- unfused = 3
- class core.transformer.enums.AttnMaskType(value)
Bases:
enum.Enum
Attention Mask Type
- arbitrary = 5
- causal = 2
- no_mask = 3
- padding = 1
- padding_causal = 4
- class core.transformer.enums.AttnType(value)
Bases:
enum.Enum
Attention type
- cross_attn = 2
- self_attn = 1
- class core.transformer.enums.ModelType(value)
Bases:
enum.Enum
Model Type
encoder_or_decoder for bert, gpt etc encoder_and_decoder for multimodal , T5 etc
- encoder_and_decoder = 2
- encoder_or_decoder = 1
This provides a pass-through module that can be used in specs to indicate that the operation should not be performed. For example, when using LayerNorm with the subsequent linear layer, an IdentityOp can be passed in as the LayerNorm module to use.
- class core.transformer.identity_op.IdentityFuncOp(*args: Any, **kwargs: Any)
Bases:
core.transformer.identity_op.IdentityOp
This is a placeholder for IdentityFuncOp(…)(x) -> IdentityOp(x) -> x. Such a func is handy for ops like bias_dropout_fusion which themselves return a function at runtime based on passed arguments
- forward(*args, **kwargs)
- class core.transformer.identity_op.IdentityOp(*args: Any, **kwargs: Any)
Bases:
torch.nn.Module
This is a placeholder for IdentityOp(x) -> x
- forward(x, *args, **kwargs)
This is the entire MLP portion of the transformer layer with an input projection, non-linearity, and output projection.
- class core.transformer.mlp.MLP(*args: Any, **kwargs: Any)
Bases:
megatron.core.transformer.module.MegatronModule
MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension.
Returns an output and a bias to be added to the output. If config.add_bias_linear is False, the bias returned is None.
- We use the following notation:
h: hidden size p: number of tensor model parallel partitions b: batch size s: sequence length
- forward(hidden_states)
Perform the forward pass through the MLP block.
- sharded_state_dict(prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None) → megatron.core.dist_checkpointing.mapping.ShardedStateDict
- class core.transformer.mlp.MLPSubmodules(linear_fc1: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None, linear_fc2: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None)
Bases:
object
- linear_fc1: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
- linear_fc2: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = None
- core.transformer.mlp.apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets)
This provides a common base class for all modules used in the transformer that contains some common functionality.
Megatron Module.
- class core.transformer.module.Float16Module(*args: Any, **kwargs: Any)
Bases:
core.transformer.module.MegatronModule
Float 16 Module.
- config
Transformer config
- Type
- fp16
Specifies if the model runs in fp16 mode
- Type
bool
- bf16
Specifies if the model runs in bf16 mode
- Type
bool
- Parameters
config (TransformerConfig) – The transformer config used to initalize the model
- forward(*inputs, **kwargs)
- load_state_dict(state_dict, strict=True)
- set_input_tensor(input_tensor)
- sharded_state_dict(prefix='', *args, **kwargs)
Retrieve sharded_state_dict from the module being wrapped.
- state_dict(destination=None, prefix='', keep_vars=False)
- state_dict_for_save_checkpoint(prefix='', keep_vars=False)
Retrieve state_dict from the module being wrapped.
- class core.transformer.module.MegatronModule(*args: Any, **kwargs: Any)
Bases:
torch.nn.Module
Base Megatron module inhertied by all Models.
Megatron specific extensions of torch Module with support for pipelining
- Parameters
config (TransformerConfig) – Transformer config
- set_is_first_microbatch()
Sets the is_first_microbatch flag if it exists and config.fp8==True. When this flag is set, TE modules will update their fp8 parameter cache.
- sharded_state_dict(prefix: str = '', sharded_offsets: Tuple[Tuple[int, int, int]] = (), metadata: Optional[dict] = None) → megatron.core.dist_checkpointing.mapping.ShardedStateDict
Default implementation for sharded state dict for distributed checkpointing.
General definition of sharded_state_dict simply calls sharded_state_dict_default (which call sharded_state_dict method if possible or a default implementation otherwise) recursively on all submodules.
- Parameters
prefix (str) – prefix for the state dict keys
sharded_offsets (Tuple[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor
metadata (dict, optional) – metadata passed recursively to sharded_state_dict methods
- Returns
dictionary of state dict keys mapped to ShardedTensors
- Return type
dict
- state_dict_for_save_checkpoint(prefix: str = '', keep_vars: bool = False)
Override state dict for saving checkpoints Use this function to override the state dict for saving checkpoints.
- Parameters
prefix (str, optional) – _description_. Defaults to ‘’.
keep_vars (bool, optional) – _description_. Defaults to False.
- Returns
_description_
- Return type
_type_
- core.transformer.module.conversion_helper(val, conversion)
- core.transformer.module.float16_to_fp32(val)
- core.transformer.module.fp32_to_float16(val, float16_convertor)
A block, or stack, of several transformer layers. The layers can all be the same or each can be unique.
- class core.transformer.transformer_block.TransformerBlock(*args: Any, **kwargs: Any)
Bases:
megatron.core.transformer.module.MegatronModule
Transformer class.
- forward(hidden_states: torch.Tensor, attention_mask: torch.Tensor, context: Optional[torch.Tensor] = None, context_mask: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None, rotary_pos_cos: Optional[torch.Tensor] = None, rotary_pos_sin: Optional[torch.Tensor] = None, attention_bias: Optional[torch.Tensor] = None, inference_params: Optional[megatron.core.InferenceParams] = None, packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None, sequence_len_offset: Optional[torch.Tensor] = None)
Perform the forward pass through the transformer block.
This method handles the core computation of the transformer, including self-attention, optional cross-attention, and feed-forward operations.
- Parameters
hidden_states (Tensor) – Input tensor of shape [s, b, h] where s is the sequence length, b is the batch size, and h is the hidden size.
attention_mask (Tensor) – Boolean tensor of shape [1, 1, s, s] for masking self-attention.
context (Tensor, optional) – Context tensor for cross-attention.
context_mask (Tensor, optional) – Mask for cross-attention context
rotary_pos_emb (Tensor, optional) – Rotary positional embeddings.
attention_bias (Tensor) – Bias tensor for Q * K.T of shape in shape broadcastable to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. Used as an alternative to apply attention mask for TE cuDNN attention.
inference_params (InferenceParams, optional) – Parameters for inference-time optimizations.
packed_seq_params (PackedSeqParams, optional) – Parameters for packed sequence processing.
- Returns
The output hidden states tensor of shape [s, b, h], and optionally the updated context tensor if cross-attention is used.
- Return type
Union[Tensor, Tuple[Tensor, Tensor]]
- get_cuda_graph_optional_args(attention_mask: torch.Tensor, context: torch.Tensor, context_mask: torch.Tensor, rotary_pos_emb: torch.Tensor, attention_bias: torch.Tensor, inference_params: megatron.core.InferenceParams, packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams)
Get optional tensor arguments for CUDA graph.
- set_input_tensor(input_tensor: torch.Tensor)
Set input tensor to be used instead of forward()’s input.
When doing pipeline parallelism the input from the previous stage comes from communication, not from the input, so the model’s forward_step_func won’t have it. This function is thus used by internal code to bypass the input provided by the forward_step_func
- sharded_state_dict(prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None) → megatron.core.dist_checkpointing.mapping.ShardedStateDict
Generate a sharded state dictionary for the transformer block.
- Parameters
prefix (str, optional) – Prefix to be added to all keys in the state dict. Defaults to an empty string.
sharded_offsets (tuple, optional) – Tuple of sharding offsets.
metadata (dict, optional) – Additional metadata for sharding. Can specify if layers are non-homogeneous. Defaults to None.
- Returns
A dictionary containing the sharded state of the model.
- Return type
ShardedStateDict
- class core.transformer.transformer_block.TransformerBlockSubmodules(layer_specs: Optional[List[megatron.core.transformer.spec_utils.ModuleSpec]] = None, layer_norm: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, torch.nn.Module]] = None)
Bases:
object
Dataclass for specifying the submodules of a transformer block.
This class defines the structure for configuring the layers and normalization within a transformer block, allowing for flexible and customizable architecture designs.
- Parameters
layer_specs (List[ModuleSpec], optional) – A list of module specifications for the layers within the transformer block. Each specification typically defines a complete transformer layer (e.g., self-attention, feed-forward network).
layer_norm (Optional[Union[ModuleSpec, torch.nn.Module]], optional) – Specification or instance of the layer normalization to be applied.
- layer_norm: Optional[Union[megatron.core.transformer.spec_utils.ModuleSpec, torch.nn.Module]] = None
- layer_specs: List[megatron.core.transformer.spec_utils.ModuleSpec] = None
- core.transformer.transformer_block.get_num_layers_to_build(config: megatron.core.transformer.transformer_config.TransformerConfig) → int
Determine the number of transformer layers to build for the current pipeline stage. :param config: Configuration object containing transformer model parameters. :type config: TransformerConfig
- Returns
The number of layers to be built for the current pipeline stage.
- Return type
int
This contains all of the configuration options for the transformer. Using a dataclass reduces code bloat by keeping all arguments together in a dataclass instead of passing several arguments through multiple layers of function calls.
- class core.transformer.transformer_config.MLATransformerConfig(tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, virtual_pipeline_model_parallel_size: Optional[int] = None, sequence_parallel: bool = False, context_parallel_size: int = 1, hierarchical_context_parallel_sizes: Optional[list[int]] = None, expert_model_parallel_size: int = 1, expert_tensor_parallel_size: Optional[int] = None, moe_extended_tp: bool = False, perform_initialization: bool = True, use_cpu_initialization: bool = False, fp16: bool = False, bf16: bool = False, params_dtype: torch.dtype = torch.float32, timers: Optional[Callable] = None, finalize_model_grads_func: Optional[Callable] = None, grad_scale_func: Optional[Callable] = None, no_sync_func: Optional[Callable] = None, grad_sync_func: Optional[Callable] = None, param_sync_func: Optional[Callable] = None, deterministic_mode: bool = False, enable_autocast: bool = False, autocast_dtype: Optional[torch.dtype] = None, num_microbatches_with_partial_activation_checkpoints: Optional[int] = None, gradient_accumulation_fusion: bool = False, async_tensor_model_parallel_allreduce: bool = False, use_te_rng_tracker: bool = False, tp_comm_overlap: bool = False, tp_comm_bulk_wgrad: bool = True, tp_comm_bulk_dgrad: bool = True, tp_comm_overlap_ag: bool = True, tp_comm_overlap_rs: bool = True, tp_comm_overlap_rs_dgrad: bool = False, tp_comm_split_ag: bool = True, tp_comm_atomic_ag: bool = False, tp_comm_split_rs: bool = True, tp_comm_atomic_rs: bool = False, cross_entropy_loss_fusion: bool = False, tp_comm_overlap_disable_qkv: bool = False, tp_comm_overlap_disable_fc1: bool = False, tp_comm_bootstrap_backend: str = 'nccl', pipeline_dtype: Optional[torch.dtype] = None, variable_seq_lengths: bool = False, overlap_p2p_comm: bool = False, batch_p2p_comm: bool = True, batch_p2p_sync: bool = True, use_ring_exchange_p2p: bool = False, deallocate_pipeline_outputs: bool = False, defer_embedding_wgrad_compute: bool = False, wgrad_deferral_limit: int = 0, pipeline_model_parallel_split_rank: Optional[int] = None, overlap_p2p_comm_warmup_flush: bool = False, microbatch_group_size_per_vp_stage: Optional[int] = None, cpu_offloading: bool = False, cpu_offloading_num_layers: int = 0, _cpu_offloading_context: Optional[ContextManager] = None, cpu_offloading_activations: bool = True, cpu_offloading_weights: bool = True, barrier_with_L1_time: bool = True, num_layers: int = 0, first_pipeline_num_layers: Optional[int] = None, last_pipeline_num_layers: Optional[int] = None, hidden_size: int = 0, num_attention_heads: int = 0, attention_backend: megatron.core.transformer.enums.AttnBackend = megatron.core.transformer.enums.AttnBackend.auto, softmax_scale: Optional[float] = None, num_query_groups: Optional[int] = None, ffn_hidden_size: Optional[int] = None, kv_channels: Optional[int] = None, hidden_dropout: float = 0.1, attention_dropout: float = 0.1, fp32_residual_connection: bool = False, apply_residual_connection_post_layernorm: bool = False, layernorm_epsilon: float = 1e-05, layernorm_zero_centered_gamma: bool = False, add_bias_linear: bool = True, add_qkv_bias: bool = False, gated_linear_unit: bool = False, activation_func: Callable = torch.nn.functional.gelu, activation_func_fp8_input_store: bool = False, num_moe_experts: Optional[int] = None, rotary_interleaved: bool = False, window_size: Optional[Tuple[int, int]] = None, normalization: str = 'RMSNorm', qk_layernorm: bool = False, test_mode: bool = False, calculate_per_token_loss: bool = False, multi_latent_attention: bool = True, init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None, init_method_std: float = 0.02, apply_query_key_layer_scaling: bool = False, attention_softmax_in_fp32: bool = True, bias_activation_fusion: bool = False, masked_softmax_fusion: bool = False, persist_layer_norm: bool = False, memory_efficient_layer_norm: bool = False, bias_dropout_fusion: bool = False, apply_rope_fusion: bool = False, recompute_granularity: Optional[str] = None, recompute_method: Optional[str] = None, recompute_num_layers: Optional[int] = None, distribute_saved_activations: Optional[bool] = None, fp8: Optional[str] = None, fp8_margin: int = 0, fp8_interval: int = 1, fp8_amax_history_len: int = 1, fp8_amax_compute_algo: str = 'most_recent', fp8_wgrad: bool = True, fp8_dot_product_attention: bool = False, fp8_multi_head_attention: bool = False, tp_only_amax_red: bool = False, moe_shared_expert_intermediate_size: Optional[int] = None, moe_shared_expert_overlap: bool = False, moe_layer_freq: int = 1, moe_ffn_hidden_size: Optional[int] = None, moe_router_load_balancing_type: str = 'aux_loss', moe_router_topk: int = 2, moe_router_topk_limited_devices: Optional[int] = None, moe_router_pre_softmax: bool = False, moe_router_topk_scaling_factor: Optional[float] = None, moe_grouped_gemm: bool = False, moe_use_legacy_grouped_gemm: bool = False, moe_aux_loss_coeff: float = 0, moe_z_loss_coeff: Optional[float] = None, moe_input_jitter_eps: Optional[float] = None, moe_token_dropping: bool = False, moe_token_dispatcher_type: str = 'allgather', moe_per_layer_logging: bool = False, moe_expert_capacity_factor: Optional[float] = None, moe_pad_expert_input_to_capacity: bool = False, moe_token_drop_policy: str = 'probs', moe_layer_recompute: bool = False, cp_comm_type: Optional[Union[str, List[str]]] = None, clone_scatter_output_in_embedding: bool = True, disable_parameter_transpose_cache: bool = False, enable_cuda_graph: bool = False, cuda_graph_retain_backward_graph: bool = False, external_cuda_graph: bool = False, config_logger_dir: str = '', flash_decode: bool = False, inference_rng_tracker: bool = False, q_lora_rank: int = 512, kv_lora_rank: int = 512, qk_head_dim: int = 128, qk_pos_emb_head_dim: int = 64, v_head_dim: int = 128, rotary_base: float = 10000, rotary_scaling_factor: float = 40, max_position_embeddings: int = 163840, beta_fast: float = 32, beta_slow: float = 1, mscale: float = 0.707, mscale_all_dim: float = 0.707)
Bases:
core.transformer.transformer_config.TransformerConfig
Configuration object for megatron-core Multi-Latent Attention (MLA) transformers.
The initialization function has an argument for each parameter, including those in ModelParallelConfig. Included YaRN RoPE parameters that is fused in MLA.
- beta_fast: float = 32
Beta fast for YaRN RoPE.
- beta_slow: float = 1
Beta slow for YaRN RoPE.
- kv_lora_rank: int = 512
Rank of Key and Value tensors’ low rank representation.
- max_position_embeddings: int = 163840
Maximum position embeddings for the original model.
- mscale: float = 0.707
Mscale for YaRN RoPE in Multi-Latent Attention.
- mscale_all_dim: float = 0.707
Mscale all dimensions for YaRN RoPE in Multi-Latent Attention.
- multi_latent_attention: bool = True
Whether to use Multi-Latent Attention.
- normalization: str = 'RMSNorm'
Default normalization layer for MLA models is RMSNorm.
- q_lora_rank: int = 512
Rank of Query tensor’s low rank representation.
- qk_head_dim: int = 128
Dimension of the head in the QK projection. q_head_dim = qk_head_dim + qk_pos_emb_head_dim
- qk_pos_emb_head_dim: int = 64
Dimension of the position embedding in the QK projection.
- rotary_base: float = 10000
Rotary base for the rotary embeddings.
- rotary_scaling_factor: float = 40
Rotary scaling factor for the rotary embeddings.
- v_head_dim: int = 128
Dimension of the head in the V projection.
- class core.transformer.transformer_config.TransformerConfig(tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, virtual_pipeline_model_parallel_size: Optional[int] = None, sequence_parallel: bool = False, context_parallel_size: int = 1, hierarchical_context_parallel_sizes: Optional[list[int]] = None, expert_model_parallel_size: int = 1, expert_tensor_parallel_size: Optional[int] = None, moe_extended_tp: bool = False, perform_initialization: bool = True, use_cpu_initialization: bool = False, fp16: bool = False, bf16: bool = False, params_dtype: torch.dtype = torch.float32, timers: Optional[Callable] = None, finalize_model_grads_func: Optional[Callable] = None, grad_scale_func: Optional[Callable] = None, no_sync_func: Optional[Callable] = None, grad_sync_func: Optional[Callable] = None, param_sync_func: Optional[Callable] = None, deterministic_mode: bool = False, enable_autocast: bool = False, autocast_dtype: Optional[torch.dtype] = None, num_microbatches_with_partial_activation_checkpoints: Optional[int] = None, gradient_accumulation_fusion: bool = False, async_tensor_model_parallel_allreduce: bool = False, use_te_rng_tracker: bool = False, tp_comm_overlap: bool = False, tp_comm_bulk_wgrad: bool = True, tp_comm_bulk_dgrad: bool = True, tp_comm_overlap_ag: bool = True, tp_comm_overlap_rs: bool = True, tp_comm_overlap_rs_dgrad: bool = False, tp_comm_split_ag: bool = True, tp_comm_atomic_ag: bool = False, tp_comm_split_rs: bool = True, tp_comm_atomic_rs: bool = False, cross_entropy_loss_fusion: bool = False, tp_comm_overlap_disable_qkv: bool = False, tp_comm_overlap_disable_fc1: bool = False, tp_comm_bootstrap_backend: str = 'nccl', pipeline_dtype: Optional[torch.dtype] = None, variable_seq_lengths: bool = False, overlap_p2p_comm: bool = False, batch_p2p_comm: bool = True, batch_p2p_sync: bool = True, use_ring_exchange_p2p: bool = False, deallocate_pipeline_outputs: bool = False, defer_embedding_wgrad_compute: bool = False, wgrad_deferral_limit: int = 0, pipeline_model_parallel_split_rank: Optional[int] = None, overlap_p2p_comm_warmup_flush: bool = False, microbatch_group_size_per_vp_stage: Optional[int] = None, cpu_offloading: bool = False, cpu_offloading_num_layers: int = 0, _cpu_offloading_context: Optional[ContextManager] = None, cpu_offloading_activations: bool = True, cpu_offloading_weights: bool = True, barrier_with_L1_time: bool = True, num_layers: int = 0, first_pipeline_num_layers: Optional[int] = None, last_pipeline_num_layers: Optional[int] = None, hidden_size: int = 0, num_attention_heads: int = 0, attention_backend: megatron.core.transformer.enums.AttnBackend = megatron.core.transformer.enums.AttnBackend.auto, softmax_scale: Optional[float] = None, num_query_groups: Optional[int] = None, ffn_hidden_size: Optional[int] = None, kv_channels: Optional[int] = None, hidden_dropout: float = 0.1, attention_dropout: float = 0.1, fp32_residual_connection: bool = False, apply_residual_connection_post_layernorm: bool = False, layernorm_epsilon: float = 1e-05, layernorm_zero_centered_gamma: bool = False, add_bias_linear: bool = True, add_qkv_bias: bool = False, gated_linear_unit: bool = False, activation_func: Callable = torch.nn.functional.gelu, activation_func_fp8_input_store: bool = False, num_moe_experts: Optional[int] = None, rotary_interleaved: bool = False, window_size: Optional[Tuple[int, int]] = None, normalization: bool = 'LayerNorm', qk_layernorm: bool = False, test_mode: bool = False, calculate_per_token_loss: bool = False, multi_latent_attention: bool = False, init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None, init_method_std: float = 0.02, apply_query_key_layer_scaling: bool = False, attention_softmax_in_fp32: bool = True, bias_activation_fusion: bool = False, masked_softmax_fusion: bool = False, persist_layer_norm: bool = False, memory_efficient_layer_norm: bool = False, bias_dropout_fusion: bool = False, apply_rope_fusion: bool = False, recompute_granularity: Optional[str] = None, recompute_method: Optional[str] = None, recompute_num_layers: Optional[int] = None, distribute_saved_activations: Optional[bool] = None, fp8: Optional[str] = None, fp8_margin: int = 0, fp8_interval: int = 1, fp8_amax_history_len: int = 1, fp8_amax_compute_algo: str = 'most_recent', fp8_wgrad: bool = True, fp8_dot_product_attention: bool = False, fp8_multi_head_attention: bool = False, tp_only_amax_red: bool = False, moe_shared_expert_intermediate_size: Optional[int] = None, moe_shared_expert_overlap: bool = False, moe_layer_freq: int = 1, moe_ffn_hidden_size: Optional[int] = None, moe_router_load_balancing_type: str = 'aux_loss', moe_router_topk: int = 2, moe_router_topk_limited_devices: Optional[int] = None, moe_router_pre_softmax: bool = False, moe_router_topk_scaling_factor: Optional[float] = None, moe_grouped_gemm: bool = False, moe_use_legacy_grouped_gemm: bool = False, moe_aux_loss_coeff: float = 0, moe_z_loss_coeff: Optional[float] = None, moe_input_jitter_eps: Optional[float] = None, moe_token_dropping: bool = False, moe_token_dispatcher_type: str = 'allgather', moe_per_layer_logging: bool = False, moe_expert_capacity_factor: Optional[float] = None, moe_pad_expert_input_to_capacity: bool = False, moe_token_drop_policy: str = 'probs', moe_layer_recompute: bool = False, cp_comm_type: Optional[Union[str, List[str]]] = None, clone_scatter_output_in_embedding: bool = True, disable_parameter_transpose_cache: bool = False, enable_cuda_graph: bool = False, cuda_graph_retain_backward_graph: bool = False, external_cuda_graph: bool = False, config_logger_dir: str = '', flash_decode: bool = False, inference_rng_tracker: bool = False)
Bases:
core.model_parallel_config.ModelParallelConfig
Configuration object for megatron-core transformers.
The initialization function has an argument for each parameter, including those in ModelParallelConfig.
- activation_func: Callable
Activation function to use for the non-linearity in the MLP.
- activation_func_fp8_input_store: bool = False
Store the input of MLP activation function in FP8 for backprop to save memory. The stored input is casted back to the original precision before backprop compuatation.
- add_bias_linear: bool = True
Include a bias term in all linear layers (QKV projections, after core attention, and two in MLP layer).
- add_qkv_bias: bool = False
Add a bias term only for QKV projections.
- apply_query_key_layer_scaling: bool = False
If true, scale Q * K^T by 1 / layer-number. This improve numeric stability when training with fp16.
- apply_residual_connection_post_layernorm: bool = False
If True, uses the original BERT residule connection ordering.
- apply_rope_fusion: bool = False
If True, use fused RoPE kernel.
- attention_backend: megatron.core.transformer.enums.AttnBackend
Attention backend to run. By default we let transformer engine decide the best backend to run (except in the case of local). If attention backend is local we use the local pytorch implementation in mcore. Users can specify exact backend by changing this config.
- attention_dropout: float = 0.1
Post attention dropout probability.
- attention_softmax_in_fp32: bool = True
If True, run attention masking and softmax in fp32. This should be True if apply_query_key_layer_scaling is True.
- bias_activation_fusion: bool = False
If True, fuses bias addition and the activation function when possible.
- bias_dropout_fusion: bool = False
If True, uses bias dropout fusion.
- calculate_per_token_loss: bool = False
Whether cross entropy loss is calculated over the actual number of non-padded tokens in the global batch, versus the default behavior of assuming all tokens are non-padded.
- clone_scatter_output_in_embedding: bool = True
When set to True, clone the output of scatter_to_sequence_parallel_region in embedding layer to facilitate garbage collection of input.
- config_logger_dir: str = ''
When non-empty, dumps entry-point configs to config_logger_dir
- cp_comm_type: Union[str, List[str]] = None
Inter-gpu communication type for context parallelism. str: all layers share same communication type. List[str]: each layer has its separate communication type. cp_comm_type of each layer can be “p2p” or “all_gather” or “a2a” or “a2a+p2p”. “p2p”: Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. “all_gather”: All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. “a2a”: Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. “a2a+p2p”: A hierarchical implementation of context parallelism to attention. It uses A2A communications in low-level CP groups (e.g., via NVLink), and P2P communications in high-level CP groups (e.g., via IBLink).
- cuda_graph_retain_backward_graph: bool = False
When set to true, cudagraph backward passes will be graph captured with ‘retain_grad=True’ This may enable cudagraphs for certain modules that are not completely cudagraph safe. For more details, see: https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html.
- disable_parameter_transpose_cache: bool = False
When set to true, the parameter transposes are not cached for subsequent iterations.
- distribute_saved_activations: bool = None
If True, distribute recomputed activations across the model parallel group.
- enable_cuda_graph: bool = False
When set to true, TransformerLayer layers are swapped with a CUDA graphed version.
- external_cuda_graph: bool = False
When set to true, TransformerLayer layers are swapped with user provided CUDA graphs.
Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size if not provided.
- first_pipeline_num_layers: int = None
Number of transformer layers on first pipeline stage. None implies equal layer division across PP ranks.
- flash_decode: bool = False
Use the optimized flash decoding kernel during inference.
- fp32_residual_connection: bool = False
If true, move residual connections to fp32.
- fp8: str = None
If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined choices (1) ‘e4m3’ uniformly uses e4m3 for all FP8 tensors, (2) ‘hybrid’ uses e4m3 for all FP8 activation and weight tensors and e5m2 for all FP8 output activation gradient tensors.
- fp8_amax_compute_algo: str = 'most_recent'
Algorithm used for choosing the amax value for the scaling factor computation. There are 2 predefined choices: max chooses the largest amax in the history window, while most_recent always chooses the most recently seen value.
- fp8_amax_history_len: int = 1
The length of the amax history window used for scaling factor computation.
- fp8_dot_product_attention: bool = False
When set to True, use the FP8 implementation of Dot Product Attention.
- fp8_interval: int = 1
DEPRECATED from TransformerEngine v1.8.0. This flag is ignored. Controls how often the scaling factor is recomputed.
- fp8_margin: int = 0
Margin for the scaling factor computation.
- fp8_multi_head_attention: bool = False
When set to True, use the FP8 implementation of Multi Head Attention.
- fp8_wgrad: bool = True
When set to False, override FP8 config options and do the wgrad computation in higher precision.
- gated_linear_unit: bool = False
Use a gated linear unit for the first linear layer in the MLP.
Dropout probability for transformer hidden state.
Transformer hidden size.
- inference_rng_tracker: bool = False
Whether we should instantiate a separate RNG tracker for inference.
- init_method: Callable = None
Method to initialize weights. Note that bias is always set to zero. Should be a function that takes a single Tensor and initializes it. If None, will be set to megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with mean=0.0 and std=init_method_std.
- init_method_std: float = 0.02
Standard deviation of the zero mean normal for the default initialization method, not used if init_method and output_layer_init_method are provided.
- kv_channels: int = None
Projection weights dimension in multi-head attention. This is set to hidden_size // num_attention_heads if not provided.
- last_pipeline_num_layers: int = None
Number of transformer layers on last pipeline stage. None implies equal layer division across PP ranks.
- layernorm_epsilon: float = 1e-05
Epsilon value for any LayerNorm operations.
- layernorm_zero_centered_gamma: bool = False
If set to True, the LayerNorm is adjusted to center the gamma values around 0. This improves numerical stability.
- masked_softmax_fusion: bool = False
If True, uses softmax fusion.
- memory_efficient_layer_norm: bool = False
If True, and using local layers (not from TransformerEngine), tells Apex to use the memory efficient fused LayerNorm kernel. Ignored if not using LayerNorm.
- moe_aux_loss_coeff: float = 0
Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended.
- moe_expert_capacity_factor: float = None
The capacity factor for each expert, None means no token will be dropped. The default is None.
- Type
moe_expert_capacity_factor (float)
MoE Feed-Forward Network hidden size
- moe_grouped_gemm: bool = False
When there are multiple experts per rank, compress multiple local (potentially small) gemms in a single kernel launch to improve the utilization and performance by leveraging the Grouped GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).
- moe_input_jitter_eps: float = None
Add noise to the input tensor by applying jitter with a specified epsilon value.
- moe_layer_freq: int = 1
Frequency between MoE layers and Dense layers. Accepts either: - An integer N: Represents a 1:N ratio, meaning one expert layer for every N-1 dense layers. - A string containing a Python list expression that defines a custom pattern, e.g.: “([1]*3+[0]*1)*3” evaluates to [1,1,1,0,1,1,1,0,1,1,1,0] where 1 indicates an expert layer and 0 indicates a dense layer.
- moe_layer_recompute: bool = False
checkpointing moe_layer to save actiavtion memory.
- Type
Memory optimization
- moe_pad_expert_input_to_capacity: bool = False
If True, pads the input for each expert to match the expert capacity length, effective only after the moe_expert_capacity_factor is set. The default setting is False.
- Type
moe_pad_expert_input_to_capacity (bool)
- moe_per_layer_logging: bool = False
Enable per-layer logging for MoE, currently supports auxiliary loss and z loss.
- moe_router_load_balancing_type: str = 'aux_loss'
The load balancing strategy for the router. “aux_loss” corresponds to the load balancing loss used in GShard and SwitchTransformer; “seq_aux_loss” corresponds to the loss used in DeepSeekV2, which computes the loss for each individual sample; “sinkhorn” corresponds to the balancing algorithm used in S-BASE, and “none” implies no load balancing. The default is “aux_loss”.
- moe_router_pre_softmax: bool = False
Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.
- moe_router_topk: int = 2
Number of experts to route to for each token.
- moe_router_topk_limited_devices: int = None
Number of expert parallel ranks to consider for each token during routing. Perform top-k routing on a subset of expert parallel ranks by first selecting N ranks for each token, then conducting top-k selection among experts on these devices. None means no device limitation.
- moe_router_topk_scaling_factor: float = None
Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax enabled. Defaults to None, which means no scaling.
Shared expert total ffn hidden size. It should be equal to ‘num_shared_experts * ffn_size_of_each_shared_expert’ if there are multiple shared experts. None means no shared expert.
Enable overlapping between shared expert computations and dispatcher communications. Without this, the shared epxerts execute after the routed experts.
- moe_token_dispatcher_type: str = 'allgather'
The type of token dispatcher to use. The default is ‘allgather’. Options are ‘allgather’ and ‘alltoall’.
- moe_token_drop_policy: str = 'probs'
The policy to drop tokens. Can be either “probs” or “position”. If “probs”, the tokens with the lowest probabilities will be dropped. If “position”, tokens at the end of each batch will be dropped.
- moe_token_dropping: bool = False
This feature involves selectively dropping and padding tokens for each expert to achieve a specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is currently unsupported so should remain False.
- moe_use_legacy_grouped_gemm: bool = False
Use legacy GroupedMLP rather than TEGroupedMLP. Note: The legacy one will be deprecated soon.
- moe_z_loss_coeff: float = None
Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended.
- multi_latent_attention: bool = False
Whether to use multi-latent attention.
- normalization: bool = 'LayerNorm'
Which norm to use for normalization layers, valid options are LayerNorm and RMSNorm.
- num_attention_heads: int = 0
Number of transformer attention heads.
- num_layers: int = 0
Number of transformer layers in a transformer block.
- num_moe_experts: int = None
Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None for no MoE.
- num_query_groups: int = None
Number of query groups for group query attention. If None, normal attention is used.
- output_layer_init_method: Callable = None
Method to initialize weights of the output layer of both attention and MLP blocks. If None, will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers).
- persist_layer_norm: bool = False
If True, uses the persistent fused layer norm kernel. This kernel only supports a fixed set of hidden sizes.
- qk_layernorm: bool = False
Whether to apply LayerNorm to the query and key embeddings.
- recompute_granularity: str = None
Determines which type of activation recompute to use. Megatron-core supports ‘selective’ activation checkpointing where only the memory intensive part of attention is checkpointed. These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models (https://arxiv.org/abs/2205.05198) for more details. ‘full’ will checkpoint the entire transformer layer. If None, no recompute is performed and all activations are saved. If set, must be ‘selective’ or ‘full’. ‘selective’ always uses all layers.
- recompute_method: str = None
Determines which transformer layers will be recomputed. uniform will uniformly divide the total number of transformer layers in a transformer block and recompute the input activation of each divided chunk at the specified granularity. block will recompute the input activations for only a set number of transformer layers per pipeline stage. The rest of the layers in the pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all layers will do recomputation. If set, must be ‘uniform’ or ‘block’.
- recompute_num_layers: int = None
When recompute_method is uniform, recompute_num_layers is the number of transformer layers in each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is the number of transformer layers to recompute within each pipeline stage. Must be None for ‘selective’ activation checkpointing.
- rotary_interleaved: bool = False
True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of first half and second half (LLaMa style). Default to False.
- softmax_scale: float = None
Softmax scale for attention scaling.
- test_mode: bool = False
Whether to run real-time tests.
- tp_only_amax_red: bool = False
When set to True, reduce the FP8 AMAX only in the TP or TP-CP domain
- use_te_rng_tracker: bool = False
Whether to use the TE or MCore version of the RNG tracker.
- window_size: Optional[Tuple[int, int]] = None
If not None, then will use sliding window attention. The size of the window is specified by the numbers inside the tuple; -1 is special value meaning “infinite window size”.
A single standard transformer layer including attention and MLP blocks.
- class core.transformer.transformer_layer.BaseTransformerLayer
Bases:
abc.ABC
A common parent class for TransformerLayer like implementations.
A dummy class that is subclassed by similar TransformerLayer`s e.g. the `TransformerLayer in this file and possibly other TransformerLayer implementations that aim to use TransformerBlock as the base module. The main purpose is to check if any layer (or module) provided in the spec is a subclass of this class to allow fanning-out of that spec for all the layers in the TransformerBlock. See _get_block_submodules method implementation in transformer_block.py file for more details.
- class core.transformer.transformer_layer.TransformerLayer(*args: Any, **kwargs: Any)
Bases:
megatron.core.transformer.module.MegatronModule
,core.transformer.transformer_layer.BaseTransformerLayer
A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an output of the same size.
- forward(hidden_states, attention_mask=None, context=None, context_mask=None, rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, attention_bias=None, inference_params=None, packed_seq_params=None, sequence_len_offset=None)
Perform a forward pass through the transformer layer.
This method implements the core computation of a transformer layer, including self-attention, cross-attention (if applicable), and feed-forward operations.
- Parameters
hidden_states (Tensor) – Input tensor of shape [s, b, h] where s is sequence length, b is batch size, and h is hidden size.
attention_mask (Tensor) – Mask tensor for self-attention.
context (Tensor, optional) – Context tensor for cross-attention.
context_mask (Tensor, optional) – Mask tensor for cross-attention.
rotary_pos_emb (Tensor, optional) – Rotary positional embeddings.
attention_bias (Tensor, optional) – Bias tensor for Q * K.T.
inference_params (object, optional) – Parameters for inference-time optimizations.
packed_seq_params (object, optional) – Parameters for packed sequence processing.
- Returns
- A tuple containing:
output (Tensor): Transformed hidden states of shape [s, b, h]. context (Tensor): Updated context tensor if cross-attention is used, otherwise None.
- Return type
Tuple[Tensor, Tensor]
- sharded_state_dict(prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None) → megatron.core.dist_checkpointing.mapping.ShardedStateDict
Generate a sharded state dictionary for the transformer layer.
- Parameters
prefix (str, optional) – Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional) – Tuple of sharding offsets.
metadata (Optional[dict], optional) – Additional metadata for sharding.
- Returns
A dictionary containing the sharded state of the transformer layer.
- Return type
ShardedStateDict
- class core.transformer.transformer_layer.TransformerLayerSubmodules(input_layernorm: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, self_attention: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, self_attn_bda: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityFuncOp, pre_cross_attn_layernorm: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, cross_attention: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, cross_attn_bda: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityFuncOp, pre_mlp_layernorm: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, mlp: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityOp, mlp_bda: typing.Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = megatron.core.transformer.identity_op.IdentityFuncOp, sharded_state_dict_keys_map: typing.Dict[str, str] = <factory>)
Bases:
object
Configuration class for specifying the submodules of a transformer layer.
This class defines the structure and default implementations for various components of a transformer layer, allowing for flexible customization of the layer’s architecture.
- Parameters
input_layernorm (Union[ModuleSpec, type]) – Specification for the input layer normalization.
self_attention (Union[ModuleSpec, type]) – Specification for the self-attention mechanism.
self_attn_bda (Union[ModuleSpec, type]) – Specification for the bias-dropout-add operation after self-attention.
pre_cross_attn_layernorm (Union[ModuleSpec, type]) – Specification for the layer normalization before cross-attention.
cross_attention (Union[ModuleSpec, type]) – Specification for the cross-attention mechanism.
cross_attn_bda (Union[ModuleSpec, type]) – Specification for the bias-dropout-add operation after cross-attention.
pre_mlp_layernorm (Union[ModuleSpec, type]) – Specification for the layer normalization before the MLP.
mlp (Union[ModuleSpec, type]) – Specification for the MLP in Dense layer.
mlp_bda (Union[ModuleSpec, type]) – Specification for the bias-dropout-add operation after the MLP.
sharded_state_dict_keys_map (Dict[str, str]) – Mapping for sharded tensor keys to be applied in the sharded_state_dict method.
- sharded_state_dict_keys_map: Dict[str, str]
Various utilities used in the transformer implementation.
Utilities for transformer layers.
- core.transformer.utils.attention_mask_func(attention_scores, attention_mask)
- core.transformer.utils.erf_gelu(x)
- core.transformer.utils.gelu_impl(x)
OpenAI’s gelu implementation.
- core.transformer.utils.get_default_causal_mask(sq: int) → torch.Tensor
Return the causal upper triangular mask for softmax input.
- core.transformer.utils.get_linear_layer(rows, columns, init_method, perform_initialization=True)
Simple linear layer with weight initialization.
- core.transformer.utils.make_sharded_object_for_checkpoint(obj: Any, key: str, sharded_offsets: Iterable[Tuple[int, int, int]] = (), replica_id: Union[None, int, Tuple[int, ...]] = None, **kwargs)
Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group).
- Parameters
obj (object) – any object to be sharded
key (str) – unique identifier of the object
sharded_offsets (Iterable[Tuple[int, int, int]]) – offsets normally prepended to ShardedTensors, will be used as global offsets for ShardedObject
replica_id (Union[None, int, Tuple[int, ...]]) – replica id
- core.transformer.utils.make_sharded_tensors_for_checkpoint(state_dict: megatron.core.dist_checkpointing.mapping.StateDict, prefix: str, tensor_parallel_layers_axis_map: Optional[Dict[str, int]] = None, sharded_offsets: Iterable[Tuple[int, int, int]] = (), extra_state_suffix: str = '_extra_state')
Wraps tensors from transformer layers with ShardedTensor or ShardedObject.
For a given state_dict, wraps: - all _extra_states with ShardedObject - all tensors specified in tensor_parallel_layers_axis_map with TP and DP sharded ShardedTensor - other values with DP sharded ShardedTensor
- Parameters
state_dict (StateDict) – state_dict to convert
prefix (str) – prefix appended to keys in final state dict
tensor_parallel_layers_axis_map (Dict[str, int], optional) – dict mapping layer names to the axis for TP sharding
sharded_offsets (Iterable[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related), passed along to ShardedTensor
extra_state_suffix (str, default = '_extra_state') – layers with this suffix will be wrapped with ShardedObject instead of ShardedTensor.
- core.transformer.utils.openai_gelu(x)
- core.transformer.utils.sharded_state_dict_default(module: torch.nn.Module, prefix: str = '', sharded_offsets: Tuple[Tuple[int, int, int]] = (), metadata: Optional[dict] = None) → megatron.core.dist_checkpointing.mapping.ShardedStateDict
Provides implementation for sharded_state_dict method for non-MegatronModules.
Tries to call module.sharded_state_dict when possible, otherwise uses regular state dict and assumes tensors are replicated across TP and DP.
keep_vars=True is passed to module.state_dict so that optimizer states can be sharded later on.
- Parameters
module (torch.nn.Module) – module which sharded state dict we want to obtain
prefix (str) – prefix for the state dict keys
sharded_offsets (Tuple[Tuple[int, int, int]], optional) – sharding already applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor
metadata (dict, optional) – metadata passed to module sharded_state_dict method
- Returns
dictionary of state dict keys mapped to ShardedTensors
- Return type
dict