pyTorch

Modules

class transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)

Applies a linear transformation to the incoming data \(y = xA^T + b\)

On NVIDIA GPUs it is a drop-in replacement for torch.nn.Linear.

Parameters:
  • in_features (int) – size of each input sample.

  • out_features (int) – size of each output sample.

  • bias (bool, default = True) – if set to False, the layer will not learn an additive bias.

  • init_method (Callable, default = None) – used for initializing weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • parameters_split (Tuple[str, ...], default = None) – if a tuple of strings is provided, the weight and bias parameters of the module are exposed as N separate torch.nn.parameter.Parameter`s each, split along the first dimension, where `N is the length of the argument and the strings contained are the names of the split parameters.

Parallelism parameters:
  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.

  • parallel_mode ({None, ‘Column’, ‘Row’}, default = None) – used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described here. When set to None, no communication is performed.

  • skip_weight_param_allocation (bool, default = False) – if set to True, weight parameter is not allocated and must be passed as a keyword argument weight during the forward pass.

Optimization parameters:
  • fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additional main_grad attribute (used instead of the regular grad) which is a pre-allocated buffer of the correct size to accumulate gradients in.

  • return_bias (bool, default = False) – when set to True, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.

  • params_dtype (torch.dtype, default = torch.float32) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.

forward(inp: Tensor, weight: Tensor | None = None, bias: Tensor | None = None, is_first_microbatch: bool | None = None) Tensor | Tuple[Tensor, ...]

Apply the linear transformation to the input.

Parameters:
  • inp (torch.Tensor) – Input tensor.

  • weight (torch.Tensor, default = None) – An optional weight tensor for the module. This argument is compulsory if module is initialized with skip_weight_param_allocation=True

  • bias (torch.Tensor, default = None) – An optional bias tensor for the module. This argument is compulsory if module is initialized with skip_weight_param_allocation=True and one of use_bias or return_bias

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

    • it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)

class transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)

Applies Layer Normalization over a mini-batch of inputs as described in the paper Layer Normalization

\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta\]

\(\gamma\) and \(\beta\) are learnable affine transform parameters of size hidden_size

Parameters:
  • hidden_size (int) – size of each input sample.

  • eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.

  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • params_dtype (torch.dtype, default = torch.float32) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.

  • zero_centered_gamma (bool, default = 'False') –

    if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]

class transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)

Applies layer normalization followed by linear transformation to the incoming data.

Parameters:
  • in_features (int) – size of each input sample.

  • out_features (int) – size of each output sample.

  • eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.

  • bias (bool, default = True) – if set to False, the layer will not learn an additive bias.

  • init_method (Callable, default = None) – used for initializing weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • return_layernorm_output (bool, default = False) – if set to True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.

  • parameters_split (Tuple[str, ...], default = None) – if a tuple of strings is provided, the weight and bias parameters of the module are exposed as N separate torch.nn.parameter.Parameter`s each, split along the first dimension, where `N is the length of the argument and the strings contained are the names of the split parameters.

  • zero_centered_gamma (bool, default = 'False') –

    if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]

Parallelism parameters:
  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.

  • parallel_mode ({None, ‘Column’, ‘Row’}, default = None) – used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described here. When set to None, no communication is performed.

  • skip_weight_param_allocation (bool, default = False) – if set to True, weight parameter is not allocated and must be passed as a keyword argument weight during the forward pass.

Optimization parameters:
  • fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient.

  • return_bias (bool, default = False) – when set to True, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.

  • params_dtype (torch.dtype, default = torch.float32) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.

forward(inp: Tensor, weight: Tensor | None = None, bias: Tensor | None = None, is_first_microbatch: bool | None = None) Tensor | Tuple[Tensor, ...]

Apply layer normalization to the input followed by a linear transformation.

Parameters:
  • inp (torch.Tensor) – Input tensor.

  • weight (torch.Tensor, default = None) – An optional weight tensor for the module. This argument is compulsory if module is initialized with skip_weight_param_allocation=True

  • bias (torch.Tensor, default = None) – An optional bias tensor for the module. This argument is compulsory if module is initialized with skip_weight_param_allocation=True and one of use_bias or return_bias

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

    • it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)

class transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)

Applies layer normalization on the input followed by the MLP module, consisting of 2 successive linear transformations, separated by the GeLU activation.

