pyTorch

class transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)

Applies a linear transformation to the incoming data \(y = xA^T + b\)

On NVIDIA GPUs it is a drop-in replacement for torch.nn.Linear.

Parameters:
  • in_features (int) – size of each input sample.

  • out_features (int) – size of each output sample.

  • bias (bool, default = True) – if set to False, the layer will not learn an additive bias.

  • init_method (Callable, default = None) – used for initializing weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • get_rng_state_tracker (Callable, default = None) – used to get the random number generator state tracker for initializing weights.

  • rng_tracker_name (str, default = None) – the param passed to get_rng_state_tracker to get the specific rng tracker.

  • parameters_split (Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None) – Configuration for splitting the weight and bias tensors along dim 0 into multiple PyTorch parameters. If a list or tuple of strings is provided, they are used to make the names of equally-sized parameters. If a dict (preferably an OrderedDict) is provided, the keys are used as names and values as split sizes along dim 0. The resulting parameters will have names that end in _weight or _bias, so trailing underscores are stripped from any provided names.

  • device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will be allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.

Parallelism parameters:
  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.

  • parallel_mode ({None, ‘column’, ‘row’}, default = None) – used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described here. When set to None, no communication is performed.

Optimization parameters:
  • fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additional main_grad attribute (used instead of the regular grad) which is a pre-allocated buffer of the correct size to accumulate gradients in.

  • return_bias (bool, default = False) – when set to True, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.

  • params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.

forward(inp: torch.Tensor, is_first_microbatch: bool | None = None, fp8_output: bool | None = False) torch.Tensor | Tuple[torch.Tensor, Ellipsis]

Apply the linear transformation to the input.

Parameters:
  • inp (torch.Tensor) – Input tensor.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

    • it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)

set_tensor_parallel_group(tp_group: transformer_engine.pytorch.constants.dist_group_type | None) None

Set the tensor parallel group for the given module before executing the forward pass.

Parameters:

tp_group (ProcessGroup, default = None) – tensor parallel process group.

class transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs)
Applies linear transformations to the incoming data list

\(y_i = x_iA_i^T + b_i\) in a grouped way.

Parameters:
  • num_gemms (int) – number of GEMMs to be performed simutaneously.

  • in_features (int) – size of each input sample.

  • out_features (int) – size of each output sample.

  • bias (bool, default = True) – if set to False, the layer will not learn an additive bias.

  • init_method (Callable, default = None) – used for initializing weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • get_rng_state_tracker (Callable, default = None) – used to get the random number generator state tracker for initializing weights.

  • rng_tracker_name (str, default = None) – the param passed to get_rng_state_tracker to get the specific rng tracker.

  • device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will be allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.

Optimization parameters:
  • fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additional main_grad attribute (used instead of the regular grad) which is a pre-allocated buffer of the correct size to accumulate gradients in.

  • return_bias (bool, default = False) – when set to True, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.

  • params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.

forward(inp: torch.Tensor, m_splits: List[int], is_first_microbatch: bool | None = None) torch.Tensor | Tuple[torch.Tensor, Ellipsis]

Apply the linear transformation to the input.

Parameters:
  • inp (torch.Tensor) – Input tensor.

  • m_splits (List[int]) – List of integers representing the split of the input tensor.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

    • it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)

set_tensor_parallel_group(tp_group: transformer_engine.pytorch.constants.dist_group_type | None) None

Set the tensor parallel group for the given module before executing the forward pass.

Parameters:

tp_group (ProcessGroup, default = None) – tensor parallel process group.

class transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)

Layer Normalization

Applies Layer Normalization over a mini-batch of inputs as described in the paper Layer Normalization

\[y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta\]

\(\gamma\) and \(\beta\) are learnable affine transform parameters that match the inner-most dimensions of the input tensor.

Parameters:
  • normalized_shape (int or iterable of int) – Inner dimensions of input tensor

  • eps (float, default = 1e-5) – A value added to the denominator of layer normalization for numerical stability

  • device (torch.device, default = default CUDA device) – Tensor device

  • dtype (torch.dtype, default = default dtype) – Tensor datatype

  • zero_centered_gamma (bool, default = 'False') –

    If True, the \(\gamma\) parameter is initialized to zero and the calculation changes to

    \[y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]

  • sm_margin (int or dict, default = 0) – Number of SMs to exclude when launching CUDA kernels. This helps overlap with other kernels, e.g. communication kernels. For more fine-grained control, provide a dict with the SM margin at each compute stage (“forward”, “backward”, “inference”).

  • Legacy

  • ------

  • sequence_parallel (bool) – Set a bool attr named sequence_parallel in the parameters. This is custom logic for Megatron-LM integration.

class transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)

Root Mean Square Layer Normalization

Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in the paper Root Mean Square Layer Normalization

\[y = \frac{x}{\text{RMS}_\varepsilon(x)} * \gamma\]

where

\[\text{RMS}_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^n x_i^2 + \varepsilon}\]

\(\gamma\) is a learnable affine transform parameter that matches the inner-most dimensions of the input tensor.

Parameters:
  • normalized_shape (int or iterable of int) – Inner dimensions of input tensor

  • eps (float, default = 1e-5) – A value added to the denominator for numerical stability

  • device (torch.device, default = default CUDA device) – Tensor device

  • dtype (torch.dtype, default = default dtype) – Tensor datatype

  • zero_centered_gamma (bool, default = 'False') –

    If True, the \(\gamma\) parameter is initialized to zero and the calculation changes to

    \[y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)\]

  • sm_margin (int, default = 0) – Number of SMs to exclude when launching CUDA kernels. This helps overlap with other kernels, e.g. communication kernels. For more fine-grained control, provide a dict with the SM margin at each compute stage (“forward”, “backward”, “inference”).

  • Legacy

  • ------

  • sequence_parallel (bool) – Set a bool attr named sequence_parallel in the parameters. This is custom logic for Megatron-LM integration.

class transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)

Applies layer normalization followed by linear transformation to the incoming data.

Parameters:
  • in_features (int) – size of each input sample.

  • out_features (int) – size of each output sample.

  • eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.

  • bias (bool, default = True) – if set to False, the layer will not learn an additive bias.

  • normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.

  • init_method (Callable, default = None) – used for initializing weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • return_layernorm_output (bool, default = False) – if set to True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.

  • return_layernorm_output_gathered (bool, default = False) – if set to True, output of layernorm is returned after the all gather operation. Ignored if return_layernorm_output is False. Example use case: with sequence parallel, input to residual connection for transformer module (e.g. LoRA) will need to be gathered. Returning layernorm output gathered will prevent a redundant gather.

  • parameters_split (Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None) – Configuration for splitting the weight and bias tensors along dim 0 into multiple PyTorch parameters. If a list or tuple of strings is provided, they are used to make the names of equally-sized parameters. If a dict (preferably an OrderedDict) is provided, the keys are used as names and values as split sizes along dim 0. The resulting parameters will have names that end in _weight or _bias, so trailing underscores are stripped from any provided names.

  • zero_centered_gamma (bool, default = 'False') –

    if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]

  • device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will be allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.

