core.extensions.transformer_engine#
Module Contents#
Classes#
Configuration object types in config dictionary |
|
Class to capture options for opening an autocast context in forward |
|
Class to capture precision options for training and evaluation. |
|
A conditional wrapper to initialize an instance of
Transformer-Engine’s |
|
Wrapper for the Transformer-Engine’s |
|
Wrapper for the Transformer-Engine’s |
|
Wrapper for the Transformer-Engine’s |
|
Wrapper for the Transformer-Engine’s |
|
Wrapper for the Transformer-Engine’s |
|
Wrapper for the Transformer-Engine’s |
|
Wraps TransformerEngine’s CudaRNGStatesTracker so that it is interchangeable with Megatron’s RNG tracker |
Functions#
Condition TE init_method on config.perform_initialization. |
|
Split a TELayerNormColumnParallelLinear into separate TENorm and TEColumnParallelLinear layers. |
|
Checkpointing with Transformer-Engine. |
|
Set the module to save the original input tensors. |
Data#
API#
- core.extensions.transformer_engine._TE_CONFIG_TYPE_KEY#
‘transformer_engine_config_type’
- class core.extensions.transformer_engine.TransformerEngineConfigType(*args, **kwds)#
Bases:
enum.EnumConfiguration object types in config dictionary
Initialization
- TEQuantizationParams#
‘TEQuantizationParams’
- class core.extensions.transformer_engine.TEQuantizationRecipe#
Class to capture options for opening an autocast context in forward
- fp8_quantization_recipe: Optional[megatron.core.enums.Fp8Recipe]#
None
An FP8 quantization override if the module should use FP8. If no FP8 or FP4 quantization is configured, the recipe is execution in high-precision (BF16).
- fp4_quantization_recipe: Optional[megatron.core.enums.Fp4Recipe]#
None
An FP4 quantization override if the module should use FP4. If no FP8 or FP4 quantization is configured, the recipe is execution in high-precision (BF16).
- custom_recipe_factory: Optional[str]#
None
The path to a custom recipe factory if a custom Fp4 or Fp8 recipe is configured
- fp8_format: str#
‘e4m3’
A format to select from an FP8Recipe
- override_quantized_autocast: bool#
True
If the quantization autocast context for a targeted module is enabled, whether to override it and change (or disable) the quantization recipe.
- override_nonquantized_autocast: bool#
False
If the quantization autocast context for a targeted module is not enabled, whether to override it and enable a quantization recipe.
- tp_only_amax_red: bool#
False
If an amax reduction is applicable, such as in per-tensor quantization recipe, whether to reduce only along TP groups.
- classmethod parse_from_config(
- quant_config: Dict[Any, Any],
Parse config from quantization dictionary.
- classmethod get_config_keys() Set[str]#
Get expected keys from the dataclass fields.
- class core.extensions.transformer_engine.TEQuantizationParams#
Class to capture precision options for training and evaluation.
- training_recipe: core.extensions.transformer_engine.TEQuantizationRecipe#
None
Precision override for when self.training is True
- evaluation_recipe: Optional[core.extensions.transformer_engine.TEQuantizationRecipe]#
None
Precision override for when self.training is False. If None, training_recipe is used.
- static parse_from_config(
- quant_config: megatron.core.quantization.quant_config.QuantizationConfig,
Parses quantization config for a layer or throw an error.
- core.extensions.transformer_engine._get_fp8_autocast_for_quant_recipe( )#
- core.extensions.transformer_engine._get_fp8_autocast_for_quant_params(
- qparams: core.extensions.transformer_engine.TEQuantizationParams | None,
- training: bool,
- core.extensions.transformer_engine._get_should_context_be_quantized_recipe(
- qrecipe: core.extensions.transformer_engine.TEQuantizationRecipe,
- is_original_context_quantized: bool,
- core.extensions.transformer_engine._get_should_context_be_quantized_params(
- qparams: core.extensions.transformer_engine.TEQuantizationParams | None,
- training: bool,
- is_context_quantized: bool,
- core.extensions.transformer_engine._get_extra_te_kwargs(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- core.extensions.transformer_engine.condition_init_method(config, init_method)#
Condition TE init_method on config.perform_initialization.
- core.extensions.transformer_engine.split_te_layernorm_column_parallel_linear(
- fused_layer,
- config,
- init_method: Optional[callable] = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Split a TELayerNormColumnParallelLinear into separate TENorm and TEColumnParallelLinear layers.
- Parameters:
fused_layer – The fused TELayerNormColumnParallelLinear layer to split
config – TransformerConfig to use for creating the new layers
init_method – Initialization method for the linear layer (optional)
tp_group – Tensor parallel group (optional)
- Returns:
A tuple of (TENorm, TEColumnParallelLinear) with weights copied from the fused layer
- class core.extensions.transformer_engine.TENorm#
A conditional wrapper to initialize an instance of Transformer-Engine’s
LayerNormorRMSNormbased on input.- __new__(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- hidden_size: int,
- eps: float = 1e-05,
- class core.extensions.transformer_engine.TELinear(
- input_size: int,
- output_size: int,
- *,
- parallel_mode: Optional[str],
- config: megatron.core.model_parallel_config.ModelParallelConfig,
- init_method: Callable,
- bias: bool,
- skip_bias_add: bool,
- skip_weight_param_allocation: bool,
- tp_comm_buffer_name: Optional[str] = None,
- is_expert: bool = False,
- symmetric_ar_type: Optional[str] = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
transformer_engine.pytorch.LinearWrapper for the Transformer-Engine’s
Linearlayer.Note that if Megatron’s parallel_state has not been initialized yet, the tp_group passed to TE will be None and must be set later via set_tensor_parallel_group().
parallel_mode currently supports 3 different values: - “column”: Split the weight matrix along output dimension (used in TEColumnParallelLinear) - “row”: Split the weight matrix along input dimension (used in TERowParallelLinear) - “duplicated”: No tensor parallelism and weight is duplicated across TP ranks - Note: For expert linear layers, we will disable communication logic here as TP communication is handled in token_dispatcher.
Initialization
- finish_init(
- quantization_config: megatron.core.quantization.quant_config.QuantizationConfig,
Post-init of quantization override
- will_execute_quantized(is_context_quantized: bool) bool#
Returns whether the module is configured to execute quantized.
- forward(x)#
Forward.
- sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#
Replicate cross TP/DP.
- backward_dw()#
Compute weight gradients during the backward pass if delay_wgrad_compute is enabled.
- class core.extensions.transformer_engine.TELayerNormColumnParallelLinear(
- input_size: int,
- output_size: int,
- *,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- init_method: Callable,
- gather_output: bool,
- bias: bool,
- skip_bias_add: bool,
- is_expert: bool,
- skip_weight_param_allocation: bool = False,
- tp_comm_buffer_name: Optional[str] = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
transformer_engine.pytorch.LayerNormLinearWrapper for the Transformer-Engine’s
LayerNormLinearlayer that combines layernorm and linear layers.Initialization
- finish_init(
- quantization_config: megatron.core.quantization.quant_config.QuantizationConfig,
Post-init of quantization override
- will_execute_quantized(is_context_quantized: bool) bool#
Returns whether the module is configured to execute quantized.
- forward(x)#
Forward.
- sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#
Sharding along axis 0, bias sharded
- __repr__()#
- backward_dw()#
Compute weight gradients during the backward pass if delay_wgrad_compute is enabled.
- class core.extensions.transformer_engine.TEColumnParallelLinear(
- input_size: int,
- output_size: int,
- *,
- config: megatron.core.model_parallel_config.ModelParallelConfig,
- init_method: Callable,
- gather_output: bool,
- bias: bool,
- skip_bias_add: bool,
- is_expert: bool,
- skip_weight_param_allocation: bool = False,
- tp_comm_buffer_name: Optional[str] = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
core.extensions.transformer_engine.TELinearWrapper for the Transformer-Engine’s
Linearlayer but specialized similar to megatron’sColumnParallelLinearlayer.Initialization
- sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#
Sharding along axis 0, bias sharded
- __repr__()#
- backward_dw()#
Compute weight gradients during the backward pass if delay_wgrad_compute is enabled.
- class core.extensions.transformer_engine.TERowParallelLinear(
- input_size: int,
- output_size: int,
- *,
- config: megatron.core.model_parallel_config.ModelParallelConfig,
- init_method: Callable,
- bias: bool,
- input_is_parallel: bool,
- skip_bias_add: bool,
- is_expert: bool,
- tp_comm_buffer_name: Optional[str] = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
core.extensions.transformer_engine.TELinearWrapper for the Transformer-Engine’s
Linearlayer but specialized similar to megatron’sRowParallelLinearlayer.Initialization
- sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#
Sharding along axis 1, bias not sharded
- __repr__()#
- backward_dw()#
Compute weight gradients during the backward pass if delay_wgrad_compute is enabled.
- class core.extensions.transformer_engine.TEDotProductAttention(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- layer_number: int,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
- attention_type: str,
- attention_dropout: Optional[float] = None,
- softmax_scale: Optional[float] = None,
- k_channels: Optional[int] = None,
- v_channels: Optional[int] = None,
- num_splits: Optional[int] = None,
- cp_comm_type: str = 'p2p',
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
Bases:
transformer_engine.pytorch.DotProductAttentionWrapper for the Transformer-Engine’s
DotProductAttentionlayer that also has “flash attention” enabled.Note that if Megatron’s parallel_state has not been initialized yet, the tp_group and cp_group passed to TE will be None and must be set later via set_tensor_parallel_group() and set_context_parallel_group().
Initialization
- cp_stream: torch.cuda.Stream#
None
- forward(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
- attention_bias: torch.Tensor = None,
- packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams = None,
- num_splits: Optional[int] = None,
Forward.
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: Tuple[Tuple[int, int, int]] = (),
- metadata: Optional[dict] = None,
Sharded state dict for the learnable softmax offset parameter
- class core.extensions.transformer_engine.TEDelayedScaling(
- config: megatron.core.model_parallel_config.ModelParallelConfig,
- fp8_format: int,
- override_linear_precision: tuple = (False, False, False),
Bases:
transformer_engine.common.recipe.DelayedScalingWrapper for the Transformer-Engine’s
DelayedScalinglayer.Initialization
- class core.extensions.transformer_engine.TECudaRNGStatesTracker(is_inference_rng_tracker=False)#
Bases:
transformer_engine.pytorch.distributed.CudaRNGStatesTrackerWraps TransformerEngine’s CudaRNGStatesTracker so that it is interchangeable with Megatron’s RNG tracker
Initialization
- is_initialized()#
Checks if the internal RNG state has been set with set_states().
- reset()#
Reset the internal RNG state.
- set_states(states)#
Set the internal RNG state.
- add(name, seed)#
Track the rng state.
- core.extensions.transformer_engine.te_checkpoint(
- forward_func,
- distribute_saved_activations,
- get_rng_state_tracker,
- tp_group,
- *args,
- **kwargs,
Checkpointing with Transformer-Engine.
- core.extensions.transformer_engine.set_save_original_input(module)#
Set the module to save the original input tensors.
Some transformer-engine modules would save the quantized tensors by default in fp8 training. This method is used to set these modules to save the original input tensors directly.
This can save the memory usage in some FP8 training scenarios, such as the attn linear_proj and the shared experts. The output-discarding recompute method also relies on this.