core.extensions.transformer_engine#

Module Contents#

Classes#

TransformerEngineConfigType

Configuration object types in config dictionary

TEQuantizationRecipe

Class to capture options for opening an autocast context in forward

TEQuantizationParams

Class to capture precision options for training and evaluation.

TENorm

A conditional wrapper to initialize an instance of Transformer-Engine’s LayerNorm or RMSNorm based on input.

TELinear

Wrapper for the Transformer-Engine’s Linear layer.

TELayerNormColumnParallelLinear

Wrapper for the Transformer-Engine’s LayerNormLinear layer that combines layernorm and linear layers.

TEColumnParallelLinear

Wrapper for the Transformer-Engine’s Linear layer but specialized similar to megatron’s ColumnParallelLinear layer.

TERowParallelLinear

Wrapper for the Transformer-Engine’s Linear layer but specialized similar to megatron’s RowParallelLinear layer.

TEDotProductAttention

Wrapper for the Transformer-Engine’s DotProductAttention layer that also has “flash attention” enabled.

TEDelayedScaling

Wrapper for the Transformer-Engine’s DelayedScaling layer.

TECudaRNGStatesTracker

Wraps TransformerEngine’s CudaRNGStatesTracker so that it is interchangeable with Megatron’s RNG tracker

Functions#

_get_fp8_autocast_for_quant_recipe

_get_fp8_autocast_for_quant_params

_get_should_context_be_quantized_recipe

_get_should_context_be_quantized_params

_get_extra_te_kwargs

condition_init_method

Condition TE init_method on config.perform_initialization.

split_te_layernorm_column_parallel_linear

Split a TELayerNormColumnParallelLinear into separate TENorm and TEColumnParallelLinear layers.

te_checkpoint

Checkpointing with Transformer-Engine.

set_save_original_input

Set the module to save the original input tensors.

Data#

API#

core.extensions.transformer_engine._TE_CONFIG_TYPE_KEY#

‘transformer_engine_config_type’

class core.extensions.transformer_engine.TransformerEngineConfigType(*args, **kwds)#

Bases: enum.Enum

Configuration object types in config dictionary

Initialization

TEQuantizationParams#

‘TEQuantizationParams’

class core.extensions.transformer_engine.TEQuantizationRecipe#

Class to capture options for opening an autocast context in forward

fp8_quantization_recipe: Optional[megatron.core.enums.Fp8Recipe]#

None

An FP8 quantization override if the module should use FP8. If no FP8 or FP4 quantization is configured, the recipe is execution in high-precision (BF16).

fp4_quantization_recipe: Optional[megatron.core.enums.Fp4Recipe]#

None

An FP4 quantization override if the module should use FP4. If no FP8 or FP4 quantization is configured, the recipe is execution in high-precision (BF16).

custom_recipe_factory: Optional[str]#

None

The path to a custom recipe factory if a custom Fp4 or Fp8 recipe is configured

fp8_format: str#

‘e4m3’

A format to select from an FP8Recipe

override_quantized_autocast: bool#

True

If the quantization autocast context for a targeted module is enabled, whether to override it and change (or disable) the quantization recipe.

override_nonquantized_autocast: bool#

False

If the quantization autocast context for a targeted module is not enabled, whether to override it and enable a quantization recipe.

tp_only_amax_red: bool#

False

If an amax reduction is applicable, such as in per-tensor quantization recipe, whether to reduce only along TP groups.

classmethod parse_from_config(
quant_config: Dict[Any, Any],
) core.extensions.transformer_engine.TEQuantizationRecipe#

Parse config from quantization dictionary.

classmethod get_config_keys() Set[str]#

Get expected keys from the dataclass fields.

class core.extensions.transformer_engine.TEQuantizationParams#

Class to capture precision options for training and evaluation.

training_recipe: core.extensions.transformer_engine.TEQuantizationRecipe#

None