Parallelism parameters:
  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.

  • parallel_mode ({None, ‘column’, ‘row’}, default = None) – used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described here. When set to None, no communication is performed.

Optimization parameters:
  • fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additional main_grad attribute (used instead of the regular grad) which is a pre-allocated buffer of the correct size to accumulate gradients in.

  • return_bias (bool, default = False) – when set to True, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.

  • params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.

forward(inp: torch.Tensor, is_first_microbatch: bool | None = None, fp8_output: bool | None = False) torch.Tensor | Tuple[torch.Tensor, Ellipsis]

Apply layer normalization to the input followed by a linear transformation.

Parameters:
  • inp (torch.Tensor) – Input tensor.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

    • it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)

set_tensor_parallel_group(tp_group: transformer_engine.pytorch.constants.dist_group_type | None) None

Set the tensor parallel group for the given module before executing the forward pass.

Parameters:

tp_group (ProcessGroup, default = None) – tensor parallel process group.

class transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)

Applies layer normalization on the input followed by the MLP module, consisting of 2 successive linear transformations, separated by the GeLU activation.

Parameters:
  • hidden_size (int) – size of each input sample.

  • ffn_hidden_size (int) – intermediate size to which input samples are projected.

  • eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.

  • bias (bool, default = True) – if set to False, the FC1 and FC2 layers will not learn an additive bias.

  • normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.

  • activation (str, default = 'gelu') – activation function used. Options: ‘gelu’, ‘geglu’, ‘relu’, ‘reglu’, ‘squared_relu’, ‘swiglu’, ‘qgelu’, ‘srelu’.

  • init_method (Callable, default = None) – used for initializing FC1 weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • output_layer_init_method (Callable, default = None) – used for initializing FC2 weights in the following way: output_layer_init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • return_layernorm_output (bool, default = False) – if set to True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.

  • return_layernorm_output_gathered (bool, default = False) – if set to True, output of layernorm is returned after the all gather operation. Ignored if return_layernorm_output is False. Example use case: with sequence parallel, input to residual connection for transformer module (e.g. LoRA) will need to be gathered. Returning layernorm output gathered will prevent a redundant gather.

  • zero_centered_gamma (bool, default = 'False') –

    if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]

  • device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will be allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.

Parallelism parameters:
  • set_parallel_mode (bool, default = False) – if set to True, FC1 is used as Column Parallel and FC2 is used as Row Parallel as described here.

  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.

Optimization parameters:
  • fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additional main_grad attribute (used instead of the regular grad) which is a pre-allocated buffer of the correct size to accumulate gradients in.

  • return_bias (bool, default = False) – when set to True, this module will not apply the additive bias for FC2, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.

  • params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.

  • seq_length (int) – sequence length of input samples. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.

  • micro_batch_size (int) – batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.

forward(inp: torch.Tensor, is_first_microbatch: bool | None = None) torch.Tensor | Tuple[torch.Tensor, Ellipsis]

Apply layer normalization to the input followed by a feedforward network (MLP Block).

Parameters:
  • inp (torch.Tensor) – Input tensor.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

    • it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)

set_tensor_parallel_group(tp_group: transformer_engine.pytorch.constants.dist_group_type | None) None

Set the tensor parallel group for the given module before executing the forward pass.

Parameters:

tp_group (ProcessGroup, default = None) – tensor parallel process group.

class transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)

Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need.

Note

Argument attention_mask in the forward call is only used when attn_mask_type includes ‘“padding”’ or “arbitrary”.

Warning

FlashAttention uses a non-deterministic algorithm for optimal performance. To observe deterministic behavior at the cost of performance, use FlashAttention version >= 2.4.1 and set the environment variable NVTE_ALLOW_NONDETERMINISTIC_ALGO=0. In order to disable`flash-attn` entirely, set NVTE_FLASH_ATTN=0.

Note

Transformer Engine stores the FP8 metadata under a ._extra_state key when checkpointing. As the FP8 attention support expands from one backend to multiple backends, the location of that key has also shifted (see FP8 checkpoint compatibility).

Parameters:
  • num_attention_heads (int) – number of attention heads in the transformer layer.

  • kv_channels (Union[int, Tuple[int, int]]) – the head size in key and value tensors. If the same, kv_channels can be an integer; if not, kv_channels should be a tuple of two integers.

  • num_gqa_groups (Optional[int] = None) – number of GQA groups in the transformer layer. Grouped Query Attention is described in this paper. This only affects the keys and values, not the queries. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e. num_gqa_groups = num_attention_heads.

  • attention_dropout (float, default = 0.0) – dropout probability for the dropout op during multi-head attention.

  • attn_mask_type (str, default = causal) – type of attention mask passed into softmax operation, options are “no_mask”, “padding”, “causal”, “padding,causal”, “causal,padding”, “padding_causal”, “causal_bottom_right”, “padding_causal_bottom_right”, and “arbitrary”, where “padding,causal”, “causal,padding” and “padding_causal” are equivalent. This arg can be overridden by attn_mask_type in the forward method. It is useful for cases involving compilation/tracing, e.g. ONNX export, and the forward arg is useful for dynamically changing mask types, e.g. a different mask for training and inference. 1. For “no_mask”, no attention mask is applied. 2. For “causal”, “causal_bottom_right”, or the causal mask in “padding_causal” and “padding_causal_bottom_right”, Transformer Engine calculates and applies an upper triangular mask to the softmax input. No user input is needed. Causal masks without the “bottom_right” appendix align the diagonal line to the top left corner of the softmax matrix. With “bottom_right”, the causal mask is aligned to the bottom right corner, which is often used in inference/KV caching. 3. For “padding”, or the padding mask in “padding_causal” and “padding_causal_bottom_right”, users need to provide the locations of padded tokens, either via cu_seqlens_q and cu_seqlens_kv (both in shape [batch_size + 1]), or via attention_mask (one tensor for self-attention in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and [batch_size, 1, 1, max_seqlen_kv]). 4. For “arbitrary”, users need to provide a mask that is broadcastable to the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].

  • window_size (Optional[Tuple[int, int]], default = None) – sliding window size for local attention, where query at position i attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. Both causal and causal_bottom_right masks map to window_size = (-1, 0) and Transformer Engine distinguishes them based on attn_mask_type. Similar to attn_mask_type, window_size can be overridden by window_size in forward as well.

  • attention_type (str, default = self) – type of attention, either “self” and “cross”.

  • layer_number (int, default = None) – layer number of the current DotProductAttention when multiple such modules are concatenated, for instance in consecutive transformer blocks.

  • qkv_format (str, default = sbhd) – dimension format for query_layer, key_layer and value_layer, {sbhd, bshd, thd}. s stands for the sequence length, b batch size, h the number of heads, d head size, and t the total number of tokens in a batch, with t = sum(s_i), for i = 0…b-1. sbhd and bshd formats are used for when sequences in a batch are of equal length or padded to equal length, and the thd format is used for when sequences in a batch have different lengths. Please note that these formats do not reflect how tensors query_layer, key_layer, value_layer are laid out in memory. For that, please use get_qkv_layout to gain the layout information.

  • softmax_scale (Optional[float], default = None) – softmax scale for the attention scores. If None, defaults to 1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0]).