Parameters:
  • hidden_size (int) – size of each input sample.

  • ffn_hidden_size (int) – intermediate size to which input samples are projected.

  • eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.

  • bias (bool, default = True) – if set to False, the FC2 layer will not learn an additive bias.

  • init_method (Callable, default = None) – used for initializing FC1 weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • output_layer_init_method (Callable, default = None) – used for initializing FC2 weights in the following way: output_layer_init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • return_layernorm_output (bool, default = False) – if set to True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.

  • zero_centered_gamma (bool, default = 'False') –

    if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]

Parallelism parameters:
  • set_parallel_mode (bool, default = False) – if set to True, FC1 is used as Column Parallel and FC2 is used as Row Parallel as described here.

  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.

Optimization parameters:
  • fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient.

  • return_bias (bool, default = False) – when set to True, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.

  • params_dtype (torch.dtype, default = torch.float32) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.

  • seq_length (int) – sequence length of input samples. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.

  • micro_batch_size (int) – batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.

forward(inp: Tensor, is_first_microbatch: bool | None = None) Tensor | Tuple[Tensor, ...]

Apply layer normalization to the input followed by a feedforward network (MLP Block).

Parameters:
  • inp (torch.Tensor) – Input tensor.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

    • it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)

class transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)

Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need.

Note

Argument attention_mask will be ignored in the forward call when attn_mask_type is set to “causal”.

Warning

For the default attention mechanism, this module executes a non-deterministic version of flash-attn whenever possible in order to achieve optimal performance. To observe deterministic behavior, set the environment variable NVTE_ALLOW_NONDETERMINISTIC_ALGO=0. In order to disable flash-attn entirely, set NVTE_FLASH_ATTN=0.

Parameters:
  • num_attention_heads (int) – number of attention heads in the transformer layer.

  • kv_channels (int) – number of key-value channels.

  • attention_dropout (float, default = 0.0) – dropout probability for the dropout op during multi-head attention.

  • attn_mask_type ({‘causal’, ‘padding’}, default = causal) – type of attention mask passed into softmax operation.

Parallelism parameters:
  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_size (int, default = 1) – tensor parallel world size.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

forward(query_layer: Tensor, key_layer: Tensor, value_layer: Tensor, attention_mask: Tensor | None = None, checkpoint_core_attention: bool = False) Tensor

Dot Product Attention Layer.

Note

Argument attention_mask will be ignored when attn_mask_type is set to “causal”.

Note

Input tensors query_layer, key_layer, and value_layer must each be of shape (sequence_length, batch_size, num_attention_heads, kv_channels). Output of shape (sequence_length, batch_size, num_attention_heads * kv_channels) is returned.

Parameters:
  • query_layer (torch.Tensor) – Query tensor.

  • key_layer (torch.Tensor) – Key tensor.

  • value_layer (torch.Tensor) – Value tensor.

  • attention_mask (Optional[torch.Tensor], default = None) – Boolean tensor used to mask out softmax input when not using flash-attn.

  • checkpoint_core_attention (bool, default = False) – If true, forward activations for attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.

class transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)

TransformerLayer is made up of an attention block and a feedforward network (MLP). This standard layer is based on the paper “Attention Is All You Need”.

Note

Argument attention_mask will be ignored in the forward call when self_attn_mask_type is set to “causal”.

Parameters:
  • hidden_size (int) – size of each input sample.

  • ffn_hidden_size (int) – intermediate size to which input samples are projected.

  • num_attention_heads (int) – number of attention heads in the transformer layer.

  • layernorm_epsilon (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.

  • hidden_dropout (float, default = 0.1) – dropout probability for the dropout op after FC2 layer.

  • attention_dropout (float, default = 0.1) – dropout probability for the dropout op during multi-head attention.

  • init_method (Callable, default = None) – used for initializing weights of QKV and FC1 weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • output_layer_init_method (Callable, default = None) – used for initializing weights of PROJ and FC2 in the following way: output_layer_init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

  • apply_residual_connection_post_layernorm (bool, default = False) – if set to True, residual connections are taken from the output of layer norm (default is taken from input of layer norm)

  • layer_number (int, default = None) – layer number of the current TransformerLayer when multiple such modules are concatenated to form a transformer block.

  • apply_query_key_layer_scaling (bool, default = False) – apply query-key layer scaling during BMM1 by a factor of layer_number

  • output_layernorm (bool, default = False) – if set to True, layer normalization is applied on the output side, after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation.

  • attention_softmax_in_fp32 (bool, default = True) – if set to False, softmax is executed in the dtype of activation tensors.

  • layer_type ({‘encoder’, ‘decoder’}, default = encoder) – if set to decoder, an additional cross-attn block is added after self-attn. This can be used for structures like T5 Transformer in conjunction with the encoder option.

  • kv_channels (int, default = None) – number of key-value channels. defaults to hidden_size / num_attention_heads if None.

  • self_attn_mask_type ({‘causal’, ‘padding’}, default = causal) – type of attention mask passed into softmax operation.

  • zero_centered_gamma (bool, default = 'False') –

    if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]

  • qkv_weight_interleaved (bool, default = True) – if set to False, the QKV weight is interpreted as a concatenation of query, key, and value weights along the 0th dimension. The default interpretation is that the individual q, k, and v weights for each attention head are interleaved. This parameter is set to False when using fuse_qkv_params=False.