Precision override for when self.training is True

evaluation_recipe: Optional[core.extensions.transformer_engine.TEQuantizationRecipe]#

None

Precision override for when self.training is False. If None, training_recipe is used.

static parse_from_config(
quant_config: megatron.core.quantization.quant_config.QuantizationConfig,
) core.extensions.transformer_engine.TEQuantizationParams#

Parses quantization config for a layer or throw an error.

core.extensions.transformer_engine._get_fp8_autocast_for_quant_recipe(
qrecipe: core.extensions.transformer_engine.TEQuantizationRecipe,
)#
core.extensions.transformer_engine._get_fp8_autocast_for_quant_params(
qparams: core.extensions.transformer_engine.TEQuantizationParams | None,
training: bool,
)#
core.extensions.transformer_engine._get_should_context_be_quantized_recipe(
qrecipe: core.extensions.transformer_engine.TEQuantizationRecipe,
is_original_context_quantized: bool,
)#
core.extensions.transformer_engine._get_should_context_be_quantized_params(
qparams: core.extensions.transformer_engine.TEQuantizationParams | None,
training: bool,
is_context_quantized: bool,
)#
core.extensions.transformer_engine._get_extra_te_kwargs(
config: megatron.core.transformer.transformer_config.TransformerConfig,
)#
core.extensions.transformer_engine.condition_init_method(config, init_method)#

Condition TE init_method on config.perform_initialization.

