PyTorch
- class transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
Applies a linear transformation to the incoming data \(y = xA^T + b\)
On NVIDIA GPUs it is a drop-in replacement for
torch.nn.Linear.- Parameters:
in_features (int) – size of each input sample.
out_features (int) – size of each output sample.
bias (bool, default = True) – if set to
False, the layer will not learn an additive bias.init_method (Callable, default = None) – used for initializing weights in the following way:
init_method(weight). When set toNone, defaults totorch.nn.init.normal_(mean=0.0, std=0.023).get_rng_state_tracker (Callable, default = None) – used to get the random number generator state tracker for initializing weights.
rng_tracker_name (str, default = None) – the param passed to get_rng_state_tracker to get the specific rng tracker.
parameters_split (Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None) – Configuration for splitting the weight and bias tensors along dim 0 into multiple PyTorch parameters. If a list or tuple of strings is provided, they are used to make the names of equally-sized parameters. If a dict (preferably an OrderedDict) is provided, the keys are used as names and values as split sizes along dim 0. The resulting parameters will have names that end in
_weightor_bias, so trailing underscores are stripped from any provided names.device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will be allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.
name (str, default = None) – name of the module, currently used for debugging purposes.
- Parallelism parameters:
sequence_parallel (bool, default = False) – if set to
True, uses sequence parallelism.tp_group (ProcessGroup, default = None) – tensor parallel process group.
tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the
set_tensor_parallel_group(tp_group)method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.parallel_mode ({None, ‘column’, ‘row’}, default = None) – used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described here. When set to
None, no communication is performed.
- Optimization parameters:
fuse_wgrad_accumulation (bool, default = ‘False’) – if set to
True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additionalmain_gradattribute (used instead of the regulargrad) which is a pre-allocated buffer of the correct size to accumulate gradients in. This argument along with weight tensor having attribute ‘overwrite_main_grad’ set to True will overwritemain_gradinstead of accumulating.return_bias (bool, default = False) – when set to
True, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
delay_wgrad_compute (bool, default = False) – Whether or not to delay weight gradient computation. If set to
True, it’s the user’s responsibility to callmodule.backward_dwto compute weight gradients.symmetric_ar_type ({None, ‘multimem_all_reduce’, ‘two_shot’, ‘one_shot’}, default = None) – Type of symmetric memory all-reduce to use during the forward pass. This can help in latency bound communication situations. Requires PyTorch version 2.7.0 or higher. When set to
None, standard all-reduce is used.save_original_input (bool, default = False) – If set to
True, always saves the original input tensor rather than the cast tensor. In some scenarios, the input tensor is used by multiple modules, and saving the original input tensor may reduce the memory usage. Cannot work with FP8 DelayedScaling recipe.
- forward(inp: torch.Tensor, is_first_microbatch: bool | None = None, fp8_output: bool | None = False, fp8_grad: bool | None = False) torch.Tensor | Tuple[torch.Tensor, Ellipsis]
Apply the linear transformation to the input.
- Parameters:
inp (torch.Tensor) – Input tensor.
is_first_microbatch ({True, False, None}, default = None) –
During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:
during FP8 training, it allows caching of the FP8 versions of the weights
it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)
- set_tensor_parallel_group(tp_group: transformer_engine.pytorch.constants.dist_group_type | None) None
Set the tensor parallel group for the given module before executing the forward pass.
- Parameters:
tp_group (ProcessGroup, default = None) – tensor parallel process group.
- class transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs)
- Applies linear transformations to the incoming data list
\(y_i = x_iA_i^T + b_i\) in a grouped way.
- Parameters:
num_gemms (int) – number of GEMMs to be performed simutaneously.
in_features (int) – size of each input sample.
out_features (int) – size of each output sample.
bias (bool, default = True) – if set to
False, the layer will not learn an additive bias.init_method (Callable, default = None) – used for initializing weights in the following way:
init_method(weight). When set toNone, defaults totorch.nn.init.normal_(mean=0.0, std=0.023).get_rng_state_tracker (Callable, default = None) – used to get the random number generator state tracker for initializing weights.
rng_tracker_name (str, default = None) – the param passed to get_rng_state_tracker to get the specific rng tracker.
device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will be allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.
- Optimization parameters:
fuse_wgrad_accumulation (bool, default = False) – if set to
True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additionalmain_gradattribute (used instead of the regulargrad) which is a pre-allocated buffer of the correct size to accumulate gradients in. This argument along with weight tensor having attribute ‘overwrite_main_grad’ set to True will overwritemain_gradinstead of accumulating.return_bias (bool, default = False) – when set to
True, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
delay_wgrad_compute (bool, default = False) – Whether to delay weight gradient computation
save_original_input (bool, default = False) – If set to
True, always saves the original input tensor rather than the cast tensor. In some scenarios, the input tensor is used by multiple modules, and saving the original input tensor may reduce the memory usage. Cannot work with FP8 DelayedScaling recipe.
Notes
GroupedLinear doesn’t really handle the TP communications inside. The
tp_sizeandparallel_modeare used to determine the shapes of weights and biases. The TP communication should be handled in the dispatch and combine stages of MoE models.- forward(inp: torch.Tensor, m_splits: List[int], is_first_microbatch: bool | None = None) torch.Tensor | Tuple[torch.Tensor, Ellipsis]
Apply the linear transformation to the input.
- Parameters:
inp (torch.Tensor) – Input tensor.
m_splits (List[int]) – List of integers representing the split of the input tensor.
is_first_microbatch ({True, False, None}, default = None) –
During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:
during FP8 training, it allows caching of the FP8 versions of the weights
it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)
- set_tensor_parallel_group(tp_group: transformer_engine.pytorch.constants.dist_group_type | None) None
Set the tensor parallel group for the given module before executing the forward pass.
- Parameters:
tp_group (ProcessGroup, default = None) – tensor parallel process group.
- class transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)
Layer Normalization
Applies Layer Normalization over a mini-batch of inputs as described in the paper Layer Normalization
\[y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta\]\(\gamma\) and \(\beta\) are learnable affine transform parameters that match the inner-most dimensions of the input tensor.
- Parameters:
normalized_shape (int or iterable of int) – Inner dimensions of input tensor
eps (float, default = 1e-5) – A value added to the denominator of layer normalization for numerical stability
device (torch.device, default = default CUDA device) – Tensor device
dtype (torch.dtype, default = default dtype) – Tensor datatype
zero_centered_gamma (bool, default = 'False') –
If
True, the \(\gamma\) parameter is initialized to zero and the calculation changes to\[y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]sm_margin (int or dict, default = 0) – Number of SMs to exclude when launching CUDA kernels. This helps overlap with other kernels, e.g. communication kernels. For more fine-grained control, provide a dict with the SM margin at each compute stage (
"forward","backward","inference").sequence_parallel (bool) – Legacy parameter. Set a bool attr named
sequence_parallelin the parameters. This is custom logic for Megatron-LM integration.
- class transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
Root Mean Square Layer Normalization
Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in the paper Root Mean Square Layer Normalization
\[y = \frac{x}{\text{RMS}_\varepsilon(x)} * \gamma\]where
\[\text{RMS}_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^n x_i^2 + \varepsilon}\]\(\gamma\) is a learnable affine transform parameter that matches the inner-most dimensions of the input tensor.
- Parameters:
normalized_shape (int or iterable of int) – Inner dimensions of input tensor
eps (float, default = 1e-5) – A value added to the denominator for numerical stability
device (torch.device, default = default CUDA device) – Tensor device
dtype (torch.dtype, default = default dtype) – Tensor datatype
zero_centered_gamma (bool, default = False) –
If
True, the \(\gamma\) parameter is initialized to zero and the calculation changes to\[y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)\]sm_margin (int, default = 0) – Number of SMs to exclude when launching CUDA kernels. This helps overlap with other kernels, e.g. communication kernels. For more fine-grained control, provide a dict with the SM margin at each compute stage (
"forward","backward","inference").sequence_parallel (bool) – Legacy parameter. Set a bool attr named
sequence_parallelin the parameters. This is custom logic for Megatron-LM integration.
- class transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)
Applies layer normalization followed by linear transformation to the incoming data.
- Parameters:
in_features (int) – size of each input sample.
out_features (int) – size of each output sample.
eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
bias (bool, default = True) – if set to
False, the layer will not learn an additive bias.normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.
init_method (Callable, default = None) – used for initializing weights in the following way:
init_method(weight). When set toNone, defaults totorch.nn.init.normal_(mean=0.0, std=0.023).return_layernorm_output (bool, default = False) – if set to
True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.return_layernorm_output_gathered (bool, default = False) – if set to
True, output of layernorm is returned after the all gather operation. Ignored if return_layernorm_output is False. Example use case: with sequence parallel, input to residual connection for transformer module (e.g. LoRA) will need to be gathered. Returning layernorm output gathered will prevent a redundant gather.parameters_split (Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None) – Configuration for splitting the weight and bias tensors along dim 0 into multiple PyTorch parameters. If a list or tuple of strings is provided, they are used to make the names of equally-sized parameters. If a dict (preferably an OrderedDict) is provided, the keys are used as names and values as split sizes along dim 0. The resulting parameters will have names that end in
_weightor_bias, so trailing underscores are stripped from any provided names.zero_centered_gamma (bool, default = 'False') –
if set to
'True', gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will be allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.
name (str, default = None) – name of the module, currently used for debugging purposes.
- Parallelism parameters:
sequence_parallel (bool, default = False) – if set to
True, uses sequence parallelism.tp_group (ProcessGroup, default = None) – tensor parallel process group.
tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the
set_tensor_parallel_group(tp_group)method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.parallel_mode ({None, ‘column’, ‘row’}, default = None) – used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described here. When set to
None, no communication is performed.
- Optimization parameters:
fuse_wgrad_accumulation (bool, default = ‘False’) – if set to
True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additionalmain_gradattribute (used instead of the regulargrad) which is a pre-allocated buffer of the correct size to accumulate gradients in. This argument along with weight tensor having attribute ‘overwrite_main_grad’ set to True will overwritemain_gradinstead of accumulating.return_bias (bool, default = False) – when set to
True, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
delay_wgrad_compute (bool, default = False) – Whether or not to delay weight gradient computation. If set to
True, it’s the user’s responsibility to callmodule.backward_dwto compute weight gradients.symmetric_ar_type ({None, ‘multimem_all_reduce’, ‘two_shot’, ‘one_shot’}, default = None) – Type of symmetric memory all-reduce to use during the forward pass. This can help in latency bound communication situations. Requires PyTorch version 2.7.0 or higher. When set to
None, standard all-reduce is used.
- forward(inp: torch.Tensor, is_first_microbatch: bool | None = None, fp8_output: bool | None = False, fp8_grad: bool | None = False) torch.Tensor | Tuple[torch.Tensor, Ellipsis]
Apply layer normalization to the input followed by a linear transformation.
- Parameters:
inp (torch.Tensor) – Input tensor.
is_first_microbatch ({True, False, None}, default = None) –
During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:
during FP8 training, it allows caching of the FP8 versions of the weights
it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)
- set_tensor_parallel_group(tp_group: transformer_engine.pytorch.constants.dist_group_type | None) None
Set the tensor parallel group for the given module before executing the forward pass.
- Parameters:
tp_group (ProcessGroup, default = None) – tensor parallel process group.
- class transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
Applies layer normalization on the input followed by the MLP module, consisting of 2 successive linear transformations, separated by the activation function.
- Parameters:
hidden_size (int) – size of each input sample.
ffn_hidden_size (int) – intermediate size to which input samples are projected.
eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
bias (bool, default = True) – if set to
False, the FC1 and FC2 layers will not learn an additive bias.normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.
activation (str, default = 'gelu') – activation function used. Options:
'gelu','geglu','qgelu','qgeglu','relu','reglu','srelu','sreglu','silu','swiglu', and'clamped_swiglu'.activation_params (dict, default = None) – Additional parameters for the activation function. At the moment, only used for
'clamped_swiglu'activation which supports'limit'and'alpha'parameters.init_method (Callable, default = None) – used for initializing FC1 weights in the following way:
init_method(weight). When set toNone, defaults totorch.nn.init.normal_(mean=0.0, std=0.023).output_layer_init_method (Callable, default = None) – used for initializing FC2 weights in the following way:
output_layer_init_method(weight). When set toNone, defaults totorch.nn.init.normal_(mean=0.0, std=0.023).return_layernorm_output (bool, default = False) – if set to
True, output of layernorm is returned from theforward()method together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.return_layernorm_output_gathered (bool, default = False) – if set to
True, output of layernorm is returned after the all gather operation. Ignored ifreturn_layernorm_outputis False. Example use case: with sequence parallel, input to residual connection for transformer module (e.g. LoRA) will need to be gathered. Returning layernorm output gathered will prevent a redundant gather.zero_centered_gamma (bool, default = False) –
if set to
True, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will be allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.
name (str, default = None) – name of the module, currently used for debugging purposes.
- Parallelism parameters:
set_parallel_mode (bool, default = False) – if set to
True, FC1 is used as Column Parallel and FC2 is used as Row Parallel as described here.sequence_parallel (bool, default = False) – if set to
True, uses sequence parallelism.tp_group (ProcessGroup, default = None) – tensor parallel process group.
tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the
set_tensor_parallel_group(tp_group)method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.
- Optimization parameters:
fuse_wgrad_accumulation (bool, default = False) – if set to
True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additionalmain_gradattribute (used instead of the regulargrad) which is a pre-allocated buffer of the correct size to accumulate gradients in. This argument along with weight tensor having attribute'overwrite_main_grad'set to True will overwritemain_gradinstead of accumulating.return_bias (bool, default = False) – when set to
True, this module will not apply the additive bias for FC2, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
seq_length (int) – sequence length of input samples. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.
micro_batch_size (int) – batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.
delay_wgrad_compute (bool, default = False) – Whether or not to delay weight gradient computation. If set to
True, it’s the user’s responsibility to callbackward_dw()to compute weight gradients.symmetric_ar_type ({None, ‘multimem_all_reduce’, ‘two_shot’, ‘one_shot’}, default = None) – Type of symmetric memory all-reduce to use during the forward pass. This can help in latency bound communication situations. Requires PyTorch version 2.7.0 or higher. When set to
None, standard all-reduce is used.checkpoint (bool, default = False) – whether to use selective activation checkpointing, where activations are not saved for bwd, and instead are recomputed (skipping fc2, as it is not needed for backward). Trades compute for memory. default is false, in which activations are saved in fwd. not supported for onnx forward
- forward(inp: torch.Tensor, is_first_microbatch: bool | None = None) torch.Tensor | Tuple[torch.Tensor, Ellipsis]
Apply layer normalization to the input followed by a feedforward network (MLP Block).
- Parameters:
inp (torch.Tensor) – Input tensor.
is_first_microbatch ({True, False, None}, default = None) –
During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:
during FP8 training, it allows caching of the FP8 versions of the weights
it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)
- set_tensor_parallel_group(tp_group: transformer_engine.pytorch.constants.dist_group_type | None) None
Set the tensor parallel group for the given module before executing the forward pass.
- Parameters:
tp_group (ProcessGroup, default = None) – tensor parallel process group.
- class transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need.
Note
Argument
attention_maskin theforwardcall is only used whenattn_mask_typeincludes ‘“padding”’ or"arbitrary".Warning
FlashAttention uses a non-deterministic algorithm for optimal performance. To observe deterministic behavior at the cost of performance, use FlashAttention version >=
2.4.1and set the environment variableNVTE_ALLOW_NONDETERMINISTIC_ALGO=0. In order to disableflash-attnentirely, setNVTE_FLASH_ATTN=0.Note
Transformer Engine stores the FP8 metadata under a
._extra_statekey when checkpointing. As the FP8 attention support expands from one backend to multiple backends, the location of that key has also shifted (see FP8 checkpoint compatibility).- Parameters:
num_attention_heads (int) – number of attention heads in the transformer layer.
kv_channels (Union[int, Tuple[int, int]]) – the head size in key and value tensors. If the same,
kv_channelscan be an integer; if not,kv_channelsshould be a tuple of two integers.num_gqa_groups (Optional[int], default = None) – number of GQA groups in the transformer layer. Grouped Query Attention is described in this paper. This only affects the keys and values, not the queries. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e.
num_gqa_groups = num_attention_heads.attention_dropout (float, default = 0.0) – dropout probability for the dropout op during multi-head attention.
attn_mask_type (str, default = "causal") –
type of attention mask passed into softmax operation, options are
"no_mask","padding","causal","padding,causal","causal,padding","padding_causal","causal_bottom_right","padding_causal_bottom_right", and"arbitrary", where"padding,causal","causal,padding"and"padding_causal"are equivalent. This arg can be overridden byattn_mask_typein theforward()method. It is useful for cases involving compilation/tracing, e.g. ONNX export, and the forward arg is useful for dynamically changing mask types, e.g. a different mask for training and inference.For
"no_mask", no attention mask is applied.For
"causal","causal_bottom_right", or the causal mask in"padding_causal"and"padding_causal_bottom_right", Transformer Engine calculates and applies an upper triangular mask to the softmax input. No user input is needed. Causal masks without the"bottom_right"appendix align the diagonal line to the top left corner of the softmax matrix. With"bottom_right", the causal mask is aligned to the bottom right corner, which is often used in inference/KV caching.For
"padding", or the padding mask in"padding_causal"and"padding_causal_bottom_right", users need to provide the locations of padded tokens, either viacu_seqlens_qandcu_seqlens_kv(both of shape[batch_size + 1]), or viaattention_mask(one tensor for self-attention of shape[batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for cross-attention of shapes[batch_size, 1, 1, max_seqlen_q]and[batch_size, 1, 1, max_seqlen_kv]).For
"arbitrary", users need to provide a mask that is broadcastable to the shape of softmax input[batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
window_size (Optional[Tuple[int, int]], default = None) – sliding window size for local attention, where query at position i attends to keys in
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases ``(-1, -1)and(-1, 0)mean no sliding window and causal mask specifically. Bothcausalandcausal_bottom_rightmasks map towindow_size = (-1, 0)and Transformer Engine distinguishes them based onattn_mask_type. Similar toattn_mask_type,window_sizecan be overridden bywindow_sizeinforwardas well.attention_type (str, default = "self") – type of attention, either
"self"and"cross".layer_number (int, default = None) – layer number of the current
DotProductAttentionwhen multiple such modules are concatenated, for instance in consecutive transformer blocks.qkv_format (str, default = "sbhd") – dimension format for
query_layer,key_layerandvalue_layer, {"sbhd","bshd","thd"}.sstands for the sequence length,bbatch size,hthe number of heads,dhead size, andtthe total number of tokens in a batch, witht = sum(s_i), for i = 0...b-1."sbhd"and"bshd"formats are used for when sequences in a batch are of equal length or padded to equal length, and the"thd"format is used for when sequences in a batch have different lengths. Please note that these formats do not reflect how tensorsquery_layer,key_layer,value_layerare laid out in memory. For that, please useget_qkv_layoutto gain the layout information.softmax_scale (Optional[float], default = None) – softmax scale for the attention scores. If
None, defaults to1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0]).softmax_type (str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla') –
Softmax type as described in the paper Efficient Streaming Language Models with Attention Sinks.
For a given attention score \(S = Q \cdot K^T\), of shape
[b, h, s_q, s_kv]:'vanilla':\[Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}\]'off-by-one':\[Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}\]'learnable':\[Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}\]where \(\alpha\) is a learnable parameter of shape
[h].
'off-by-one'and'learnable'softmax types are also called sink attention ('zero sink'and'learnable sink').return_max_logit (Optional[bool], default = False) – If true, returns the maximum attention score that can be used in a Muon optimizer to rescale the Q and K projection weights (see Muon is Scalable for LLM Training). \(\text{max_logit} = \max(S)\), where \(S = \text{mask}(Q \cdot K^T \cdot \text{softmax_scale} + \text{bias})\) of shape
[b, h, s_q, s_kv], and \(\text{max_logit}\) is of shape[h].
- Parallelism parameters:
sequence_parallel (bool, default = False) – if set to
True, uses sequence parallelism.tp_size (int, default = 1) – tensor parallel world size.
tp_group (ProcessGroup, default = None) – tensor parallel process group.
cp_group (Union[ProcessGroup, List[ProcessGroup]], default = None) – context parallel process group.
ProcessGroupis forcp_comm_typeof"p2p","all_gather", and"a2a".List[ProcessGroup]is forcp_comm_typeof"a2a+p2p", wherecp_group[0]andcp_group[1]are for"a2a"and"p2p"communications respectively.cp_global_ranks (list of global rank IDs, default = None) – global rank IDs of GPUs that are in
cp_group.cp_stream (CUDA stream, default = None) – context parallelism splits flash attention into multiple steps for compute and communication overlapping. To address the wave quantization issue of each split step, we add an additional CUDA stream so that we can overlap two flash attention kernels.
cp_comm_type (str, default = “p2p”) – inter-gpu communication type for context parallelism. Can be
"p2p"or"all_gather"or"a2a"or"a2a+p2p"."p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute."all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped."a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV."a2a+p2p": hierarchical CP implementation. First applying a2a to QKV across each CP sub-group (e.g., via NVLink), then exchanging KV with p2p between sub-groups (e.g., via IBLink).
- forward(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] = None, qkv_format: str = None, cu_seqlens_q: torch.Tensor = None, cu_seqlens_kv: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, max_seqlen_q: int = None, max_seqlen_kv: int = None, attn_mask_type: str | None = None, window_size: Tuple[int, int] | None = None, checkpoint_core_attention: bool = False, core_attention_bias_type: str = 'no_bias', core_attention_bias: torch.Tensor | None = None, alibi_slopes: torch.Tensor | None = None, fast_zero_fill: bool = True, inference_params: transformer_engine.pytorch.attention.inference.InferenceParams | None = None, pad_between_seqs: bool | None = None, fp8_output: bool | None = False, num_splits: int | None = 1) torch.Tensor
Dot Product Attention Layer.
Note
Argument
attention_maskis only used whenattn_mask_typeincludes"padding"or"arbitrary".Note
DotProductAttention supports three backends: 1) FlashAttention which calls HazyResearch/Dao-AILab’s flash-attn PyTorch API, 2) FusedAttention which has multiple fused attention implementations based on cuDNN Graph API (see
FusedAttentionfor more details on FusedAttention backends), and 3) UnfusedDotProductAttention which is the native PyTorch implementation with fused scaled masked softmax.Note
Users can use environment variables
NVTE_FLASH_ATTN,NVTE_FUSED_ATTN, andNVTE_FUSED_ATTN_BACKENDto control which DotProductAttention backend, and FusedAttention backend if applicable, to use. Transformer Engine prioritizes FlashAttention over FusedAttention and over UnfusedDotProductAttention. If FusedAttention is being used, users can also choose to switch to flash-attn’s implementation for backward by settingNVTE_FUSED_ATTN_USE_FAv2_BWD=1(default: 0), because of the performance differences between various versions of flash-attn and FusedAttention. Further,NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPTcan be used to enable (1) or disable (0) the workspace related optimizations in FusedAttention. When unset, Transformer Engine determines the code path based on its internal logic. These optimizations trade memory for performance and should be used with care.Note
When training data has variable sequence lengths, users have two options.
Manipulate the data and pad all sequences to the same length. Use
qkv_format= {“bshd”, “sbhd”} andattn_mask_type= {“padding”, “padding_causal”, “padding_causal_bottom_right”}. Pass incu_seqlens_qandcu_seqlens_kv, orattention_mask(which will be converted tocu_seqlens_qandcu_seqlens_kv), to provide the real sequence length information. For example, a batch of 3 sequences[a a a b b c c c c]can be padded to[a a a PAD b b PAD PAD c c c c], and the cumulative sequence length tensors would becu_seqlens_q=cu_seqlens_kv=[0, 3, 5, 9]for self-attention.Do not perform padding on training data. Use
qkv_format= “thd” andattn_mask_type= {“padding”, “padding_causal”, “padding_causal_bottom_right”}. Pass incu_seqlens_qandcu_seqlens_kv, orattention_mask, as in option 1. For example, a batch of 3 sequences[a a a b b c c c c]can be processed without any padding, and the sequence length tensors would becu_seqlens_q=cu_seqlens_kv=[0, 3, 5, 9]for self-attention.In certain use cases, a varying number of identifier tokens are inserted between sequences. These tokens do not participate in the attention calculation.
cu_seqlens_q_paddedandcu_seqlens_kv_paddedmust be specified in such cases to correctly identify the start and end of each sequence in a batch. For example, a batch of 3 sequences[a a a 1 b b 2 2 c c c c 3]would havecu_seqlens_q=cu_seqlens_kv=[0, 3, 5, 9], andcu_seqlens_q_padded=cu_seqlens_kv_padded=[0, 4, 8, 13]for self-attention.
Note
When
qkv_format= {“bshd”, “sbhd”}, sequences are of equal length in a batch.max_seqlen_qandmax_seqlen_kvshould be the same as the “s” dimension ofquery_layerandkey_layertensors. When unset, Transformer Engine will infer them as such.When
qkv_format= “thd”, sequences have varying lengths.max_seqlen_qandmax_seqlen_kvshould be the maximum query and key/value sequence length in a batch. When unset, Transformer Engine deduces them fromcu_seqlens_qandcu_seqlens_kv. This deduction costs a small kernel and some CPU-GPU synchronization, and to avoid this overhead, users are recommended to obtain the maximum sequence lengths from the data loaders and pass them in.As the maximum sequence lengths, batch size, and number of tokens change from batch to batch, dynamic shapes need to be supported for tensor construction. FlashAttention and UnfusedDotProductAttention naturally do so, while FusedAttention requires parameters to be static to create graphs before performance heuristics analysis. To reduce the number of graphs created per run, Transformer Engine 1.13+ quantizes relevant parameters: for cuDNN < 9.6, {batch size,
max_seqlen_q,max_seqlen_kv}, and for cuDNN >= 9.6, {“t” dimension ofquery_layer, “t” dimension ofkey_layer}.
- Parameters:
query_layer (torch.Tensor) – Query tensor.
key_layer (torch.Tensor) – Key tensor.
value_layer (torch.Tensor) – Value tensor.
attention_mask (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],) – default = None. Boolean tensor(s) used to mask out attention softmax input. It should be
Nonefor causal masks and"no_mask". For padding masks, it should be a single tensor of[batch_size, 1, 1, seqlen_q]for self-attention, and a tuple of two tensors of shapes[batch_size, 1, 1, seqlen_q]and[batch_size, 1, 1, seqlen_kv]for cross-attention. For"arbitrary"mask, it should be of a shape broadcastable to[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. ATruevalue means the corresponding position is masked out and aFalsemeans that position is allowed to participate in attention.qkv_format (str, default = None) – If provided, overrides
qkv_formatfrom initialization.cu_seqlens_q (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (without offset) in a batch for
query_layer, with shape [batch_size + 1] and dtype torch.int32. See note for more details.cu_seqlens_kv (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (without offset) in a batch for
key_layerandvalue_layer, with shape [batch_size + 1] and dtype torch.int32. See note for more details.cu_seqlens_q_padded (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (with offset) in a batch for
query_layer, with shape[batch_size + 1]and dtype torch.int32. When there is no padding between sequences in a batch,cu_seqlens_q_padded=cu_seqlens_q. See note for more details.cu_seqlens_kv_padded (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (with offset) in a batch for
key_layerandvalue_layer, with shape[batch_size + 1]and dtype torch.int32. When there is no padding between sequences in a batch,cu_seqlens_kv_padded=cu_seqlens_kv. See note for more details.max_seqlen_q (Optional[int], default = None) – Maximum sequence length in
query_layer. See note for more details.max_seqlen_kv (Optional[int], default = None) – Maximum sequence length in
key_layerandvalue_layer. See note for more details.attn_mask_type ({'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',) – ‘padding_causal’, ‘causal_bottom_right’, ‘padding_causal_bottom_right’, ‘arbitrary’}, default = None. Type of attention mask passed into softmax operation. ‘padding,causal’, ‘causal,padding’ and ‘padding_causal’ are equivalent. By default, causal masks are aligned to the top left corner of the softmax matrix. When
"bottom_right"is specified in the mask type, causal masks are aligned to the bottom right corner.window_size (Optional[Tuple[int, int]], default = None) – Sliding window size for local attention.
checkpoint_core_attention (bool, default = False) – If true, forward activations for attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.
core_attention_bias_type (str, default = "no_bias") – Bias type, {
"no_bias","pre_scale_bias","post_scale_bias","alibi"}core_attention_bias (Optional[torch.Tensor], default = None) – Bias tensor for \(Q \cdot K^T\), shape
[1, num_head, max_seqlen_q, max_seqlen_kv]. It should beNonefor"no_bias"and"alibi"bias types.alibi_slopes (Optional[torch.Tensor], default = None) – ALiBi slopes in FP32 and shape
[nheads]or[batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j.fast_zero_fill (bool, default = True) – Whether to use the fast path to set output tensors to 0 or not.
inference_params (Optional[InferenceParams], default = None) – Optimizes execution performance during inference by caching Keys and Values of the current decoding iteration. These cached values are appended to the K and V values computed in previous iterations, eliminating the need to recalculate them for the entire sequence. Initialization of
inference_paramsis required prior to use to ensure sufficient memory allocation. Adjustments of the sequence_len_offset should be done after a complete forward pass. If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand. Supports “sbhd” and “bshd” layouts, with the “sbhd” layout being more efficient.pad_between_seqs (Optional[bool], default = None) – If
None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. IfTrue, there are padding tokens between individual sequences in a packed batch.fp8_output (Optional[bool], default = False) – Whether to enforce output to be in FP8 or not.
num_splits (Optional[int], default = 1) – Optional split control for FlashAttention-3 only. When set, this value is forwarded to the FA3 backend to control internal kernel splitting behavior for non-context-parallel cases. It is ignored for other backends and when context parallelism is enabled.
- set_context_parallel_group(cp_group: transformer_engine.pytorch.constants.dist_group_type | List[transformer_engine.pytorch.constants.dist_group_type] | None, cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, cp_comm_type: str = 'p2p') None
Set the context parallel attributes for the given module before executing the forward pass.
- Parameters:
cp_group (Union[ProcessGroup, List[ProcessGroup]]) – context parallel process group.
ProcessGroupis forcp_comm_typeof"p2p","all_gather", and"a2a".List[ProcessGroup]is forcp_comm_typeof"a2a+p2p", wherecp_group[0]andcp_group[1]are for"a2a"and"p2p"communications respectively.cp_global_ranks (List[int]) – list of global ranks in the context group.
cp_stream (torch.cuda.Stream) – cuda stream for context parallel execution.
cp_comm_type (str, default = "p2p") –
inter-gpu communication type for context parallelism. Can be
"p2p"or"all_gather"or"a2a"or"a2a+p2p"."p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute."all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped."a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV."a2a+p2p": hierarchical CP implementation. First applying a2a to QKV across each CP sub-group (e.g., via NVLink), then exchanging KV with p2p between sub-groups (e.g., via IBLink).
- class transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)
Multi-head Attention (MHA), including Query, Key, Value and Output projection.
Note
Argument
attention_maskin theforward()method is only used whenattn_mask_typeincludes"padding"or"arbitrary".- Parameters:
hidden_size (int) – size of each input sample.
num_attention_heads (int) – number of attention heads in the transformer layer.
kv_channels (int, default = None) – number of key-value channels. defaults to
hidden_size/num_attention_headsifNone.attention_dropout (float, default = 0.1) – dropout probability for the dropout op during multi-head attention.
layernorm_epsilon (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
init_method (Callable, default = None) – used for initializing weights of QKV and FC1 weights in the following way:
init_method(weight). When set toNone, defaults totorch.nn.init.normal_(mean=0.0, std=0.023).output_layer_init_method (Callable, default = None) – used for initializing weights of PROJ and FC2 in the following way:
output_layer_init_method(weight). When set toNone, defaults totorch.nn.init.normal_(mean=0.0, std=0.023).layer_number (int, default = None) – layer number of the current
TransformerLayerwhen multiple such modules are concatenated to form a transformer block.attn_mask_type ({'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',) – ‘padding_causal_bottom_right’,’arbitrary’}, default = “causal” type of attention mask passed into softmax operation. Overridden by
attn_mask_typein theforward()method. Theforward()arg is useful for dynamically changing mask types, e.g. a different mask for training and inference. The__init__()arg is useful for cases involving compilation/tracing, e.g. ONNX export.window_size (Optional[Tuple[int, int]], default = None) – sliding window size for local attention, where query at position i attends to keys in
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]]inclusive. Special cases(-1, -1)and(-1, 0)mean no sliding window and causal mask specifically. Both"causal"and"causal_bottom_right"masks map towindow_size = (-1, 0)and Transformer Engine distinguishes them based onattn_mask_type. Similar toattn_mask_type,window_sizecan be overridden bywindow_sizeinforward()as well.num_gqa_groups (int, default = None) –
number of GQA groups in the transformer layer. Grouped Query Attention is described in this paper. This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e.
num_gqa_groups = num_attention_heads.return_layernorm_output (bool, default = False) – if set to
True, output of layernorm is returned from theforward()method together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.input_layernorm (bool, default = False) – if set to
True, layer normalization to the input is applied.attention_type ({ 'self', 'cross' }, default = 'self') – type of attention applied.
zero_centered_gamma (bool, default = 'False') –
if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.
qkv_weight_interleaved (bool, default = True) – if set to
False, the QKV weight is interpreted as a concatenation of query, key, and value weights along the0thdimension. The default interpretation is that the individualq,k, andvweights for each attention head are interleaved. This parameter is set toFalsewhen usingfuse_qkv_params=False.rotary_pos_interleaved (bool, default = False) – whether to use interleaved rotary position embeddings.
bias (bool, default = True) – if set to
False, the transformer layer will not learn any additive biases.device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will be allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.
qkv_format (str, default = "sbhd") – dimension format for
query_layer,key_layerandvalue_layer, {"sbhd","bshd"}.sstands for the sequence length,bbatch size,hthe number of heads anddhead size."sbhd"and"bshd"formats are used for when sequences in a batch are of equal length or padded to equal length. Please note that these formats do not reflect how tensorsquery_layer,key_layer,value_layerare laid out in memory. For that, please useget_qkv_layoutto gain the layout information.name (str, default = None) – name of the module, currently used for debugging purposes.
softmax_type (str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla') –
Softmax type as described in the paper Efficient Streaming Language Models with Attention Sinks.
For a given attention score \(S = Q \cdot K^T\), of shape
[b, h, s_q, s_kv]:'vanilla':\[S_{:,:,:,i} = = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}\]'off-by-one':\[S_{:,:,:,i} = = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}\]'learnable':\[S_{:,:,:,i} = = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}\]where \(\alpha\) is a learnable parameter of shape
[h].
'off-by-one'and'learnable'softmax types are also called sink attention ('zero sink'and'learnable sink').
- Parallelism parameters:
set_parallel_mode (bool, default = False) – if set to
True, QKV and FC1 layers are used as Column Parallel whereas PROJ and FC2 is used as Row Parallel as described here.sequence_parallel (bool, default = False) – if set to
True, uses sequence parallelism.tp_group (ProcessGroup, default = None) – tensor parallel process group.
tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the
set_tensor_parallel_group(tp_group)method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.
- Optimization parameters:
fuse_wgrad_accumulation (bool, default = ‘False’) – if set to
True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additionalmain_gradattribute (used instead of the regulargrad) which is a pre-allocated buffer of the correct size to accumulate gradients in.params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
return_bias (bool, default = False) – when set to
True, this module will not apply the additive bias itself, but instead return the bias value during theforward()method together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.fuse_qkv_params (bool, default = ‘False’) – if set to
True,TransformerLayermodule exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argumentfuse_wgrad_accumulation.qk_norm_type (Optional[str], default = None) – type of normalization to apply to query and key tensors. Options:
None,'L2Normalization','RMSNorm','LayerNorm'. WhenNone, no normalization is applied. When'L2Normalization', L2 normalization is applied to query and key tensors. When'RMSNorm', RMS normalization is applied to query and key tensors. When'LayerNorm', layer normalization is applied to query and key tensors. Normalization is applied after RoPE (if applicable) but before attention computation whenqk_norm_before_ropeisFalse. This follows the e.g. Llama4 approach for QK normalization to improve training stability and model performance.qk_norm_eps (float, default = 1e-6) – epsilon value for normalization of query and key tensors. Only used when
qk_norm_typeis notNone.qk_norm_before_rope (bool, default = False) – if set to
True, query and key normalization is applied before rotary position embedding. WhenFalse(default), normalization is applied after RoPE. This parameter allows supporting different architectural variants that apply QK normalization at different points.seq_length (Optional[int], default = None) – sequence length of input samples. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propagation and activation recompute phase.
micro_batch_size (Optional[int], default = None) – batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propagation and activation recompute phase.
- forward(hidden_states: torch.Tensor, attention_mask: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None, encoder_output: torch.Tensor | None = None, attn_mask_type: str | None = None, window_size: Tuple[int, int] | None = None, is_first_microbatch: bool | None = None, checkpoint_core_attention: bool = False, inference_params: transformer_engine.pytorch.attention.inference.InferenceParams | None = None, rotary_pos_emb: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None, core_attention_bias_type: str = 'no_bias', core_attention_bias: torch.Tensor | None = None, alibi_slopes: torch.Tensor | None = None, cu_seqlens_q: torch.Tensor | None = None, cu_seqlens_kv: torch.Tensor | None = None, cu_seqlens_q_padded: torch.Tensor | None = None, cu_seqlens_kv_padded: torch.Tensor | None = None, max_seqlen_q: int | None = None, max_seqlen_kv: int | None = None, fast_zero_fill: bool = True, pad_between_seqs: bool | None = None) Tuple[torch.Tensor | None, Ellipsis]
Forward propagation for MultiheadAttention layer.
Note
Argument
attention_maskis only used whenattn_mask_typeincludes"padding"or"arbitrary".- Parameters:
hidden_states (torch.Tensor) – Input tensor.
attention_mask (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],) – default = None. Boolean tensor(s) used to mask out attention softmax input. It should be
Nonefor causal masks and"no_mask". For padding masks, it should be a single tensor of[batch_size, 1, 1, seqlen_q]for self-attention, and a tuple of two tensors of shapes[batch_size, 1, 1, seqlen_q]and[batch_size, 1, 1, seqlen_kv]for cross-attention. For"arbitrary"mask, it should be of a shape broadcastable to[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. ATruevalue means the corresponding position is masked out and aFalsemeans that position is allowed to participate in attention.attn_mask_type ({'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',) – ‘padding_causal_bottom_right’,’arbitrary’}, default = None type of attention mask passed into softmax operation. By default, causal masks are aligned to the top left corner of the softmax matrix. When
"bottom_right"is specified in the mask type, causal masks are aligned to the bottom right corner.window_size (Optional[Tuple[int, int]], default = None) – sliding window size for local attention.
encoder_output (Optional[torch.Tensor], default = None) – Output of the encoder block to be fed into the decoder block if using
layer_type="decoder".is_first_microbatch ({True, False, None}, default = None) –
During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:
during FP8 training, it allows caching of the FP8 versions of the weights
it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)
checkpoint_core_attention (bool, default = False) – If
True, forward activations for core attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.rotary_pos_emb (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = None) – Embeddings for query and key tensors for applying rotary position embedding. By default no input embedding is applied.
core_attention_bias_type (str, default = "no_bias") – Bias type, {
"no_bias","pre_scale_bias","post_scale_bias","alibi"}core_attention_bias (Optional[torch.Tensor], default = None) – Bias tensor for \(Q \cdot K^T\), shape
[1, num_head, max_seqlen_q, max_seqlen_kv]. It should beNonefor"no_bias"and"alibi"bias types.alibi_slopes (Optional[torch.Tensor], default = None) – ALiBi slopes in FP32 and shape
[nheads]or[batch_size, nheads]. It adds a bias of(-alibi_slope * (i + seqlen_k - seqlen_q - j))to the attention score of query i and key j.cu_seqlens_q (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (without offset) in a batch for
query_layer, with shape[batch_size + 1]and dtype torch.int32.cu_seqlens_kv (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (without offset) in a batch for
key_layerandvalue_layer, with shape[batch_size + 1]and dtype torch.int32.cu_seqlens_q_padded (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (with offset) in a batch for
query_layer, with shape[batch_size + 1]and dtype torch.int32.cu_seqlens_kv_padded (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (with offset) in a batch for
key_layerandvalue_layer, with shape[batch_size + 1]and dtype torch.int32.max_seqlen_q (Optional[int], default = None) – Maximum sequence length in
query_layer. Calculated fromcu_seqlens_qif not provided.max_seqlen_kv (Optional[int], default = None) – Maximum sequence length in
key_layerandvalue_layer. Calculated fromcu_seqlens_kvif not provided.fast_zero_fill (bool, default = True) – Whether to set output tensors to 0 or not before use.
pad_between_seqs (Optional[bool], default = None) – If
None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. IfTrue, there are padding tokens between individual sequences in a packed batch.
- set_context_parallel_group(cp_group: transformer_engine.pytorch.constants.dist_group_type | List[transformer_engine.pytorch.constants.dist_group_type] | None, cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, cp_comm_type: str = 'p2p') None
Set the context parallel attributes for the given module before executing the forward pass.
- Parameters:
cp_group (Union[ProcessGroup, List[ProcessGroup]]) – context parallel process group.
ProcessGroupis forcp_comm_typeof"p2p","all_gather", and"a2a".List[ProcessGroup]is forcp_comm_typeof"a2a+p2p", wherecp_group[0]andcp_group[1]are for"a2a"and"p2p"communications respectively.cp_global_ranks (List[int]) – list of global ranks in the context group.
cp_stream (torch.cuda.Stream) – cuda stream for context parallel execution.
cp_comm_type (str, default = "p2p") –
inter-gpu communication type for context parallelism. Can be
"p2p"or"all_gather"or"a2a"or"a2a+p2p"."p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute."all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped."a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV."a2a+p2p": hierarchical CP implementation. First applying a2a to QKV across each CP sub-group (e.g., via NVLink), then exchanging KV with p2p between sub-groups (e.g., via IBLink).
- set_tensor_parallel_group(tp_group: transformer_engine.pytorch.constants.dist_group_type | None) None
Set the tensor parallel group for the given module before executing the forward pass.
- Parameters:
tp_group (ProcessGroup, default = None) – tensor parallel process group.
- class transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
TransformerLayer is made up of an attention block and a feedforward network (MLP). This standard layer is based on the paper “Attention Is All You Need”.
Note
Argument
attention_maskin theforward()call is only used whenself_attn_mask_typeincludes"padding"or"arbitrary".- Parameters:
hidden_size (int) – size of each input sample.
ffn_hidden_size (int) – intermediate size to which input samples are projected.
num_attention_heads (int) – number of attention heads in the transformer layer.
num_gqa_groups (int, default = None) –
number of GQA groups in the transformer layer. Grouped Query Attention is described in this paper. This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e.
num_gqa_groups = num_attention_heads.layernorm_epsilon (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
hidden_dropout (float, default = 0.1) – dropout probability for the dropout op after FC2 layer.
attention_dropout (float, default = 0.1) – dropout probability for the dropout op during multi-head attention.
init_method (Callable, default = None) – used for initializing weights of QKV and FC1 weights in the following way:
init_method(weight). When set toNone, defaults totorch.nn.init.normal_(mean=0.0, std=0.023).output_layer_init_method (Callable, default = None) – used for initializing weights of PROJ and FC2 in the following way:
output_layer_init_method(weight). When set toNone, defaults totorch.nn.init.normal_(mean=0.0, std=0.023).apply_residual_connection_post_layernorm (bool, default = False) – if set to
True, residual connections are taken from the output of layer norm (default is taken from input of layer norm)layer_number (int, default = None) – layer number of the current
TransformerLayerwhen multiple such modules are concatenated to form a transformer block.output_layernorm (bool, default = False) – if set to
True, layer normalization is applied on the output side, after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation.parallel_attention_mlp (bool, default = False) – if set to
True, self-attention and feedforward network are computed based on the same input (in parallel) instead of sequentially. Both blocks have an independent normalization. This architecture is used in Falcon models.layer_type ({'encoder', 'decoder'}, default = "encoder") – if set to
"decoder", an additional cross-attn block is added after self-attn. This can be used for structures like T5 Transformer in conjunction with the"encoder"option.kv_channels (int, default = None) – number of query-key-value channels per attention head. defaults to
hidden_size/num_attention_headsifNone.self_attn_mask_type ({'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',) – ‘padding_causal_bottom_right’, ‘arbitrary’}, default = “causal” type of attention mask passed into softmax operation for encoder. Overridden by
self_attn_mask_typein theforward()method. Theforward()arg is useful for dynamically changing mask types, e.g. a different mask for training and inference. The__init__()arg is useful for cases involving compilation/tracing, e.g. ONNX export.window_size (Optional[Tuple[int, int]], default = None) – sliding window size for local attention in encoder, where query at position i attends to keys in
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]]inclusive. Special cases(-1, -1)and(-1, 0)mean no sliding window and causal mask specifically. Both"causal"and"causal_bottom_right"masks map towindow_size=(-1, 0)and Transformer Engine distinguishes them based onself_attn_mask_typeorenc_dec_attn_mask_type. Similar toself_attn_mask_type,window_sizecan be overridden bywindow_sizeinforward()as well.enc_dec_attn_mask_type ({'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},) – default = “no_mask” type of attention mask passed into softmax operation for decoder.
enc_dec_window_size (Optional[Tuple[int, int]], default = None) – sliding window size for local attention in decoder.
zero_centered_gamma (bool, default = False) –
if set to
True, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.
qkv_weight_interleaved (bool, default = True) – if set to
False, the QKV weight is interpreted as a concatenation of query, key, and value weights along the0thdimension. The default interpretation is that the individualq,k, andvweights for each attention head are interleaved. This parameter is set toFalsewhen usingfuse_qkv_params=False.rotary_pos_interleaved (bool, default = False) – whether to use interleaved rotary position embeddings.
bias (bool, default = True) – if set to
False, the transformer layer will not learn any additive biases.activation (str, default = 'gelu') – Type of activation used in MLP block. Options are:
'gelu','geglu','qgelu','qgeglu','relu','reglu','srelu','sreglu','silu','swiglu', and'clamped_swiglu'.activation_params (Optional[dict], default = None) – Additional parameters for the activation function. At the moment, only used for
'clamped_swiglu'activation which supports'limit'and'alpha'parameters. You can set these asactivation_params={'limit': 7.0, 'alpha': 1.702}.device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will be allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.
attn_input_format ({'sbhd', 'bshd', 'thd'}, default = 'sbhd') – This controls whether the dimensions of the intermediate hidden states is ‘sequence first’ (
'sbhd'), ‘batch first’ ('bshd'), or ‘token first’ ('thd').sstands for the sequence length,bbatch size,tthe total number of tokens,hthe number of heads,dhead size. Note that these formats are very closely related to theqkv_formatparameter in theMultiHeadAttentionandDotProductAttentionmodules.name (str, default = None) – name of the module, currently used for debugging purposes.
softmax_type (str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla') –
Softmax type as described in the paper Efficient Streaming Language Models with Attention Sinks.
For a given attention score \(S = Q \cdot K^T\), of shape
[b, h, s_q, s_kv]:'vanilla':\[Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}\]'off-by-one':\[Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}\]'learnable':\[Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}\]where \(\\alpha\) is a learnable parameter of shape
[h].
'off-by-one'and'learnable'softmax types are also called sink attention ('zero sink'and'learnable sink').
- Parallelism parameters:
set_parallel_mode (bool, default = False) – if set to
True, QKV and FC1 layers are used as Column Parallel whereas PROJ and FC2 is used as Row Parallel as described here.sequence_parallel (bool, default = False) – if set to
True, uses sequence parallelism.tp_group (ProcessGroup, default = None) – tensor parallel process group.
tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the
set_tensor_parallel_group()method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.
- Optimization parameters:
fuse_wgrad_accumulation (bool, default = False) – if set to
True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additionalmain_gradattribute (used instead of the regulargrad) which is a pre-allocated buffer of the correct size to accumulate gradients in.params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
seq_length (int) – sequence length of input samples. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.
micro_batch_size (int) – batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.
drop_path_rate (float, default = 0.0) – when > 0.0, applies stochastic depth per sample in the main path of the residual block.
fuse_qkv_params (bool, default = False) – if set to
True,TransformerLayermodule exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argumentfuse_wgrad_accumulation.qk_norm_type (Optional[str], default = None) – type of normalization to apply to query and key tensors. Options:
None,'L2Normalization','RMSNorm','LayerNorm'. WhenNone, no normalization is applied. When'L2Normalization', L2 normalization is applied to query and key tensors. When'RMSNorm', RMS normalization is applied to query and key tensors. When'LayerNorm', layer normalization is applied to query and key tensors. Normalization is applied after RoPE (if applicable) but before attention computation whenqk_norm_before_ropeisFalse. This follows the e.g. Llama4 approach for QK normalization to improve training stability and model performance.qk_norm_eps (float, default = 1e-6) – epsilon value for normalization of query and key tensors. Only used when
qk_norm_typeis notNone.qk_norm_before_rope (bool, default = False) – if set to
True, query and key normalization is applied before rotary position embedding. WhenFalse(default), normalization is applied after RoPE. This parameter allows supporting different architectural variants that apply QK normalization at different points.
- forward(hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, self_attn_mask_type: str | None = None, window_size: Tuple[int, int] | None = None, encoder_output: torch.Tensor | None = None, enc_dec_attn_mask: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None, enc_dec_attn_mask_type: str | None = None, enc_dec_window_size: Tuple[int, int] | None = None, is_first_microbatch: bool | None = None, checkpoint_core_attention: bool = False, inference_params: transformer_engine.pytorch.attention.inference.InferenceParams | None = None, rotary_pos_emb: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None, core_attention_bias_type: str = 'no_bias', core_attention_bias: torch.Tensor | None = None, alibi_slopes: torch.Tensor | None = None, cu_seqlens_q: torch.Tensor | None = None, cu_seqlens_kv: torch.Tensor | None = None, cu_seqlens_q_padded: torch.Tensor | None = None, cu_seqlens_kv_padded: torch.Tensor | None = None, max_seqlen_q: int | None = None, max_seqlen_kv: int | None = None, fast_zero_fill: bool = True, pad_between_seqs: bool | None = None) torch.Tensor
Transformer Layer: attention block and a feedforward network (MLP)
Note
Argument
attention_maskis only used whenself_attn_mask_typeincludes"padding"or"arbitrary".- Parameters:
hidden_states (torch.Tensor) – Input tensor.
attention_mask (Optional[torch.Tensor], default = None) – Boolean tensor used to mask out self-attention softmax input. It should be in
[batch_size, 1, 1, seqlen_q]for padding masks, and broadcastable to[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]for"arbitrary"mask. It should beNonefor causal masks and"no_mask"type. ATruevalue means the corresponding position is masked out and aFalsemeans that position is allowed to participate in attention.self_attn_mask_type ({'no_mask', 'causal', 'padding', 'padding_causal',) – ‘causal_bottom_right’, ‘padding_causal_bottom_right’,’arbitrary’}, default = “causal” Type of attention mask passed into softmax operation for encoder. By default, causal masks are aligned to the top left corner of the softmax matrix. When
"bottom_right"is specified in the mask type, causal masks are aligned to the bottom right corner.window_size (Optional[Tuple[int, int]], default = None) – Sliding window size for local attention in encoder.
encoder_output (Optional[torch.Tensor], default = None) – Output of the encoder block to be fed into the decoder block if using
layer_type="decoder".enc_dec_attn_mask (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],) – default = None. Boolean tensors used to mask out inter-attention softmax input if using
layer_type="decoder". It should be a tuple of two masks in[batch_size, 1, 1, seqlen_q]and[batch_size, 1, 1, seqlen_kv]for padding masks. It should be broadcastable to[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]for"arbitrary"mask. It should beNonefor causal masks and"no_mask". ATruevalue means the corresponding position is masked out and aFalsemeans that position is allowed to participate in attention.enc_dec_attn_mask_type ({'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},) – default = None Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size (Optional[Tuple[int, int]], default = None) – Sliding window size for local attention in decoder.
is_first_microbatch ({True, False, None}, default = None) –
During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:
during FP8 training, it allows caching of the FP8 versions of the weights
it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)
checkpoint_core_attention (bool, default = False) – If
True, forward activations for core attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.rotary_pos_emb (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = None) – Embeddings for query and key tensors for applying rotary position embedding. By default no input embedding is applied.
core_attention_bias_type (str, default = "no_bias") – Bias type, {
"no_bias","pre_scale_bias","post_scale_bias","alibi"}core_attention_bias (Optional[torch.Tensor], default = None) – Bias tensor for \(Q \cdot K^T\)
alibi_slopes (Optional[torch.Tensor], default = None) – ALiBi slopes in FP32 and shape
[nheads]or[batch_size, nheads]. It adds a bias of \((-\text{alibi_slope} \cdot (i + \text{seqlen_k} - \text{seqlen_q} - j))\) to the attention score of query i and key j.cu_seqlens_q (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (without offset) in a batch for query layer, with shape
[batch_size + 1]and dtype torch.int32. Used by encoders, or decoders’ self-attention.cu_seqlens_kv (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (without offset) in a batch for key layer and value layer, with shape
[batch_size + 1]and dtype torch.int32. Used by decoders’ cross-attention.cu_seqlens_q_padded (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (with offset) in a batch for query layer, with shape
[batch_size + 1]and dtype torch.int32. Set tocu_seqlens_qifNone. Used by encoders, or decoders’ self-attention.cu_seqlens_kv_padded (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (with offset) in a batch for key layer and value layer, with shape
[batch_size + 1]and dtype torch.int32. Set tocu_seqlens_kvifNone. Used by decoders’ cross-attention.max_seqlen_q (Optional[int], default = None) – Maximum sequence length in query layer. Calculated from
cu_seqlens_q_paddedif not provided.max_seqlen_kv (Optional[int], default = None) – Maximum sequence length in key layer and value layer. Calculated from
cu_seqlens_kv_paddedif not provided.fast_zero_fill (bool, default = True) – Whether to set output tensors to 0 or not before use.
inference_params (InferenceParams, default = None) – Inference parameters that are passed to the main model in order to efficiently calculate and store the context during inference.
pad_between_seqs (Optional[bool], default = None) – If
None, inferred fromqkv_format, cu_seqlens and cu_seqlens_padded. IfTrue, there are padding tokens between individual sequences in a packed batch, i.e.qkv_format='thd'.
- set_context_parallel_group(cp_group: transformer_engine.pytorch.constants.dist_group_type | List[transformer_engine.pytorch.constants.dist_group_type] | None, cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, cp_comm_type: str = 'p2p') None
Set the context parallel attributes for the given module before executing the forward pass.
- Parameters:
cp_group (Union[ProcessGroup, List[ProcessGroup]]) – context parallel process group. ProcessGroup is for cp_comm_type of
"p2p","all_gather", and"a2a". List[ProcessGroup] is for cp_comm_type of"a2a+p2p", wherecp_group[0]andcp_group[1]are for a2a and p2p communications respectively.cp_global_ranks (List[int]) – list of global ranks in the context group.
cp_stream (torch.cuda.Stream) – cuda stream for context parallel execution.
cp_comm_type (str, default = "p2p") –
inter-gpu communication type for context parallelism. Can be
"p2p"or"all_gather"or"a2a"or"a2a+p2p"."p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute."all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped."a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV."a2a+p2p": hierarchical CP implementation. First applying a2a to QKV across each CP sub-group (e.g., via NVLink), then exchanging KV with p2p between sub-groups (e.g., via IBLink).
- set_tensor_parallel_group(tp_group: transformer_engine.pytorch.constants.dist_group_type | None) None
Set the tensor parallel group for the given module before executing the forward pass.
- Parameters:
tp_group (ProcessGroup, default = None) – tensor parallel process group.
- class transformer_engine.pytorch.CudaRNGStatesTracker
For model parallelism, multiple RNG states need to simultaneously exist in order to execute operations in or out of the model parallel region. This class keeps track of the various RNG states and provides utility methods to maintain them and execute parts of the model under a given RNG setting. Using the
add()method, a cuda rng state is initialized based on the inputseedand is assigned toname. Later, by forking the rng state, we can perform operations and return to our starting cuda state.- add(name: str, seed: int) None
Adds a new RNG state.
- Parameters:
name (str) – string identifier for the RNG state.
seed (int) – PyTorch seed for the RNG state.
- fork(name: str = 'model-parallel-rng')
Fork the cuda rng state, perform operations, and exit with the original state.
- Parameters:
name (str) – string identifier for the RNG state.
- get_states() Dict[str, torch.Tensor]
Get rng states. Copy the dictionary so we have direct pointers to the states, not just a pointer to the dictionary.
- reset()
Set to the initial state (no tracker).
- set_states(states: Dict[str, torch.Tensor]) None
Set the rng states. For efficiency purposes, we do not check the size of seed for compatibility.
- Parameters:
states (Dict[str, torch.Tensor]) – A mapping from string names to RNG states.
- transformer_engine.pytorch.autocast(enabled: bool = True, calibrating: bool = False, recipe: transformer_engine.common.recipe.Recipe | None = None, amax_reduction_group: transformer_engine.pytorch.constants.dist_group_type | None = None, _graph: bool = False) None
Context manager for quantization schemes like FP8 or FP4.
with autocast(enabled=True): out = model(inp)
Note
Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors with shapes where both dimensions are divisible by 16. In terms of the input to the full Transformer network, this typically requires padding sequence length to be multiple of 16.
Note
When
recipe.reduce_amax==True, any module must not be invoked more than once inside a single autocast region. This is unsupported behavior because the amax reduction is handled during the exit of the autocast context. Calling the same module more than once inside an autocast region overrides the amax tensors before reduction can occur.- Parameters:
enabled (bool, default = True) – whether or not to enable low precision quantization (FP8/FP4).
calibrating (bool, default = False) – calibration mode allows collecting statistics such as amax and scale data of quantized tensors even when executing without quantization enabled. This is useful for saving an inference ready checkpoint while training using a higher precision.
recipe (recipe.Recipe, default = None) – recipe used for low precision quantization.
amax_reduction_group (torch._C._distributed_c10d.ProcessGroup, default = None) – distributed group over which amaxes for the quantized tensors are reduced at the end of each training step.
- transformer_engine.pytorch.quantized_model_init(enabled: bool = True, recipe: transformer_engine.common.recipe.Recipe | None = None, preserve_high_precision_init_val: bool = False) None
Context manager for initialization of quantized parameters.
Example usage:
with quantized_model_init(enabled=True): model = transformer_engine.pytorch.Linear(768, 768) # Preserving high precision initial value to initialize master weight with quantized_model_init(enabled=True, preserve_high_precision_init_val=True): model = transformer_engine.pytorch.Linear(768, 768) master_weight = model.weight.get_high_precision_init_val() model.weight.clear_high_precision_init_val()
- Parameters:
enabled (bool, default = True) –
when enabled, Transformer Engine modules created inside this quantized_model_init region will hold only quantized copies of its parameters, as opposed to the default behavior where both higher precision and quantized copies are present. Setting this option to True may result in lower memory consumption and is especially useful for scenarios like:
full model training using optimizer with master weights, where the high precision copies of weights are already present in the optimizer.
inference, where only the quantized copies of the parameters are used.
LoRA-like fine-tuning, where the main parameters of the model do not change.
recipe (transformer_engine.common.recipe.Recipe, default = None) – Recipe used to create the parameters. If left to None, it uses the default recipe.
preserve_high_precision_init_val (bool, default = False) –
when enabled, store the high precision tensor used to initialize quantized parameters in CPU memory, and add two function attributes named get_high_precision_init_val() and clear_high_precision_init_val() to quantized parameters to get/clear this high precision tensor. The purpose is that users can use this high-precision copy to initialize master weights, avoiding the loss of precision that can occur when using quantized parameters directly. Note that after the master weights are initialized, users should call clear_high_precision_init_val() to release this CPU memory.
This functionality is EXPERIMENTAL.
- transformer_engine.pytorch.checkpoint(function: Callable, *args: Tuple[torch.Tensor, Ellipsis], **kwargs: Dict[str, Any]) Tuple[torch.Tensor, Ellipsis]
Checkpoint a part of the model by trading compute for memory. This function is based on torch.utils.checkpoint.checkpoint.
Warning
It is the user’s responsibility to ensure identical behavior when calling
functionfrom the forward and backward pass. If different output is produced (e.g. due to global state), then the checkpointed version won’t be numerically equivalent.Warning
use_reentrant=False does not support early stopping, and will execute the entire forward pass for the checkpointed module when recomputing activations in the backward pass.
- Parameters:
function (Callable) – pytorch module used to run the forward and backward passes using the specified
argsandkwargs.distribute_saved_activations (bool, default = False) – if set to
Trueanduse_reentrant=True, first tensor argument is distributed across the specified tensor parallel group (tp_group) before saving it for the backward pass. This has no effect whenuse_reentrant=False.get_rng_state_tracker (Callable, default = None) – python callable which returns an instance of
CudaRNGStatesTracker.tp_group (ProcessGroup, default = None) – tensor parallel process group. Used only when
distribute_saved_activations=Trueanduse_reentrant=True. IfNone, it falls back to the default group.use_reentrant (bool, default = True) – perform checkpointing in reentrant mode.
args (tuple) – tuple of torch tensors for inputs to
function.kwargs (dict) – dictionary of string keys for keyword arguments to
function.
- transformer_engine.pytorch.make_graphed_callables(modules: SingleOrTuple[Callable], sample_args: SingleOrTuple[Tuple[torch.Tensor, Ellipsis]], num_warmup_iters: int = 3, allow_unused_input: bool = False, sample_kwargs: SingleOrTuple[Dict[str, Any]] | None = None, fp8_enabled: SingleOrTuple[bool] | None = None, fp8_calibrating: bool | None = None, fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, fp8_group: transformer_engine.pytorch.constants.dist_group_type | None = None, fp8_weight_caching: bool | None = None, enabled: SingleOrTuple[bool] | None = None, calibrating: bool | None = None, recipe: transformer_engine.common.recipe.Recipe | None = None, amax_reduction_group: transformer_engine.pytorch.constants.dist_group_type | None = None, cache_quantized_params: bool | None = None, _order: List[int] | None = None, _num_layers_per_chunk: List[int] | None = None, pool: Tuple[int, Ellipsis] | None = None, retain_graph_in_backward: bool = False, _reuse_graph_input_output_buffers: bool = False) Callable | Tuple[Callable, Ellipsis]
Make CUDA graph version of Transformer Engine modules
A variation of PyTorch’s make_graphed_callables utility function with support for Transformer Engine modules and FP8. Please see the original PyTorch implementation for more documentation.
Warning
Arguments ‘fp8_enabled’, ‘fp8_calibrating’, ‘fp8_recipe’, ‘fp8_group’, and ‘fp8_weight_caching’ are deprecated. Use arguments ‘enabled’, ‘calibrating’, ‘recipe’, ‘amax_reduction_group’, and ‘cache_quantized_params’ instead.
- Graphing parameters:
modules ((tuple of) callable) – Callable or callables to graph.
sample_args ((tuple of) tuple of torch.Tensor) – Positional arguments to callable(s).
num_warmup_iters (int, default = 3) – Number of warmup iterations.
allow_unused_input (bool, default = False) – Whether to handle case where callable inputs and outputs are disconnected in compute graph.
sample_kwargs ((tuple of) dict, optional) – Keyword arguments to callable(s)
pool ((tuple of) int, default = None, optional) – An instance returned from function torch.cuda.graph_pool_handle that hints this graph may share memory with the indicated pool.
retain_graph_in_backward (bool, default = False) – Whether to set retain_graph=True in backward graph capture.
_reuse_graph_input_output_buffers (bool, default = False) – Reduce memory usage by reusing input/output data buffers between graphs. Only supported with Mcore interleaved pipeline parallelism, i.e. when _order is provided. All callables in modules are assumed to have inputs and outputs with the same dtype and shape.
- Quantization parameters:
enabled ((tuple of) bool, default = False) – whether or not to enable low precision quantization (FP8/FP4). If tuple, the length must match the number of modules.
calibrating (bool, default = False) – calibration mode allows collecting statistics such as amax and scale data of quantized tensors even when executing without quantization enabled. This is useful for saving an inference ready checkpoint while training using a higher precision.
recipe (recipe.Recipe, default = None) – recipe used for low precision quantization.
amax_reduction_group (torch._C._distributed_c10d.ProcessGroup, default = None) – distributed group over which amaxes for the quantized tensors are reduced at the end of each training step.
cache_quantized_params (bool, default = False) – Whether or not to cache quantized weights across microbatches. if set to True, the is_first_microbatch boolean argument must be passed into the forward method for TransformerEngine modules. When storing primary weights in low precision using TE’s quantized_model_init API and using an quantization aware optimizer, this arg must be set to False if calculating weight transposes’ outside TE, e.g., in the optimizer step.
- transformer_engine.pytorch.get_cpu_offload_context(enabled: bool = False, num_layers: int | None = 1, model_layers: int = 1, offload_activations: bool = True, offload_weights: bool = False, double_buffering: bool = False, manual_synchronization: bool = False, retain_pinned_cpu_buffers: bool = False, offload_stream: torch.cuda.Stream | None = None)
CPU Offloading feature for seqeuences of layers. Can be used for arbitrary layers, not necessarily for these provided by the TE.
Usage:
cpu_offload_context, sync_function = get_cpu_offload_context(...) for _ in range(num_layers): with cpu_offload_context: x = layers[i].forward(x) x = sync_function(x)
- Parameters:
enabled (bool, default = False) – When set to True, CPU Offloading functionality is enabled.
num_layers (int, default = 1) – Determines the number of layers you want to offload activations/weights for.
model_layers (int, default = 1) – Number of layers in the model that will be used under this context.
offload_activations (bool, default = True) – Deprecated.
offload_weights (bool, default = True) – Deprecated.
double_buffering (bool, default = False) – Deprecated.
retain_pinned_cpu_buffers (bool, default = False) – If True, the pinned CPU buffers are retained after offloading and reused for the next iteration. It is useful for cuda graphs capture.
manual_synchronization (bool, default = False) – If True, the synchronization is done manually by the user. Additional argument manual_controller is returned. See more in manual control section.
offload_stream (torch.cuda.Stream, default = None) – If provided, the offload stream is used for offloading and reloading. Otherwise, a new stream is allocated internally. It can be other than None only if manual_synchronization is True.
Notes
Manual synchronization:
By default, layers are offloaded/reloaded asynchronously with respect to the current forward/backward stream with predefined synchronization, to ensure that activation memory usage is equal to
(num_layers - num_offloaded_layers) * T, whereTis the memory footprint of a layer.For more control over the offloading and reloading process, you can set
manual_synchronization=True. In this case, an additional argument,manual_controller, is returned.The
manual_controllerprovides the following methods: -start_offload_layer(layer_id: int)-release_activation_forward_gpu_memory(layer_id: int)-start_reload_layer(layer_id: int)If none of these methods are invoked for a given layer, that layer will not be offloaded or reloaded. If
start_offload_layer()is called for a layer, offload copies for that layer begin asynchronously on the offload stream.Since GPU activations must be kept in memory until the copy is finished, pointers to all activations are stored. To release this memory, you need to call
release_activation_forward_gpu_memory(layer_id). This method makes the current stream wait for an event recorded on the offload stream after all tensors from the layer have been offloaded.The
start_reload_layer()method is used to start reloading a layer. Each tensor reload is awaited to finish beforetensor_pop()for that tensor is called on the current stream.You can provide an
offload_streamto be used for offload and reload operations. This allows for more detailed synchronization, such as delaying the start of offloading.Example:
offload_stream = torch.cuda.Stream() cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context( enabled=True, model_layers=num_layers, manual_synchronization=True, offload_stream=offload_stream) for i in range(num_layers): with cpu_offload_context: out[i] = layers[i].forward(inp[i]) out[i] = sync_function(out[i]) manual_controller.start_offload_layer(i) offload_stream.synchronize() for i in range(num_layers): manual_controller.release_activation_forward_gpu_memory(i) for i in range(num_layers - 1, -1, -1): manual_controller.start_reload_layer(i) offload_stream.synchronize() for i in range(num_layers): out[i].sum().backward()
V1 code path:
If you want to use the v1 code path for offloading, please set the environment variable
NVTE_CPU_OFFLOAD_V1to 1.
- transformer_engine.pytorch.parallel_cross_entropy(inp: torch.Tensor, target: torch.Tensor, label_smoothing: float = 0.0, reduce_loss: bool = False, dist_process_group: torch.distributed.ProcessGroup | None = None, ignore_idx: int = -100, is_cg_capturable: bool = False, *, _input: torch.Tensor | None = None) torch.Tensor
Cross Entropy loss with optional distributed reduction.
The input tensor can be in BF16/FP32, the loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted to the datatype of the input.
If
dist_process_groupis passed for distributed loss calculation, the input to each distributed rank should be(*, V/world_size). Note that each of the ranks should get equal shards along the V dimension.- Parameters:
inp (torch.Tensor) – The input tensor of shape
(B, SQ, V)or(SQ, B, V)where B is batch size, SQ is sequence length, V is vocab size.target (torch.Tensor) – The target tensor of shape
(B, SQ)or(SQ, B)where each value is in[0, V-1].label_smoothing (float, default = 0.0) – The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduce_loss (bool, default = False) – If True, returns the averaged loss across the B*SQ dimension.
dist_process_group (torch.distributed.ProcessGroup, default = None) – The distributed process group the loss computation is split across, None if on 1 device.
ignore_idx (int, default = -100) – The index for which loss and gradients are made to zero.
is_cg_capturable (bool, default = False) – Whether the operation is CUDA graph capturable.
- Returns:
The computed loss.
- Return type:
torch.Tensor
Recipe availability
- transformer_engine.pytorch.is_fp8_available(return_reason: bool = False) bool | Tuple[bool, str]
Determine if FP8 support is available for the delayed scaling and per tensor current scaling recipe.
- Parameters:
return_reason (bool, optional) – If
False(default), return only a boolean indicating availability. IfTrue, return a tuple(is_available, reason)wherereasonprovides a human-readable explanation when required support is not available. The reason will be an empty string if support for FP8 is available.
- transformer_engine.pytorch.is_mxfp8_available(return_reason: bool = False) bool | Tuple[bool, str]
Determine if support is available for the MXFP8 recipe.
- Parameters:
return_reason (bool, optional) – If
False(default), return only a boolean indicating availability. IfTrue, return a tuple(is_available, reason)wherereasonprovides a human-readable explanation when required support is not available. The reason will be an empty string if support for MXFP8 is available.
- transformer_engine.pytorch.is_fp8_block_scaling_available(return_reason: bool = False) bool | Tuple[bool, str]
Determine if support is available for the FP8 block scaling recipe.
- Parameters:
return_reason (bool, optional) – If
False(default), return only a boolean indicating availability. IfTrue, return a tuple(is_available, reason)wherereasonprovides a human-readable explanation when required support is not available. The reason will be an empty string if support for FP8 block scaling is available.
- transformer_engine.pytorch.is_nvfp4_available(return_reason: bool = False) bool | Tuple[bool, str]
Determine if support is available for the NVFP4 recipe.
- Parameters:
return_reason (bool, optional) – If
False(default), return only a boolean indicating availability. IfTrue, return a tuple(is_available, reason)wherereasonprovides a human-readable explanation when required support is not available. The reason will be an empty string if support for NVFP4 is available.
- transformer_engine.pytorch.is_bf16_available(return_reason: bool = False) bool | Tuple[bool, str]
Determine whether bfloat16 (BF16) computation is supported on the current device.
- Parameters:
return_reason (bool, optional) – If
False(default), return only a boolean indicating BF16 availability. IfTrue, return a tuple(is_available, reason)wherereasonprovides a human-readable explanation when BF16 is not available. When BF16 is available, the reason will be an empty string.
- transformer_engine.pytorch.get_cudnn_version() Tuple[int, int, int]
Runtime cuDNN version (major, minor, patch)
- transformer_engine.pytorch.get_device_compute_capability() Tuple[int, int]
CUDA compute capability of current GPU
- transformer_engine.pytorch.get_default_recipe() transformer_engine.common.recipe.Recipe
Returns the default training recipe based on available device.
Mixture of Experts (MoE) functions
- transformer_engine.pytorch.moe_permute(inp: torch.Tensor, routing_map: torch.Tensor, num_out_tokens: int = -1, max_token_num: int = -1, map_type: str = 'mask') Tuple[torch.Tensor, torch.Tensor]
Permute the tokens based on the routing_map. Token with the same index will be grouped together. Tokens with the same designated expert will be grouped together. The routing_map indicates which experts were selected by each token.
- Parameters:
inp (torch.Tensor) – Input tensor of shape [num_tokens, hidden_size], on which permutation will be applied.
routing_map (torch.Tensor) – The token to expert mapping tensor. If map_type is ‘mask’, routing_map is of shape [num_tokens, num_experts] and dtype ‘int32’. The values in it: 1 means the token is routed to this expert and 0 means not. If map_type is ‘index’, routing_map is of shape [num_tokens, topK] and dtype ‘int32’. The values in it are the routed expert indices.
num_out_tokens (int, default = -1) – The effective output token count, representing the number of tokens not dropped. By default, set to ‘-1’, meaning no tokens are dropped.
max_token_num (int, default = -1) – The maximum number of tokens, used for workspace allocation. By default, set to ‘-1’, meaning the calculation of the size of workspace is automatically taken over by the operator.
map_type (str, default = 'mask') – Type of the routing map tensor. Options are: ‘mask’, ‘index’. Refer to routing_map for more details.
- transformer_engine.pytorch.moe_permute_with_probs(inp: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor, num_out_tokens: int = -1) Tuple[torch.Tensor, torch.Tensor]
Permute the tokens and probs based on the routing_map. Token with the same index will be grouped together. Tokens with the same designated expert will be grouped together. The routing_map indicates which experts were selected by each token.
- Parameters:
inp (torch.Tensor) – Input tensor of shape [num_tokens, hidden_size], on which permutation will be applied.
probs (torch.Tensor) – The tensor of probabilities corresponding to the permuted tokens and is of shape [num_tokens, num_experts]. It will be permuted with the tokens according to the routing_map.
routing_map (torch.Tensor) – The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype ‘int32’. The values in it: 1 means the token is routed to this expert and 0 means not.
num_out_tokens (int, default = -1) – The effective output token count, representing the number of tokens not dropped. By default, set to ‘-1’, meaning no tokens are dropped.
- transformer_engine.pytorch.moe_unpermute(inp: torch.Tensor, row_id_map: torch.Tensor, merging_probs: torch.Tensor | None = None, restore_shape: torch.Size | None = None, map_type: str = 'mask', probs: torch.Tensor | None = None) torch.Tensor
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their corresponding probabilities.
- Parameters:
inp (torch.Tensor) – Input tensor with permuted tokens of shape [num_tokens, hidden_size] to be unpermuted.
row_id_map (torch.Tensor) – The tensor of a mapping table for sorted indices used to unpermute the tokens, which is the second output tensor of Permute.
merging_probs (torch.Tensor, default = None) – The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
restore_shape (torch.Size, default = None) – The output shape after the unpermute operation.
map_type (str, default = 'mask') – Type of the routing map tensor. Should be the same as the value passed to moe_permute. Options are: ‘mask’, ‘index’.
probs (torch.Tensor, default = None) – Renamed to merging_probs. Keep for backward compatibility.
- transformer_engine.pytorch.moe_sort_chunks_by_index(inp: torch.Tensor, split_sizes: torch.Tensor, sorted_index: torch.Tensor) Tuple[torch.Tensor, torch.Tensor]
Split and sort the input tensor based on the split_sizes and sorted indices. The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted according to the sorted_indices.
- Parameters:
inp (torch.Tensor) – Input tensor of shape [num_tokens, hidden_size], on which permutation will be applied.
split_sizes (torch.Tensor) – Chunk sizes of the inp tensor along the 0-th dimension.
sorted_indices (torch.Tensor) – Chunk indices used to permute the chunks.
- transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs(inp: torch.Tensor, probs: torch.Tensor, split_sizes: torch.Tensor, sorted_index: torch.Tensor) Tuple[torch.Tensor, torch.Tensor]
Split and sort the input tensor and probs based on the split_sizes and sorted indices. The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted according to the sorted_indices.
- Parameters:
inp (torch.Tensor) – Input tensor of shape [num_tokens, hidden_size], on which permutation will be applied.
probs (torch.Tensor) – The tensor of probabilities corresponding to the permuted tokens and is of shape [num_tokens]. It will be permuted with the tokens according to the split_sizes and sorted_indices.
split_sizes (torch.Tensor) – Chunk sizes of the inp tensor along the 0-th dimension.
sorted_indices (torch.Tensor) – Chunk indices used to permute the chunks.
Communication-computation overlap
- transformer_engine.pytorch.initialize_ub(shape: list, tp_size: int, use_fp8: bool = False, quantization_modes: List[UserBufferQuantizationMode] = None, dtype: torch.dtype = torch.bfloat16, ub_cfgs: dict | List[dict] | None = None, bootstrap_backend: str | torch.distributed.Backend = None) None
Initialize the Userbuffers communicator for overlapping tensor-parallel communications with GEMM compute in
te.Linear,te.LayerNormLinearandte.LayerNormMLPmodules.- Parameters:
shape (list) – shape of the communication buffer, typically set to be the same as the global shape of the input tensor to a
te.TransformerLayerforward pass, with the sequence and batch dimensions collapsed together – i.e.:(sequence_length * batch_size, hidden_size)tp_size (int) – number of GPUs in the tensor-parallel process group
use_fp8 (bool = False) – allocate the communication buffer for FP8 GEMM inputs/outputs. DEPRECATED: Please use
quantization_modesinstead.quantization_modes (List[UserBufferQuantizationMode] = None) – if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list. falls back to the legacy
use_fp8parameter ifNoneis provided.dtype (torch.dtype = torch.bfloat16) – non-FP8 data type of the communication buffer when
use_fp8 = Falseub_cfgs (dict = None) –
Configuration dictionary with the structure:
{ <gemm_name> : { "method": <"ring_exchange" or "pipeline">, "is_reduce_scatter": bool, "num_sm": int, "cga_size": int, "set_sm_margin": bool, "num_splits": int, "aggregate": bool, "atomic_gemm": bool, "use_ce": bool, "fp8_buf": bool, } }
for
te.TransformerLayerGEMM layers in["qkv_fprop", "qkv_dgrad", "qkv_wgrad", "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", "fc2_fprop", "fc2_wgrad"]. a list may be provided to specify different overlap configurations for different the quantization settings inquantization_modesbootstrap_backend (str = None) –
torch.distributedcommunication backend for the all-gather, broadcast and barrier collectives during Userbuffers initialization. Not all backends are valid for every cluster configuration and distributed launch method even if they are available in PyTorch. When left unset, the initialization prefers to use the MPI backend, falling back first on Gloo and then NCCL if MPI is not available. SettingNVTE_UB_WITH_MPI=1when building TE overrides this option and always initializes Userbuffers with direct MPI calls in C++, which also requiresMPI_HOME=/path/to/mpi/rootto be set at compile time.
- transformer_engine.pytorch.destroy_ub()
Destroy all allocated userbuffer communicators.
Quantized tensors
- class transformer_engine.pytorch.QuantizedTensorStorage
Base class for all TensorStorage classes.
This class (and its subclasses) are optimization for when the full QuantizedTensor is not needed (when it is fully contained inside torch.autograd function and not visible to PyTorch’s autograd).
When creating a new tensor type X one should create both XTensorStorage class inheriting from QuantizedTensorStorage and XTensor inheriting from XTensorStorage and QuantizedTensor. XTensorStorage should contain all data members needed to implement the functionality of the tensor, while XTensor should only implement the functionality needed to behave like regular torch.Tensor (like __torch_dispatch__).
- abstract prepare_for_saving() Tuple[list[torch.Tensor | None], QuantizedTensorStorage]
Prepare the tensor base for saving for backward
- abstract restore_from_saved(tensors: list[torch.Tensor | None]) list[torch.Tensor | None]
Restore the tensor base data from the saved tensors list
- abstract update_usage(rowwise_usage: bool | None = None, columnwise_usage: bool | None = None)
Generate or remove quantized data based on provided usage.
- Parameters:
rowwise_usage (Optional[bool[, default = None) – Whether to create or keep the data needed for using the tensor in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as None preserves the original value in the tensor.
columnwise_usage (Optional[bool], default = None) – Whether to create or keep the data needed for using the tensor in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as None preserves the original value in the tensor.
- class transformer_engine.pytorch.QuantizedTensor(shape, dtype, *, requires_grad=False, device=None)
Abstract base class for tensor with quantized data
This is a proxy class with the interface of a standard PyTorch tensor, but with data that has been encoded with some quantization scheme. Derived classes should implement the quantization scheme by overriding the quantize_ and dequantize functions.
- abstract dequantize(*, dtype: torch.dtype | None = None) torch.Tensor
Convert quantized data to standard PyTorch tensor
- abstract quantize_(tensor: torch.Tensor) QuantizedTensor
Update quantized data in-place
- class transformer_engine.pytorch.Float8TensorStorage(data, fp8_scale_inv, fp8_dtype, data_transpose=None, quantizer=None)
Mixin class that holds data attributes of Float8Tensor.
Float8Tensor inherits from the PyTorch tensor class and this mixin class. If this class is instantiated directly, it has the same data, lower CPU overhead, and less functionality. It should only be instantiated directly for performance-critical internal usage.
- class transformer_engine.pytorch.MXFP8TensorStorage(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, fp8_dtype, quantizer)
Mixin class that holds data attributes of MXFP8Tensor.
MXFP8Tensor inherits from the PyTorch tensor class and this mixin class. If this class is instantiated directly, it has the same data, lower CPU overhead, and less functionality. It should only be instantiated directly for performance-critical internal usage.
- class transformer_engine.pytorch.Float8BlockwiseQTensorStorage(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, fp8_dtype, quantizer, is_2D_scaled, data_format)
Mixin class that holds data attributes of Float8BlockwiseQTensor.
Float8BlockwiseQTensor inherits from the PyTorch tensor class and this mixin class. If this class is instantiated directly, it has the same data, lower CPU overhead, and less functionality. It should only be instantiated directly for performance-critical internal usage.
- class transformer_engine.pytorch.NVFP4TensorStorage(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, amax_rowwise, amax_columnwise, fp4_dtype, quantizer)
Mixin class that holds data attributes of NVFP4Tensor.
NVFP4Tensor inherits from the PyTorch tensor class and this mixin class. If this class is instantiated directly, it has the same data, lower CPU overhead, and less functionality. It should only be instantiated directly for performance-critical internal usage.
- class transformer_engine.pytorch.Float8Tensor(shape, dtype, data, fp8_scale_inv, fp8_dtype, requires_grad=False, data_transpose=None, quantizer=None)
Experimental tensor class with FP8 data
The tensor presents as having a standard, higher-precision dtype, but the data itself is (scaled) FP8. For most tensor operations, the data will be cast to the nominal dtype before performing the operation.
- Parameters:
shape (int or iterable of int) – Tensor dimensions.
dtype (torch.dtype) – Nominal tensor datatype.
requires_grad (bool, optional = False) – Whether to compute gradients for this tensor.
data (torch.Tensor) – Raw FP8 data in a uint8 tensor
fp8_scale_inv (torch.Tensor) – Reciprocal of the scaling factor applied when casting to FP8, i.e. the scaling factor that must be applied when casting from FP8 to higher precision.
fp8_dtype (transformer_engine_torch.DType) – FP8 format.
data_transpose (torch.Tensor, optional) – FP8 transpose data in a uint8 tensor
quantizer (Float8Quantizer, Float8CurrentScalingQuantizer, optional) – Builder class for FP8 tensors
- class transformer_engine.pytorch.MXFP8Tensor(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, fp8_dtype, quantizer)
Experimental tensor class with FP8 data
The tensor presents as having a standard, higher-precision dtype, but the data itself is (scaled) FP8. For most tensor operations, the data will be cast to the nominal dtype before performing the operation.
- Parameters:
data (torch.Tensor) – Raw FP8 data in a uint8 tensor
fp8_dtype (transformer_engine_torch.DType, default = kFloat8E4M3) – FP8 format.
fp8_scale_inv (torch.Tensor) – Reciprocal of the scaling factor applied when casting to FP8, i.e. the scaling factor that must be applied when casting from FP8 to higher precision.
dtype (torch.dtype, default = torch.float32) – Nominal tensor datatype.
- class transformer_engine.pytorch.Float8BlockwiseQTensor(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, fp8_dtype, quantizer, is_2D_scaled, data_format)
Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
The tensor presents as having a standard, higher-precision dtype, but the data itself is (scaled) FP8. For most tensor operations, the data will be cast to the nominal dtype before performing the operation.
- Parameters:
rowwise_data (torch.Tensor) – FP8 data in a uint8 tensor matching shape of dequantized tensor.
rowwise_scale_inv (torch.Tensor) – FP32 dequantization scales in GEMM format for dequantizing rowwise_data.
columnwise_data (Optional[torch.Tensor]) – FP8 data in a uint8 tensor matching shape of dequantized tensor transpose.
columnwise_scale_inv (Optional[torch.Tensor]) – FP32 dequantization scales in GEMM format for dequantizing columnwise_data.
fp8_dtype (transformer_engine_torch.DType, default = kFloat8E4M3) – FP8 format.
quantizer (Quantizer - the Float8BlockQuantizer that quantized this tensor and) – holds configuration about quantization and dequantization modes.
- class transformer_engine.pytorch.NVFP4Tensor(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, amax_rowwise, amax_columnwise, fp4_dtype, quantizer)
Quantized tensor class with FP4 data
The tensor presents as having a standard, higher-precision dtype, but the data itself is (scaled) FP4. For most tensor operations, the data will be cast to the nominal dtype before performing the operation.
- Parameters:
rowwise_data (torch.Tensor) – Raw FP4 data in a uint8 tensor (rowwise layout).
rowwise_scale_inv (torch.Tensor) – Reciprocal of the scaling factor applied when casting to FP4, i.e. the scaling factor that must be applied when casting from FP4 to higher precision (rowwise).
columnwise_data (torch.Tensor, optional) – Raw FP4 data in a uint8 tensor (columnwise layout).
columnwise_scale_inv (torch.Tensor, optional) – Reciprocal of the scaling factor for columnwise FP4 data.
amax_rowwise (torch.Tensor, optional) – Rowwise amax tracking tensor.
amax_columnwise (torch.Tensor, optional) – Columnwise amax tracking tensor.
fp4_dtype (TE_DType) – The FP4 data type used for quantization.
quantizer (Quantizer) – The quantizer instance used for this tensor.
dtype (torch.dtype, default = torch.float32) – Nominal tensor datatype, used in dequantize.
Quantizers
- class transformer_engine.pytorch.Quantizer(rowwise, columnwise)
Builder class for quantized tensors.
This class is typically used to convert a high-precision tensor (e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8).
- quantize(tensor: torch.Tensor, *, out: QuantizedTensor | None = None, dtype: torch.dtype | None = None) QuantizedTensor
Quantize tensor
- abstract update_quantized(src: torch.Tensor, dst: QuantizedTensor, *, noop_flag: torch.Tensor | None = None) QuantizedTensor
Quantize tensor in-place
- class transformer_engine.pytorch.Float8Quantizer(scale, amax, fp8_dtype, *, rowwise=True, columnwise=True)
Builder class for FP8 tensors with per-tensor delayed scaling
High-precision tensors (e.g. in FP32 or BF16) are quantized by multiplying with a scaling factor and casting to FP8. The max-abs value (“amax”) in the tensor is also computed, which can be used for updating the scaling factor (handled externally by DelayedScalingRecipeState and FP8GlobalStateManager).
- class transformer_engine.pytorch.Float8CurrentScalingQuantizer(fp8_dtype, device, *, rowwise=True, columnwise=True, **kwargs)
Builder class for FP8 tensors with per-tensor current scaling
High-precision tensors (e.g. in FP32 or BF16) are quantized by multiplying with a scaling factor and casting to FP8. The max-abs value (“amax”) in the tensor is computed directly by scanning the input high-precision tensor, without the need of any history window.
Unlike delayed scaling, scale and amax tensors are not needed to initialize the quantizer, becuse they are simply GPU buffers that will be filled by current scaling quantization kernels, instead of using values taken from delayed scaling history window. Therefore, device parameter is needed for tensor allocation.
Both Float8CurrentScalingQuantizer and Float8Quantizer produces Float8Tensor, because they are both per-tensor scaling, ie. one scaling factor per tensor.
- class transformer_engine.pytorch.MXFP8Quantizer(fp8_dtype, *, rowwise=True, columnwise=True)
Builder class for FP8 tensors with MX block scaling
High-precision tensors (e.g. in FP32 or BF16) are quantized by dividing them into groups of 32 elements, each scaled and cast separately using current data.
- class transformer_engine.pytorch.Float8BlockQuantizer(fp8_dtype, *, rowwise, columnwise, **kwargs)
Builder class for tensors quantized with current scaling using NxN quantization tilings to choose scale.
This class is typically used to convert a high-precision tensor (e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8).
- class transformer_engine.pytorch.NVFP4Quantizer(fp4_dtype, *, rowwise=True, columnwise=True, **kwargs)
Builder class for NVFP4 tensors with NV block scaling
Tensor saving and restoring functions
- transformer_engine.pytorch.prepare_for_saving(*tensors: torch.Tensor | QuantizedTensorStorage) Tuple[list[torch.Tensor | torch.nn.Parameter | None], list[QuantizedTensorStorage | None]]
Prepare tensors for saving. Needed because save_for_backward accepts only torch.Tensor/torch.nn.Parameter types, while we want to be able to save the internal TensorStorage types too.
- transformer_engine.pytorch.restore_from_saved(tensors: list[torch.Tensor | QuantizedTensorStorage | None], saved_tensors: list[torch.Tensor | torch.nn.Parameter | None], return_saved_tensors: bool = False) list[torch.Tensor | QuantizedTensorStorage | None] | tuple[list[torch.Tensor | QuantizedTensorStorage | None], list[torch.Tensor | None]]
Recombine the tensor data and metadata during backward pass.
Deprecated functions
- transformer_engine.pytorch.fp8_autocast(enabled: bool = True, calibrating: bool = False, fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, fp8_group: transformer_engine.pytorch.constants.dist_group_type | None = None, _graph: bool = False) None
Warning
fp8_autocastis deprecated and will be removed in a future release. Useautocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...)instead.
- transformer_engine.pytorch.fp8_model_init(enabled: bool = True, recipe: transformer_engine.common.recipe.Recipe | None = None, preserve_high_precision_init_val: bool = False) None
Warning
fp8_model_init is deprecated and will be removed in a future release. Use
quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...)instead.