Parallelism parameters:
  • set_parallel_mode (bool, default = False) – if set to True, QKV and FC1 layers are used as Column Parallel whereas PROJ and FC2 is used as Row Parallel as described here.

  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.

Optimization parameters:
  • fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient.

  • params_dtype (torch.dtype, default = torch.float32) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.

  • seq_length (int) – sequence length of input samples. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.

  • micro_batch_size (int) – batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.

  • drop_path_rate (float, default = 0.0) – when > 0.0, applies stochastic depth per sample in the main path of the residual block.

  • fuse_qkv_params (bool, default = ‘False’) – if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.

forward(hidden_states: Tensor, attention_mask: Tensor | None = None, encoder_output: Tensor | None = None, enc_dec_attn_mask: Tensor | None = None, is_first_microbatch: bool | None = None, checkpoint_core_attention: bool = False, inference_params: Any | None = None) Tensor

Transformer Layer: attention block and a feedforward network (MLP)

Note

Argument attention_mask will be ignored when self_attn_mask_type is set to “causal”.

Parameters:
  • hidden_states (torch.Tensor) – Input tensor.

  • attention_mask (Optional[torch.Tensor], default = None) – Boolean tensor used to mask out self-attention softmax input.

  • encoder_output (Optional[torch.Tensor], default = None) – Output of the encoder block to be fed into the decoder block if using layer_type=”decoder”.

  • enc_dec_attn_mask (Optional[torch.Tensor], default = None) – Boolean tensor used to mask out inter-attention softmax input if using layer_type=”decoder”.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

    • it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)

  • checkpoint_core_attention (bool, default = False) – If true, forward activations for core attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.

Functions

transformer_engine.pytorch.fp8_autocast(enabled: bool = False, calibrating: bool = False, fp8_recipe: DelayedScaling | None = None, fp8_group: ProcessGroup | None = None) None

Context manager for FP8 usage.

with fp8_autocast(enabled=True):
    out = model(inp)

Note

Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors with shapes where both dimensions are divisible by 16. In terms of the input to the full Transformer network, this typically requires padding sequence length to be multiple of 16.

Parameters:
  • enabled (bool, default = False) – whether or not to enable fp8

  • calibrating (bool, default = False) – calibration mode allows collecting statistics such as amax and scale data of fp8 tensors even when executing without fp8 enabled. This is useful for saving an inference ready fp8 checkpoint while training using a higher precision.

  • fp8_recipe (recipe.DelayedScaling, default = None) – recipe used for FP8 training.

  • fp8_group (torch._C._distributed_c10d.ProcessGroup, default = None) – distributed group over which amaxes for the fp8 tensors are reduced at the end of each training step.

transformer_engine.pytorch.checkpoint(function: Callable, distribute_saved_activations: bool, get_cuda_rng_tracker: Callable, tp_group: ProcessGroup, *args: Tuple[Tensor, ...], **kwargs: Dict[str, Any]) Tuple[Tensor, ...]

Checkpoint a part of the model by trading compute for memory. This function is based on torch.utils.checkpoint.checkpoint.

Warning

It is the user’s responsibility to ensure identical behavior when calling function from the forward and backward pass. If different output is produced (e.g. due to global state), then the checkpointed version won’t be numerically equivalent.

Warning

The tuple args must contain only tensors (or None) in order to comply with PyTorch’s save_for_backward method. function must be callable to produce valid outputs with the inputs args and kwargs.

Parameters:
  • function (Callable) – whether or not to enable fp8

  • distribute_saved_activations (bool) – if set to True, the first tensor argument is distributed across the specified tensor parallel group (tp_group) before saving it for the backward pass.

  • get_cuda_rng_tracker (Callable) – python function with the functionality to retrieve a state via state = get_cuda_rng_tracker().get_states() and to reset the state via get_cuda_rng_tracker().set_states(state). This is used to ensure any extra cuda rng state or general global state can be reproduced across the 2 forward phases; original and recompute.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • args (tuple) – tuple of torch tensors for inputs to function.

  • kwargs (dict) – dictionary of string keys for keyword arguments to function.