core.extensions.transformer_engine.split_te_layernorm_column_parallel_linear(
fused_layer,
config,
init_method: Optional[callable] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Split a TELayerNormColumnParallelLinear into separate TENorm and TEColumnParallelLinear layers.

Parameters:
  • fused_layer – The fused TELayerNormColumnParallelLinear layer to split

  • config – TransformerConfig to use for creating the new layers

  • init_method – Initialization method for the linear layer (optional)

  • tp_group – Tensor parallel group (optional)

Returns:

A tuple of (TENorm, TEColumnParallelLinear) with weights copied from the fused layer

class core.extensions.transformer_engine.TENorm#

A conditional wrapper to initialize an instance of Transformer-Engine’s LayerNorm or RMSNorm based on input.

__new__(
config: megatron.core.transformer.transformer_config.TransformerConfig,
hidden_size: int,
eps: float = 1e-05,
)#
class core.extensions.transformer_engine.TELinear(
input_size: int,
output_size: int,
*,
parallel_mode: Optional[str],
config: megatron.core.model_parallel_config.ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False,
symmetric_ar_type: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: transformer_engine.pytorch.Linear

Wrapper for the Transformer-Engine’s Linear layer.

Note that if Megatron’s parallel_state has not been initialized yet, the tp_group passed to TE will be None and must be set later via set_tensor_parallel_group().

parallel_mode currently supports 3 different values: - “column”: Split the weight matrix along output dimension (used in TEColumnParallelLinear) - “row”: Split the weight matrix along input dimension (used in TERowParallelLinear) - “duplicated”: No tensor parallelism and weight is duplicated across TP ranks - Note: For expert linear layers, we will disable communication logic here as TP communication is handled in token_dispatcher.

Initialization

finish_init(
quantization_config: megatron.core.quantization.quant_config.QuantizationConfig,
)#

Post-init of quantization override

will_execute_quantized(is_context_quantized: bool) bool#

Returns whether the module is configured to execute quantized.

forward(x)#

Forward.

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Replicate cross TP/DP.

backward_dw()#

Compute weight gradients during the backward pass if delay_wgrad_compute is enabled.

class core.extensions.transformer_engine.TELayerNormColumnParallelLinear(
input_size: int,
output_size: int,
*,
config: megatron.core.transformer.transformer_config.TransformerConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: transformer_engine.pytorch.LayerNormLinear

Wrapper for the Transformer-Engine’s LayerNormLinear layer that combines layernorm and linear layers.

Initialization

finish_init(
quantization_config: megatron.core.quantization.quant_config.QuantizationConfig,
)#

Post-init of quantization override

will_execute_quantized(is_context_quantized: bool) bool#

Returns whether the module is configured to execute quantized.

forward(x)#

Forward.

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Sharding along axis 0, bias sharded

__repr__()#
backward_dw()#

Compute weight gradients during the backward pass if delay_wgrad_compute is enabled.

class core.extensions.transformer_engine.TEColumnParallelLinear(
input_size: int,
output_size: int,
*,
config: megatron.core.model_parallel_config.ModelParallelConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: core.extensions.transformer_engine.TELinear

Wrapper for the Transformer-Engine’s Linear layer but specialized similar to megatron’s ColumnParallelLinear layer.

Initialization

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Sharding along axis 0, bias sharded

__repr__()#
backward_dw()#

Compute weight gradients during the backward pass if delay_wgrad_compute is enabled.

class core.extensions.transformer_engine.TERowParallelLinear(
input_size: int,
output_size: int,
*,
config: megatron.core.model_parallel_config.ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: core.extensions.transformer_engine.TELinear

Wrapper for the Transformer-Engine’s Linear layer but specialized similar to megatron’s RowParallelLinear layer.

Initialization

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Sharding along axis 1, bias not sharded

__repr__()#
backward_dw()#

Compute weight gradients during the backward pass if delay_wgrad_compute is enabled.

class core.extensions.transformer_engine.TEDotProductAttention(
config: megatron.core.transformer.transformer_config.TransformerConfig,
layer_number: int,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
attention_type: str,
attention_dropout: Optional[float] = None,
softmax_scale: Optional[float] = None,
k_channels: Optional[int] = None,
v_channels: Optional[int] = None,
num_splits: Optional[int] = None,
cp_comm_type: str = 'p2p',
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
)#

Bases: transformer_engine.pytorch.DotProductAttention

Wrapper for the Transformer-Engine’s DotProductAttention layer that also has “flash attention” enabled.

Note that if Megatron’s parallel_state has not been initialized yet, the tp_group and cp_group passed to TE will be None and must be set later via set_tensor_parallel_group() and set_context_parallel_group().

Initialization

cp_stream: torch.cuda.Stream#

None

forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
attention_bias: torch.Tensor = None,
packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams = None,
num_splits: Optional[int] = None,
)#

Forward.

sharded_state_dict(
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Sharded state dict for the learnable softmax offset parameter

class core.extensions.transformer_engine.TEDelayedScaling(
config: megatron.core.model_parallel_config.ModelParallelConfig,
fp8_format: int,
override_linear_precision: tuple = (False, False, False),
)#

Bases: transformer_engine.common.recipe.DelayedScaling

Wrapper for the Transformer-Engine’s DelayedScaling layer.

Initialization

class core.extensions.transformer_engine.TECudaRNGStatesTracker(is_inference_rng_tracker=False)#

Bases: transformer_engine.pytorch.distributed.CudaRNGStatesTracker

Wraps TransformerEngine’s CudaRNGStatesTracker so that it is interchangeable with Megatron’s RNG tracker

Initialization

is_initialized()#

Checks if the internal RNG state has been set with set_states().

reset()#

Reset the internal RNG state.

set_states(states)#

Set the internal RNG state.

add(name, seed)#

Track the rng state.

core.extensions.transformer_engine.te_checkpoint(
forward_func,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
*args,
**kwargs,
)#

Checkpointing with Transformer-Engine.

core.extensions.transformer_engine.set_save_original_input(module)#

Set the module to save the original input tensors.

Some transformer-engine modules would save the quantized tensors by default in fp8 training. This method is used to set these modules to save the original input tensors directly.

This can save the memory usage in some FP8 training scenarios, such as the attn linear_proj and the shared experts. The output-discarding recompute method also relies on this.