Parallelism parameters:
  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_size (int, default = 1) – tensor parallel world size.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • cp_group (Union[ProcessGroup, List[ProcessGroup]], default = None) – context parallel process group. ProcessGroup is for cp_comm_type of “p2p”, “all_gather”, and “a2a”. List[ProcessGroup] is for cp_comm_type of “a2a+p2p”, where cp_group[0] and cp_group[1] are for a2a and p2p communications respectively.

  • cp_global_ranks (list of global rank IDs, default = None) – global rank IDs of GPUs that are in cp_group.

  • cp_stream (CUDA stream, default = None) – context parallelism splits flash attention into multiple steps for compute and communication overlapping. To address the wave quantization issue of each split step, we add an additional CUDA stream so that we can overlap two flash attention kernels.

  • cp_comm_type (str, default = p2p) – inter-gpu communication type for context parallelism. Can be “p2p” or “all_gather” or “a2a” or “a2a+p2p”. “p2p”: Exchange KV chunks with P2P communications in ring topology.

    P2P is async and can be overlapped with attention compute.

    “all_gather”: All-gather to get full sequence of KV before attention.

    The all-gather is not async, and cannot be overlapped.

    “a2a”: Like DeepSpeed Ulysses, scatter attention heads across the CP

    group, and gather to get full sequence of QKV.

    “a2a+p2p”: hierarchical CP implementation. First applying a2a to QKV across each CP sub-group (e.g., via NVLink), then exchanging KV with p2p between sub-groups (e.g., via IBLink).

forward(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None, qkv_format: str | None = None, cu_seqlens_q: torch.Tensor | None = None, cu_seqlens_kv: torch.Tensor | None = None, cu_seqlens_q_padded: torch.Tensor | None = None, cu_seqlens_kv_padded: torch.Tensor | None = None, max_seqlen_q: int | None = None, max_seqlen_kv: int | None = None, attn_mask_type: str | None = None, window_size: Tuple[int, int] | None = None, checkpoint_core_attention: bool = False, core_attention_bias_type: str = 'no_bias', core_attention_bias: torch.Tensor | None = None, alibi_slopes: torch.Tensor | None = None, fast_zero_fill: bool = True, inference_params: InferenceParams | None = None, is_first_microbatch: bool | None = None) torch.Tensor

Dot Product Attention Layer.

Note

Argument attention_mask is only used when attn_mask_type includes ‘“padding”’ or “arbitrary”.

Note

DotProductAttention supports three backends: 1) FlashAttention which calls HazyResearch/Dao-AILab’s flash-attn PyTorch API, 2) FusedAttention which has multiple fused attention implementations based on cuDNN Graph API (see FusedAttention for more details on FusedAttention backends), and 3) UnfusedDotProductAttention which is the native PyTorch implementation with fused scaled masked softmax.

Note

Users can use environment variables NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, and NVTE_FUSED_ATTN_BACKEND to control which DotProductAttention backend, and FusedAttention backend if applicable, to use. Transformer Engine prioritizes FlashAttention over FusedAttention and over UnfusedDotProductAttention. If FusedAttention is being used, users can also choose to switch to flash-attn’s implementation for backward by setting NVTE_FUSED_ATTN_USE_FAv2_BWD=1 (default: 0), because of the performance differences between various versions of flash-attn and FusedAttention. Further, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT can be used to enable (1) or disable (0) the workspace related optimizations in FusedAttention. When unset, Transformer Engine determines the code path based on its internal logic. These optimizations trade memory for performance and should be used with care.

Note

When training data has variable sequence lengths, users have two options.

  1. Manipulate the data and pad all sequences to the same length. Use qkv_format = {“bshd”, “sbhd”} and attn_mask_type = {“padding”, “padding_causal”, “padding_causal_bottom_right”}. Pass in cu_seqlens_q and cu_seqlens_kv, or attention_mask (which will be converted to cu_seqlens_q and cu_seqlens_kv), to provide the real sequence length information. For example, a batch of 3 sequences [a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative sequence length tensors would be cu_seqlens_q = cu_seqlens_kv = [0, 3, 5, 9] for self-attention.

  2. Do not perform padding on training data. Use qkv_format = “thd” and attn_mask_type = {“padding”, “padding_causal”, “padding_causal_bottom_right”}. Pass in cu_seqlens_q and cu_seqlens_kv, or attention_mask, as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed without any padding, and the sequence length tensors would be cu_seqlens_q = cu_seqlens_kv = [0, 3, 5, 9] for self-attention.

    In certain use cases, a varying number of identifier tokens are inserted between sequences. These tokens do not participate in the attention calculation. cu_seqlens_q_padded and cu_seqlens_kv_padded must be specified in such cases to correctly identify the start and end of each sequence in a batch. For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have cu_seqlens_q = cu_seqlens_kv = [0, 3, 5, 9], and cu_seqlens_q_padded = cu_seqlens_kv_padded = [0, 4, 8, 13] for self-attention.

Note

When qkv_format = {“bshd”, “sbhd”}, sequences are of equal length in a batch. max_seqlen_q and max_seqlen_kv should be the same as the “s” dimension of query_layer and key_layer tensors. When unset, Transformer Engine will infer them as such.

When qkv_format = “thd”, sequences have varying lengths. max_seqlen_q and max_seqlen_kv should be the maximum query and key/value sequence length in a batch. When unset, Transformer Engine deduces them from cu_seqlens_q and cu_seqlens_kv. This deduction costs a small kernel and some CPU-GPU synchronization, and to avoid this overhead, users are recommended to obtain the maximum sequence lengths from the data loaders and pass them in.

  • As the maximum sequence lengths, batch size, and number of tokens change from batch to batch, dynamic shapes need to be supported for tensor construction. FlashAttention and UnfusedDotProductAttention naturally do so, while FusedAttention requires parameters to be static to create graphs before performance heuristics analysis. To reduce the number of graphs created per run, Transformer Engine 1.13+ quantizes relevant parameters: for cuDNN < 9.6, {batch size, max_seqlen_q, max_seqlen_kv}, and for cuDNN >= 9.6, {“t” dimension of query_layer, “t” dimension of key_layer}.

Parameters:
  • query_layer (torch.Tensor) – Query tensor.

  • key_layer (torch.Tensor) – Key tensor.

  • value_layer (torch.Tensor) – Value tensor.

  • attention_mask (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],) – default = None. Boolean tensor(s) used to mask out attention softmax input. It should be None for causal masks and “no_mask”. For padding masks, it should be a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for cross-attention. For “arbitrary” mask, it should be in a shape broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A True value means the corresponding position is masked out and a False means that position is allowed to participate in attention.

  • qkv_format (str, default = None) – If provided, overrides qkv_format from initialization.

  • cu_seqlens_q (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (without offset) in a batch for query_layer, with shape [batch_size + 1] and dtype torch.int32. See note for more details.

  • cu_seqlens_kv (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (without offset) in a batch for key_layer and value_layer, with shape [batch_size + 1] and dtype torch.int32. See note for more details.

  • cu_seqlens_q_padded (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (with offset) in a batch for query_layer, with shape [batch_size + 1] and dtype torch.int32. When there is no padding between sequences in a batch, cu_seqlens_q_padded = cu_seqlens_q. See note for more details.

  • cu_seqlens_kv_padded (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (with offset) in a batch for key_layer and value_layer, with shape [batch_size + 1] and dtype torch.int32. When there is no padding between sequences in a batch, cu_seqlens_kv_padded = cu_seqlens_kv. See note for more details.

  • max_seqlen_q (Optional[int], default = None) – Maximum sequence length in query_layer. See note for more details.

  • max_seqlen_kv (Optional[int], default = None) – Maximum sequence length in key_layer and value_layer. See note for more details.

  • attn_mask_type ({'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',) – ‘padding_causal’, ‘causal_bottom_right’, ‘padding_causal_bottom_right’, ‘arbitrary’}, default = None. Type of attention mask passed into softmax operation. ‘padding,causal’, ‘causal,padding’ and ‘padding_causal’ are equivalent. By default, causal masks are aligned to the top left corner of the softmax matrix. When “bottom_right” is specified in the mask type, causal masks are aligned to the bottom right corner.

  • window_size (Optional[Tuple[int, int]], default = None) – Sliding window size for local attention.

  • checkpoint_core_attention (bool, default = False) – If true, forward activations for attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.

  • core_attention_bias_type (str, default = no_bias) – Bias type, {no_bias, pre_scale_bias, post_scale_bias, alibi}

  • core_attention_bias (Optional[torch.Tensor], default = None) – Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv]. It should be ‘None’ for ‘no_bias’ and ‘alibi’ bias types.

  • alibi_slopes (Optional[torch.Tensor], default = None) – ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j.

  • fast_zero_fill (bool, default = True) – Whether to use the fast path to set output tensors to 0 or not.

  • inference_params (Optional[InferenceParams], default = None) – Optimizes execution performance during inference by caching Keys and Values of the current decoding iteration. These cached values are appended to the K and V values computed in previous iterations, eliminating the need to recalculate them for the entire sequence. Initialization of inference_params is required prior to use to ensure sufficient memory allocation. Adjustments of the sequence_len_offset should be done after a complete forward pass. If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand. Supports “sbhd” and “bshd” layouts, with the “sbhd” layout being more efficient.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

    • it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)

set_context_parallel_group(cp_group: transformer_engine.pytorch.constants.dist_group_type | List[transformer_engine.pytorch.constants.dist_group_type] | None, cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, cp_comm_type: str = 'p2p') None

Set the context parallel attributes for the given module before executing the forward pass.

Parameters:
  • cp_group (Union[ProcessGroup, List[ProcessGroup]]) – context parallel process group. ProcessGroup is for cp_comm_type of “p2p”, “all_gather”, and “a2a”. List[ProcessGroup] is for cp_comm_type of “a2a+p2p”, where cp_group[0] and cp_group[1] are for a2a and p2p communications respectively.

  • cp_global_ranks (List[int]) – list of global ranks in the context group.

  • cp_stream (torch.cuda.Stream) – cuda stream for context parallel execution.

  • cp_comm_type (str, default = p2p) –

    inter-gpu communication type for context parallelism. Can be “p2p” or “all_gather” or “a2a” or “a2a+p2p”. “p2p”: Exchange KV chunks with P2P communications in ring topology.

    P2P is async and can be overlapped with attention compute.

    ”all_gather”: All-gather to get full sequence of KV before attention.

    The all-gather is not async, and cannot be overlapped.

    ”a2a”: Like DeepSpeed Ulysses, scatter attention heads across the CP

    group, and gather to get full sequence of QKV.

    ”a2a+p2p”: hierarchical CP implementation. First applying a2a to QKV across each CP sub-group (e.g., via NVLink), then exchanging KV with p2p between sub-groups (e.g., via IBLink).

class transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)

Multi-head Attention (MHA), including Query, Key, Value and Output projection.

Note

Argument attention_mask in the forward call is only used when attn_mask_type includes ‘“padding”’ or “arbitrary”.

Parameters:
  • hidden_size (int) – size of each input sample.

  • num_attention_heads (int) – number of attention heads in the transformer layer.

  • kv_channels (int, default = None) – number of key-value channels. defaults to hidden_size / num_attention_heads if None.

  • attention_dropout (float, default = 0.1) – dropout probability for the dropout op during multi-head attention.

  • layernorm_epsilon (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.

  • init_method (Callable, default = None) – used for initializing weights of QKV and FC1 weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • output_layer_init_method (Callable, default = None) – used for initializing weights of PROJ and FC2 in the following way: output_layer_init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • layer_number (int, default = None) – layer number of the current TransformerLayer when multiple such modules are concatenated to form a transformer block.

  • attn_mask_type ({'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',) – ‘padding_causal_bottom_right’,’arbitrary’}, default = causal type of attention mask passed into softmax operation. Overridden by attn_mask_type in the forward method. The forward arg is useful for dynamically changing mask types, e.g. a different mask for training and inference. The init arg is useful for cases involving compilation/tracing, e.g. ONNX export.

  • window_size (Optional[Tuple[int, int]], default = None) – sliding window size for local attention, where query at position i attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. Both causal and causal_bottom_right masks map to window_size = (-1, 0) and Transformer Engine distinguishes them based on attn_mask_type. Similar to attn_mask_type, window_size can be overridden by window_size in forward as well.

  • num_gqa_groups (int, default = None) –

    number of GQA groups in the transformer layer. Grouped Query Attention is described in this paper. This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e. num_gqa_groups = num_attention_heads.

  • return_layernorm_output (bool, default = False) – if set to True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.

  • input_layernorm (bool, default = False) – if set to True, layer normalization to the input is applied.

  • attention_type ({ 'self', 'cross' }, default = 'self') – type of attention applied.

  • zero_centered_gamma (bool, default = 'False') –

    if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]

  • normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.

  • qkv_weight_interleaved (bool, default = True) – if set to False, the QKV weight is interpreted as a concatenation of query, key, and value weights along the 0th dimension. The default interpretation is that the individual q, k, and v weights for each attention head are interleaved. This parameter is set to False when using fuse_qkv_params=False.

  • bias (bool, default = True) – if set to False, the transformer layer will not learn any additive biases.

  • device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will be allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.

  • qkv_format (str, default = sbhd) – dimension format for query_layer, key_layer and value_layer, {sbhd, bshd}. s stands for the sequence length, b batch size, h the number of heads and d head size. sbhd and bshd formats are used for when sequences in a batch are of equal length or padded to equal length. Please note that these formats do not reflect how tensors query_layer, key_layer, value_layer are laid out in memory. For that, please use get_qkv_layout to gain the layout information.

Parallelism parameters:
  • set_parallel_mode (bool, default = False) – if set to True, QKV and FC1 layers are used as Column Parallel whereas PROJ and FC2 is used as Row Parallel as described here.

  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.

Optimization parameters:
  • fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additional main_grad attribute (used instead of the regular grad) which is a pre-allocated buffer of the correct size to accumulate gradients in.

  • params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.

  • return_bias (bool, default = False) – when set to True, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.

  • fuse_qkv_params (bool, default = ‘False’) – if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.

forward(hidden_states: torch.Tensor, attention_mask: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None, encoder_output: torch.Tensor | None = None, attn_mask_type: str | None = None, window_size: Tuple[int, int] | None = None, is_first_microbatch: bool | None = None, checkpoint_core_attention: bool = False, inference_params: InferenceParams | None = None, rotary_pos_emb: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None, core_attention_bias_type: str = 'no_bias', core_attention_bias: torch.Tensor | None = None, alibi_slopes: torch.Tensor | None = None, cu_seqlens_q: torch.Tensor | None = None, cu_seqlens_kv: torch.Tensor | None = None, max_seqlen_q: int | None = None, max_seqlen_kv: int | None = None, fast_zero_fill: bool = True) Tuple[torch.Tensor | None, Ellipsis]

Forward propagation for MultiheadAttention layer.

Note

Argument attention_mask is only used when attn_mask_type includes “padding” or “arbitrary”.

Parameters:
  • hidden_states (torch.Tensor) – Input tensor.

  • attention_mask (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],) – default = None. Boolean tensor(s) used to mask out attention softmax input. It should be None for causal masks and “no_mask”. For padding masks, it should be a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for cross-attention. For “arbitrary” mask, it should be in a shape broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A True value means the corresponding position is masked out and a False means that position is allowed to participate in attention.

  • attn_mask_type ({'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',) – ‘padding_causal_bottom_right’,’arbitrary’}, default = None type of attention mask passed into softmax operation. By default, causal masks are aligned to the top left corner of the softmax matrix. When “bottom_right” is specified in the mask type, causal masks are aligned to the bottom right corner.

  • window_size (Optional[Tuple[int, int]], default = None) – sliding window size for local attention.

  • encoder_output (Optional[torch.Tensor], default = None) – Output of the encoder block to be fed into the decoder block if using layer_type=”decoder”.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

    • it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)

  • checkpoint_core_attention (bool, default = False) – If true, forward activations for core attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.

  • rotary_pos_emb (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = None) – Embeddings for query and key tensors for applying rotary position embedding. By default no input embedding is applied.

  • core_attention_bias_type (str, default = no_bias) – Bias type, {no_bias, pre_scale_bias, ‘post_scale_bias`, alibi}

  • core_attention_bias (Optional[torch.Tensor], default = None) – Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv]. It should be ‘None’ for ‘no_bias’ and ‘alibi’ bias types.

  • alibi_slopes (Optional[torch.Tensor], default = None) – ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j.

  • cu_seqlens_q (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (without offset) in a batch for query_layer, with shape [batch_size + 1] and dtype torch.int32.

  • cu_seqlens_kv (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (without offset) in a batch for key_layer and value_layer, with shape [batch_size + 1] and dtype torch.int32.

  • max_seqlen_q (Optional[int], default = None) – Maximum sequence length in query_layer. Calculated from cu_seqlens_q if not provided.

  • max_seqlen_kv (Optional[int], default = None) – Maximum sequence length in key_layer and value_layer. Calculated from cu_seqlens_kv if not provided.

  • fast_zero_fill (bool, default = True) – Whether to set output tensors to 0 or not before use.

set_context_parallel_group(cp_group: transformer_engine.pytorch.constants.dist_group_type | List[transformer_engine.pytorch.constants.dist_group_type] | None, cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, cp_comm_type: str = 'p2p') None

Set the context parallel attributes for the given module before executing the forward pass.

Parameters:
  • cp_group (Union[ProcessGroup, List[ProcessGroup]]) – context parallel process group. ProcessGroup is for cp_comm_type of “p2p”, “all_gather”, and “a2a”. List[ProcessGroup] is for cp_comm_type of “a2a+p2p”, where cp_group[0] and cp_group[1] are for a2a and p2p communications respectively.

  • cp_global_ranks (List[int]) – list of global ranks in the context group.

  • cp_stream (torch.cuda.Stream) – cuda stream for context parallel execution.

  • cp_comm_type (str, default = p2p) –

    inter-gpu communication type for context parallelism. Can be “p2p” or “all_gather” or “a2a”, “a2a+p2p”. “p2p”: Exchange KV chunks with P2P communications in ring topology.

    P2P is async and can be overlapped with attention compute.

    ”all_gather”: All-gather to get full sequence of KV before attention.

    The all-gather is not async, and cannot be overlapped.

    ”a2a”: Like DeepSpeed Ulysses, scatter attention heads across the CP

    group, and gather to get full sequence of QKV.

    ”a2a+p2p”: hierarchical CP implementation. First applying a2a to QKV across each CP sub-group (e.g., via NVLink), then exchanging KV with p2p between sub-groups (e.g., via IBLink).

set_tensor_parallel_group(tp_group: transformer_engine.pytorch.constants.dist_group_type | None) None

Set the tensor parallel group for the given module before executing the forward pass.

Parameters:

tp_group (ProcessGroup, default = None) – tensor parallel process group.

class transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)

TransformerLayer is made up of an attention block and a feedforward network (MLP). This standard layer is based on the paper “Attention Is All You Need”.

Note

Argument attention_mask in the forward call is only used when self_attn_mask_type includes “padding” or “arbitrary”.

Parameters:
  • hidden_size (int) – size of each input sample.

  • ffn_hidden_size (int) – intermediate size to which input samples are projected.

  • num_attention_heads (int) – number of attention heads in the transformer layer.

  • num_gqa_groups (int, default = None) –

    number of GQA groups in the transformer layer. Grouped Query Attention is described in this paper. This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e. num_gqa_groups = num_attention_heads.

  • layernorm_epsilon (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.

  • hidden_dropout (float, default = 0.1) – dropout probability for the dropout op after FC2 layer.

  • attention_dropout (float, default = 0.1) – dropout probability for the dropout op during multi-head attention.

  • init_method (Callable, default = None) – used for initializing weights of QKV and FC1 weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • output_layer_init_method (Callable, default = None) – used for initializing weights of PROJ and FC2 in the following way: output_layer_init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • apply_residual_connection_post_layernorm (bool, default = False) – if set to True, residual connections are taken from the output of layer norm (default is taken from input of layer norm)

  • layer_number (int, default = None) – layer number of the current TransformerLayer when multiple such modules are concatenated to form a transformer block.

  • output_layernorm (bool, default = False) – if set to True, layer normalization is applied on the output side, after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation.

  • parallel_attention_mlp (bool, default = False) – if set to True, self-attention and feedforward network are computed based on the same input (in parallel) instead of sequentially. Both blocks have an independent normalization. This architecture is used in Falcon models.

  • layer_type ({‘encoder’, ‘decoder’}, default = encoder) – if set to decoder, an additional cross-attn block is added after self-attn. This can be used for structures like T5 Transformer in conjunction with the encoder option.

  • kv_channels (int, default = None) – number of query-key-value channels per attention head. defaults to hidden_size / num_attention_heads if None.

  • self_attn_mask_type ({'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',) – ‘padding_causal_bottom_right’, ‘arbitrary’}, default = causal type of attention mask passed into softmax operation for encoder. Overridden by self_attn_mask_type in the forward method. The forward arg is useful for dynamically changing mask types, e.g. a different mask for training and inference. The init arg is useful for cases involving compilation/tracing, e.g. ONNX export.

  • window_size (Optional[Tuple[int, int]], default = None) – sliding window size for local attention in encoder, where query at position i attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. Both causal and causal_bottom_right masks map to window_size = (-1, 0) and Transformer Engine distinguishes them based on self_attn_mask_type or enc_dec_attn_mask_type. Similar to self_attn_mask_type, window_size can be overridden by window_size in forward as well.

  • enc_dec_attn_mask_type ({'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},) – default = no_mask type of attention mask passed into softmax operation for decoder.

  • enc_dec_window_size (Optional[Tuple[int, int]], default = None) – sliding window size for local attention in decoder.

  • zero_centered_gamma (bool, default = 'False') –

    if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]

  • normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.

  • qkv_weight_interleaved (bool, default = True) – if set to False, the QKV weight is interpreted as a concatenation of query, key, and value weights along the 0th dimension. The default interpretation is that the individual q, k, and v weights for each attention head are interleaved. This parameter is set to False when using fuse_qkv_params=False.

  • bias (bool, default = True) – if set to False, the transformer layer will not learn any additive biases.

  • activation (str, default = 'gelu') – Type of activation used in MLP block. Options are: ‘gelu’, ‘relu’, ‘reglu’, ‘geglu’, ‘swiglu’, ‘qgelu’ and ‘srelu’.

  • device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will be allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.

  • attn_input_format ({'sbhd', 'bshd'}, default = 'sbhd') – This controls whether the dimensions of the intermediate hidden states is ‘batch first’ (‘bshd’) or ‘sequence first’ (‘sbhd’). s stands for the sequence length, b batch size, h the number of heads, d head size. Note that these formats are very closely related to the qkv_format in the MultiHeadAttention and DotProductAttention modules.

Parallelism parameters:
  • set_parallel_mode (bool, default = False) – if set to True, QKV and FC1 layers are used as Column Parallel whereas PROJ and FC2 is used as Row Parallel as described here.

  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.

Optimization parameters:
  • fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additional main_grad attribute (used instead of the regular grad) which is a pre-allocated buffer of the correct size to accumulate gradients in.

  • params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.

  • seq_length (int) – sequence length of input samples. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.

  • micro_batch_size (int) – batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.

  • drop_path_rate (float, default = 0.0) – when > 0.0, applies stochastic depth per sample in the main path of the residual block.

  • fuse_qkv_params (bool, default = ‘False’) – if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.

forward(hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, self_attn_mask_type: str | None = None, window_size: Tuple[int, int] | None = None, encoder_output: torch.Tensor | None = None, enc_dec_attn_mask: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None, enc_dec_attn_mask_type: str | None = None, enc_dec_window_size: Tuple[int, int] | None = None, is_first_microbatch: bool | None = None, checkpoint_core_attention: bool = False, inference_params: transformer_engine.pytorch.attention.InferenceParams | None = None, rotary_pos_emb: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None = None, core_attention_bias_type: str = 'no_bias', core_attention_bias: torch.Tensor | None = None, alibi_slopes: torch.Tensor | None = None, cu_seqlens_q: torch.Tensor | None = None, cu_seqlens_kv: torch.Tensor | None = None, max_seqlen_q: int | None = None, max_seqlen_kv: int | None = None, fast_zero_fill: bool = True) torch.Tensor

Transformer Layer: attention block and a feedforward network (MLP)

Note

Argument attention_mask is only used when self_attn_mask_type includes “padding” or “arbitrary”.

Parameters:
  • hidden_states (torch.Tensor) – Input tensor.

  • attention_mask (Optional[torch.Tensor], default = None) – Boolean tensor used to mask out self-attention softmax input. It should be in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for “arbitrary” mask. It should be None for causal masks and “no_mask” type. A True value means the corresponding position is masked out and a False means that position is allowed to participate in attention.

  • self_attn_mask_type ({'no_mask', 'causal', 'padding', 'padding_causal',) – ‘causal_bottom_right’, ‘padding_causal_bottom_right’,’arbitrary’}, default = causal Type of attention mask passed into softmax operation for encoder. By default, causal masks are aligned to the top left corner of the softmax matrix. When “bottom_right” is specified in the mask type, causal masks are aligned to the bottom right corner.

  • window_size (Optional[Tuple[int, int]], default = None) – Sliding window size for local attention in encoder.

  • encoder_output (Optional[torch.Tensor], default = None) – Output of the encoder block to be fed into the decoder block if using layer_type=”decoder”.

  • enc_dec_attn_mask (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],) – default = None. Boolean tensors used to mask out inter-attention softmax input if using layer_type=”decoder”. It should be a tuple of two masks in [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks. It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for “arbitrary” mask. It should be None for causal masks and “no_mask”. A True value means the corresponding position is masked out and a False means that position is allowed to participate in attention.

  • enc_dec_attn_mask_type ({'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},) – default = None Type of attention mask passed into softmax operation for decoder.

  • enc_dec_window_size (Optional[Tuple[int, int]], default = None) – Sliding window size for local attention in decoder.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

    • it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)

  • checkpoint_core_attention (bool, default = False) – If true, forward activations for core attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.

  • rotary_pos_emb (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = None) – Embeddings for query and key tensors for applying rotary position embedding. By default no input embedding is applied.

  • core_attention_bias_type (str, default = no_bias) – Bias type, {no_bias, pre_scale_bias, post_scale_bias, alibi}

  • core_attention_bias (Optional[torch.Tensor], default = None) – Bias tensor for Q * K.T

  • alibi_slopes (Optional[torch.Tensor], default = None) – ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j.

  • cu_seqlens_q (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (without offset) in a batch for query_layer, with shape [batch_size + 1] and dtype torch.int32.

  • cu_seqlens_kv (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths (without offset) in a batch for key_layer and value_layer, with shape [batch_size + 1] and dtype torch.int32.

  • max_seqlen_q (Optional[int], default = None) – Maximum sequence length in query_layer. Calculated from cu_seqlens_q if not provided.

  • max_seqlen_kv (Optional[int], default = None) – Maximum sequence length in key_layer and value_layer. Calculated from cu_seqlens_kv if not provided.

  • fast_zero_fill (bool, default = True) – Whether to set output tensors to 0 or not before use.

  • inference_params (InferenceParams, default = None) – Inference parameters that are passed to the main model in order to efficiently calculate and store the context during inference.

set_context_parallel_group(cp_group: transformer_engine.pytorch.constants.dist_group_type | List[transformer_engine.pytorch.constants.dist_group_type] | None, cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, cp_comm_type: str = 'p2p') None

Set the context parallel attributes for the given module before executing the forward pass.

Parameters:
  • cp_group (Union[ProcessGroup, List[ProcessGroup]]) – context parallel process group. ProcessGroup is for cp_comm_type of “p2p”, “all_gather”, and “a2a”. List[ProcessGroup] is for cp_comm_type of “a2a+p2p”, where cp_group[0] and cp_group[1] are for a2a and p2p communications respectively.

  • cp_global_ranks (List[int]) – list of global ranks in the context group.

  • cp_stream (torch.cuda.Stream) – cuda stream for context parallel execution.

  • cp_comm_type (str, default = p2p) –

    inter-gpu communication type for context parallelism. Can be “p2p” or “all_gather” or “a2a”, or “a2a+p2p”. “p2p”: Exchange KV chunks with P2P communications in ring topology.

    P2P is async and can be overlapped with attention compute.

    ”all_gather”: All-gather to get full sequence of KV before attention.

    The all-gather is not async, and cannot be overlapped.

    ”a2a”: Like DeepSpeed Ulysses, scatter attention heads across the CP

    group, and gather to get full sequence of QKV.

    ”a2a+p2p”: hierarchical CP implementation. First applying a2a to QKV across each CP sub-group (e.g., via NVLink), then exchanging KV with p2p between sub-groups (e.g., via IBLink).

set_tensor_parallel_group(tp_group: transformer_engine.pytorch.constants.dist_group_type | None) None

Set the tensor parallel group for the given module before executing the forward pass.

Parameters:

tp_group (ProcessGroup, default = None) – tensor parallel process group.

class transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length)

Inference parameters that are passed to the main model in order to efficiently calculate and store the context during inference.

Parameters:
  • max_batch_size (int) – maximum batch size during inference.

  • max_sequence_length (int) – maximum sequence length during inference.

class transformer_engine.pytorch.CudaRNGStatesTracker

For model parallelism, multiple RNG states need to simultaneously exist in order to execute operations in or out of the model parallel region. This class keeps track of the various RNG states and provides utility methods to maintain them and execute parts of the model under a given RNG setting. Using the add method, a cuda rng state is initialized based on the input seed and is assigned to name. Later, by forking the rng state, we can perform operations and return to our starting cuda state.

add(name: str, seed: int) None

Adds a new RNG state.

name: str

string identifier for the RNG state.

seed: int

PyTorch seed for the RNG state.

fork(name: str = 'model-parallel-rng')

Fork the cuda rng state, perform operations, and exit with the original state.

name: str

string identifier for the RNG state.

get_states() Dict[str, torch.Tensor]

Get rng states. Copy the dictionary so we have direct pointers to the states, not just a pointer to the dictionary.

reset()

Set to the initial state (no tracker).

set_states(states: Dict[str, torch.Tensor]) None

Set the rng states. For efficiency purposes, we do not check the size of seed for compatibility.

states: Dict[str, torch.Tensor]

A mapping from string names to RNG states.

transformer_engine.pytorch.fp8_autocast(enabled: bool = True, calibrating: bool = False, fp8_recipe: transformer_engine.common.recipe.DelayedScaling | None = None, fp8_group: transformer_engine.pytorch.constants.dist_group_type | None = None, _graph: bool = False) None

Context manager for FP8 usage.

with fp8_autocast(enabled=True):
    out = model(inp)

Note

Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors with shapes where both dimensions are divisible by 16. In terms of the input to the full Transformer network, this typically requires padding sequence length to be multiple of 16.

Note

When fp8_recipe.reduce_amax==True, any module must not be invoked more than once inside a single fp8_autocast region. This is unsupported behavior because the amax reduction is handled during the exit of the fp8_autocast context. Calling the same module more than once inside an fp8_autocast region overrides the amax tensors before reduction can occur.

Parameters:
  • enabled (bool, default = True) – whether or not to enable fp8

  • calibrating (bool, default = False) – calibration mode allows collecting statistics such as amax and scale data of fp8 tensors even when executing without fp8 enabled. This is useful for saving an inference ready fp8 checkpoint while training using a higher precision.

  • fp8_recipe (recipe.DelayedScaling, default = None) – recipe used for FP8 training.

  • fp8_group (torch._C._distributed_c10d.ProcessGroup, default = None) – distributed group over which amaxes for the fp8 tensors are reduced at the end of each training step.

transformer_engine.pytorch.fp8_model_init(enabled: bool = True) None

Context manager for FP8 initialization of parameters.

Example usage:

with fp8_model_init(enabled=True):
    model = transformer_engine.pytorch.Linear(768, 768)
Parameters:

enabled (bool, default = True) –

when enabled, Transformer Engine modules created inside this fp8_model_init region will hold only FP8 copies of its parameters, as opposed to the default behavior where both higher precision and FP8 copies are present. Setting this option to True may result in lower memory consumption and is especially useful for scenarios like:

  • full model training using optimizer with master weights, where the high precision copies of weights are already present in the optimizer.

  • inference, where only the FP8 copies of the parameters are used.

  • LoRA-like fine-tuning, where the main parameters of the model do not change.

This functionality is EXPERIMENTAL.

transformer_engine.pytorch.checkpoint(function: Callable, *args: Tuple[torch.Tensor, Ellipsis], **kwargs: Dict[str, Any]) Tuple[torch.Tensor, Ellipsis]

Checkpoint a part of the model by trading compute for memory. This function is based on torch.utils.checkpoint.checkpoint.

Warning

It is the user’s responsibility to ensure identical behavior when calling function from the forward and backward pass. If different output is produced (e.g. due to global state), then the checkpointed version won’t be numerically equivalent.

Warning

use_reentrant=False does not support early stopping, and will execute the entire forward pass for the checkpointed module when recomputing activations in the backward pass.

Parameters:
  • function (Callable) – pytorch module used to run the forward and backward passes using the specified args and kwargs.

  • distribute_saved_activations (bool, default = False) – if set to True and use_reentrant=True, first tensor argument is distributed across the specified tensor parallel group (tp_group) before saving it for the backward pass. This has no effect when use_reentrant=False.

  • get_rng_state_tracker (Callable, default = None) – python callable which returns an instance of CudaRNGStatesTracker().

  • tp_group (ProcessGroup, default = None) – tensor parallel process group. Used only when distribute_saved_activations=True and use_reentrant=True. If None, it falls back to the default group.

  • use_reentrant (bool, default = True) – perform checkpointing in reentrant mode.

  • args (tuple) – tuple of torch tensors for inputs to function.

  • kwargs (dict) – dictionary of string keys for keyword arguments to function.

transformer_engine.pytorch.onnx_export(enabled: bool = False) None

Context manager for exporting to ONNX.

with onnx_export(enabled=True):
    torch.onnx.export(model)
Parameters:

enabled (bool, default = False) – whether or not to enable export

transformer_engine.pytorch.make_graphed_callables(modules: SingleOrTuple[Callable], sample_args: SingleOrTuple[Tuple[torch.Tensor, Ellipsis]], num_warmup_iters: int = 3, allow_unused_input: bool = False, sample_kwargs: SingleOrTuple[Dict[str, Any]] | None = None, fp8_enabled: bool = False, fp8_calibrating: bool = False, fp8_recipe: transformer_engine.common.recipe.DelayedScaling | None = None, fp8_weight_caching: bool = False, _order: List[int] | None = None, pool: Tuple[int, Ellipsis] | None = None) Callable | Tuple[Callable, Ellipsis]

Make CUDA graph version of Transformer Engine modules

A variation of PyTorch’s make_graphed_callables utility function with support for Transformer Engine modules and FP8. Please see the original PyTorch implementation for more documentation.

Graphing parameters:
  • modules ((tuple of) callable) – Callable or callables to graph.

  • sample_args ((tuple of) tuple of torch.Tensor) – Positional arguments to callable(s).

  • num_warmup_iters (int, default = 3) – Number of warmup iterations.

  • allow_unused_input (bool, default = False) – Whether to handle case where callable inputs and outputs are disconnected in compute graph.

  • sample_kwargs ((tuple of) dict, optional) – Keyword arguments to callable(s)

  • pool ((tuple of) int, default = None, optional) – An instance returned from function torch.cuda.graph_pool_handle that hints this graph may share memory with the indicated pool.

FP8-related parameters:
  • fp8_enabled (bool, default = True) – whether or not to enable fp8

  • fp8_calibrating (bool, default = False) – calibration mode allows collecting statistics such as amax and scale data of fp8 tensors even when executing without fp8 enabled. This is useful for saving an inference ready fp8 checkpoint while training using a higher precision.

  • fp8_recipe (recipe.DelayedScaling, default = None) – recipe used for FP8 training.

  • fp8_weight_caching (bool, default = False) – Whether or not to cache FP8 weights across microbatches. if set to True, the is_first_microbatch boolean argument must be passed into the forward method for TransformerEngine modules. When storing primary weights in FP8 using TE’s fp8_model_init API and using an FP8 aware optimizer, this arg must be set to False if calculating weight transposes’ outside TE, e.g., in the optimizer step.

transformer_engine.pytorch.get_cpu_offload_context(enabled: bool = False, num_layers: int = 1, model_layers: int = 1, offload_activations: bool = True, offload_weights: bool = True)

This function returns the CPU Offload context and the synchronizer function that needs to be used after every transformer layer. Returns nullcontext() if offloading is not enabled.

Usage:

cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True)

with cpu_offload_context:
    te_layer.forward(inp_tensor)
cpu_offload_synchronizer()
Parameters:
  • enabled (bool, default = False) – When set to True, CPU Offloading functionality is enabled.

  • num_layers (int, default = 1) – Determines the number of transformer layers you want to offload activations/weights for.

  • model_layers (int, default = 1) – Number of layers in the model that will be used under this context.

  • offload_activations (bool, default = True) – When set to True, offloads the activations for the TE layer.

  • offload_weights (bool, default = True) – When set to True, offloads the weights for the TE layer.

transformer_engine.pytorch.moe_permute(inp: torch.Tensor, indices: torch.Tensor, num_out_tokens: int = -1, max_token_num: int = -1) Tuple[torch.Tensor, torch.Tensor]

Permute the tokens based on the indices. Token with the same index will be grouped together.

Parameters:
  • inp (torch.Tensor) – Input tensor of shape [num_tokens, hidden_size], on which permutation will be applied.

  • indices (torch.Tensor) – The token to expert indices tensor of shape [num_tokens, topK] and dtype ‘int32’.

  • num_out_tokens (int, default = -1) – The effective output token count, representing the number of tokens not dropped. By default, set to ‘-1’, meaning no tokens are dropped.

  • max_token_num (int, default = -1) – The maximum number of tokens, used for workspace allocation. By default, set to ‘-1’, meaning the calculation of the size of workspace is automatically taken over by the operator.

transformer_engine.pytorch.moe_unpermute(inp: torch.Tensor, row_id_map: torch.Tensor, probs: torch.Tensor = None) torch.Tensor

Unpermute a tensor with permuted tokens, and optionally merge the tokens with their corresponding probabilities.

Parameters:
  • inp (torch.Tensor) – Input tensor with permuted tokens of shape [num_tokens, hidden_size] to be unpermuted.

  • row_id_map (torch.Tensor) – The tensor of a mapping table for sorted indices used to unpermute the tokens, which is the second output tensor of Permute.

  • probs (torch.Tensor) – The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.

transformer_engine.pytorch.initialize_ub(shape: list, tp_size: int, use_fp8: bool = False, dtype: torch.dtype = torch.bfloat16, ub_cfgs: dict | None = None, bootstrap_backend: str | torch.distributed.Backend = None) None

Initialize the Userbuffers communicator for overlapping tensor-parallel communications with GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules.

Parameters:
  • shape (list) – shape of the communication buffer, typically set to be the same as the global shape of the input tensor to a te.TransformerLayer forward pass, with the sequence and batch dimensions collapsed together – i.e.: (sequence_length * batch_size, hidden_size)

  • tp_size (int) – number of GPUs in the tensor-parallel process group

  • use_fp8 (bool = False) – allocate the communication buffer for FP8 GEMM inputs/outputs

  • dtype (torch.dtype = torch.bfloat16) – non-FP8 data type of the communication buffer when use_fp8 = False

  • ub_cfgs (dict = None) –

    Configuration dictionary with the structure ``` {

    <gemm_name>{

    “method”: <”ring_exchange” or “pipeline”>, “is_reduce_scatter”: bool, “num_sm”: int, “cga_size”: int, “set_sm_margin”: bool, “num_splits”: int, “aggregate”: bool, “atomic_gemm”: bool, “use_ce”: bool, “fp8_buf”: bool,

    }

    for te.TransformerLayer GEMM layers in [“qkv_fprop”, “qkv_dgrad”, “qkv_wgrad”, “proj_fprop”, “proj_dgrad”, “proj_wgrad”, “fc1_fprop”, “fc1_dgrad”, “fc2_dgrad”, “fc2_fprop”, “fc2_dgrad”].

  • bootstrap_backend (str = None) – torch.distributed communication backend for the all-gather, broadcast and barrier collectives during Userbuffers initialization. Not all backends are valid for every cluster configuration and distributed launch method even if they are available in PyTorch. When left unset, the initialization prefers to use the MPI backend, falling back first on Gloo and then NCCL if MPI is not available. Setting NVTE_UB_WITH_MPI=1 when building TE overrides this option and always initializes Userbuffers with direct MPI calls in C++, which also requires MPI_HOME=/path/to/mpi/root to be set at compile time.

transformer_engine.pytorch.destroy_ub()

Destroy all allocated userbuffer communicators.