paddle

class transformer_engine.paddle.Linear(in_features, out_features, **kwargs)

Applies a linear transformation to the incoming data \(y = xA^T + b\)

Parameters
  • in_features (int) – size of each input sample.

  • out_features (int) – size of each output sample.

  • weight_attr (Union[paddle.ParamAttr, None], default = None) – optional paddle.ParamAttr for weight.

  • bias_attr (Union[paddle.ParamAttr, None, bool], default = None) – optional paddle.ParamAttr for bias.

  • backend ({'transformer_engine', 'paddle'}, default = 'transformer_engine') – if set to ‘paddle’, a framework only no-FP8 path is executed with limited optimization.

Parallelism parameters
  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • parallel_mode ({None, ‘Column’, ‘Row’}, default = None) – used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described here. When set to None, no communication is performed.

  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

forward(*args, **kwargs)

Apply the linear transformation to the input.

Parameters
  • inp (paddle.Tensor) – Input tensor.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

class transformer_engine.paddle.LayerNorm(hidden_size, eps=1e-5, **kwargs)

Applies Layer Normalization over a mini-batch of inputs as described in the paper Layer Normalization

\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta\]

\(\gamma\) and \(\beta\) are learnable affine transform parameters of size hidden_size

Parameters
  • hidden_size (int) – size of each input sample.

  • eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.

  • weight_attr (Union[paddle.ParamAttr, None], default = None) – optional paddle.ParamAttr for weight.

  • bias_attr (Union[paddle.ParamAttr, None, bool], default = None) – optional paddle.ParamAttr for bias.

  • zero_centered_gamma (bool, default = 'False') –

    if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]

  • backend ({‘transformer_engine’, ‘paddle’}, default = transformer_engine) – backend to use for softmax operation.

Parallelism parameters

sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

class transformer_engine.paddle.LayerNormLinear(in_features, out_features, eps=1e-5, **kwargs)

Applies layer normalization followed by linear transformation to the incoming data.

Parameters
  • in_features (int) – size of each input sample.

  • out_features (int) – size of each output sample.

  • eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.

  • weight_attr (Union[paddle.ParamAttr, None], default = None) – optional paddle.ParamAttr for weight.

  • bias_attr (Union[paddle.ParamAttr, None, bool], default = None) – optional paddle.ParamAttr for bias.

  • normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.

  • return_layernorm_output (bool, default = False) – if set to True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.

  • zero_centered_gamma (bool, default = 'False') –

    if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]

  • backend ({'transformer_engine', 'paddle'}, default = 'transformer_engine') – if set to ‘paddle’, a framework only no-FP8 path is executed with limited optimization.

Parallelism parameters
  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • parallel_mode ({None, ‘Column’, ‘Row’}, default = None) – used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described here. When set to None, no communication is performed.

  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

forward(*args, **kwargs)

Apply layer normalization to the input followed by a linear transformation.

Parameters
  • inp (paddle.Tensor) – Input tensor.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

class transformer_engine.paddle.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, **kwargs)

Applies layer normalization on the input followed by the MLP module, consisting of 2 successive linear transformations, separated by the GeLU activation.

Parameters
  • hidden_size (int) – size of each input sample.

  • ffn_hidden_size (int) – intermediate size to which input samples are projected.

  • eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.

  • weight_attr (Union[paddle.ParamAttr, None], default = None) – optional paddle.ParamAttr for weight.

  • bias_attr (Union[paddle.ParamAttr, None, bool], default = None) – optional paddle.ParamAttr for bias.

  • normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.

  • activation (str, default = 'gelu') – activation function used. Options: ‘gelu’, ‘geglu’, ‘relu’, ‘reglu’, ‘squared_relu’, ‘swiglu’.

  • return_layernorm_output (bool, default = False) – if set to True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.

  • zero_centered_gamma (bool, default = 'False') –

    if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]

  • backend ({'transformer_engine', 'paddle'}, default = 'transformer_engine') – if set to ‘paddle’, a framework only no-FP8 path is executed with limited optimization.

Parallelism parameters
  • set_parallel_mode (bool, default = False) – if set to True, FC1 is used as Column Parallel and FC2 is used as Row Parallel as described here.

  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_group (paddle.distributed.collective.Group, default = None) – tensor parallel process group.

forward(*args, **kwargs)

Apply layer normalization to the input followed by a feedforward network (MLP Block).

Parameters
  • inp (paddle.Tensor) – Input tensor.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

class transformer_engine.paddle.FusedScaleMaskSoftmax(attn_mask_type, mask_func, **kwargs)

Scaled and masked softmax module for paddle with fused optimizations.

Parameters
  • attn_mask_type (str, default = causal) – type of attention mask, can be ‘causal’, ‘padding’, or ‘no_mask’.

  • mask_func (callable) – custom callable for applying the mask to the softmax input. masked_input=mask_func(inp, mask).

  • softmax_in_fp32 (bool, default = True) – perform softmax computation in fp32.

  • layernorm_epsilon (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.

  • backend ({‘transformer_engine’, ‘paddle’}, default = transformer_engine) – backend to use for operation.

forward(inp: paddle.Tensor, mask: paddle.Tensor, scale: Optional[float] = None)

FusedScaleMaskSoftmax fprop

class transformer_engine.paddle.DotProductAttention(num_attention_heads, kv_channels, **kwargs)

Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need.

Note

Argument attention_mask will be ignored in the forward call when attn_mask_type is set to “causal”.

Parameters
  • num_attention_heads (int) – number of attention heads in the transformer layer.

  • kv_channels (int) – number of channels in the key and value tensors.

  • num_gqa_groups (Optional[int] = None) – number of GQA groups in the transformer layer. Grouped Query Attention is described in this paper. This only affects the keys and values, not the queries. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e. num_gqa_groups = num_attention_heads.

  • attention_dropout (float, default = 0.1) – dropout probability for the dropout op during multi-head attention.

  • attn_mask_type ({‘causal’, ‘padding’, ‘no_mask’}, default = causal) – type of attention mask passed into softmax operation.

  • attention_type ({‘self’, ‘cross’}, default = self) – type of attention operation.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • backend ({‘transformer_engine’, ‘paddle’}, default = transformer_engine) – backend to use for attention operation.

forward(query_layer: paddle.Tensor, key_layer: paddle.Tensor, value_layer: paddle.Tensor, attention_mask: Optional[paddle.Tensor] = None, core_attention_bias_type: str = 'no_bias', core_attention_bias: Optional[paddle.Tensor] = None, set_zero: bool = True)

Dot Product Attention Layer.

Note

Argument attention_mask will be ignored when attn_mask_type is set to “causal”.

Parameters
  • query_layer (paddle.Tensor) – Query tensor.

  • key_layer (paddle.Tensor) – Key tensor.

  • value_layer (paddle.Tensor) – Value tensor.

  • attention_mask (Optional[paddle.Tensor], default = None) – Boolean tensor used to mask out softmax input when not using attention.

  • core_attention_bias_type (str, default = no_bias) – only support no_bias type currently, {no_bias}

  • core_attention_bias (Optional[paddle.Tensor], default = None) – Bias tensor for Q * K.T

  • set_zero (bool, default = True) – Whether to use the fast path to set output tensors to 0 or not.

class transformer_engine.paddle.MultiHeadAttention(hidden_size, num_attention_heads, **kwargs)

Multi-head Attention (MHA), including Query, Key, Value and Output projection.

Parameters
  • hidden_size (int) – hidden size of the model.

  • num_attention_heads (int) – number of attention heads.

  • attention_dropout (float, default = 0.1) – dropout probability for the dropout op during multi-head attention.

  • layernorm_epsilon (float, default = 1e-5) – epsilon to use in the layer norm operations.

  • weight_attr (Union[paddle.ParamAttr, None], default = None) – paddle.ParamAttr object for the weight parameter.

  • bias_attr (Union[paddle.ParamAttr, None, bool], default = None) – paddle.ParamAttr object for the bias parameter.

  • attn_mask_type ({‘causal’, ‘padding’, ‘no_mask’}, default = causal) – type of attention mask passed into softmax operation.

  • params_dtype (Optional[paddle.dtype], default = None) – data type for the weights and biases.

  • return_layernorm_output (bool, default = False) – whether to return the output of the layernorm operation.

  • input_layernorm (bool, default = False) – whether to apply layernorm to the input.

  • attention_type ({‘self’, ‘cross’}, default = self) – type of attention operation.

  • normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.

  • zero_centered_gamma (bool, default = False) – whether to zero initialize the gamma of the layernorm operation.

  • backend ({‘transformer_engine’, ‘paddle’}, default = transformer_engine) – backend to use for attention operation. If set to ‘paddle’, a framework only no-FP8 path is executed with limited optimization.

Parallelism parameters
  • set_parallel_mode (bool, default = False) – if set to True, QKV and FC1 layers are used as Column Parallel whereas PROJ and FC2 is used as Row Parallel as described here.

  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • num_gqa_groups (int, default = None) – number of GQA groups in the transformer layer. Grouped Query Attention is described in this paper. This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e. num_gqa_groups = num_attention_heads.

  • rng_state_name (str, default = local_seed) – Controls the rng state used for dropout on attention probs. The specified rng should be set different seeds for different TP ranks. It will be ignored if set_parallel_mode is False. The specified name should be registered through paddle.distributed.fleet.meta_parallel.get_rng_state_tracker() .add(rng_state_name, seed).

forward(hidden_states: paddle.Tensor, attention_mask: Optional[paddle.Tensor] = None, encoder_output: Optional[paddle.Tensor] = None, rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, core_attention_bias_type: str = 'no_bias', core_attention_bias: Optional[paddle.Tensor] = None, set_zero: bool = True, recompute_core_attention: bool = False, is_first_microbatch: Optional[bool] = None)

MultiHeadAttention Layer.

Parameters
  • hidden_states (paddle.Tensor) – Input tensor.

  • attention_mask (Optional[paddle.Tensor], default = None) – Boolean tensor used to mask out softmax input when not using attention.

  • encoder_output (Optional[paddle.Tensor], default = None) – Output of the encoder layer.

  • rotary_pos_emb (Tuple[paddle.Tensor, paddle.Tensor], default = None) – Embeddings for query and key tensors for applying rotary position embedding. By default no input embedding is applied.

  • core_attention_bias_type (str, default = no_bias) – only support no_bias type currently, {no_bias}

  • core_attention_bias (Optional[paddle.Tensor], default = None) – Bias tensor for Q * K.T

  • set_zero (bool, default = True) – Whether to use the fast path to set output tensors to 0 or not.

  • recompute_core_attention (bool, default = False) – If true, forward activations for core attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

class transformer_engine.paddle.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)

TransformerLayer is made up of an attention block and a feedforward network (MLP). This standard layer is based on the paper “Attention Is All You Need”.

Parameters
  • hidden_size (int) – size of each input sample.

  • ffn_hidden_size (int) – intermediate size to which input samples are projected.

  • num_attention_heads (int) – number of attention heads in the transformer layer.

  • num_gqa_groups (Optional[int], default = None) –

    number of GQA groups in the transformer layer. Grouped Query Attention is described in this paper. This only affects the keys and values, not the queries. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e. num_gqa_groups = num_attention_heads.

  • layernorm_epsilon (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.

  • hidden_dropout (float, default = 0.1) – dropout probability for the dropout op after FC2 layer.

  • attention_dropout (float, default = 0.1) – dropout probability for the dropout op during multi-head attention.

  • weight_attr (Union[paddle.ParamAttr, None], default = None) – optional paddle.ParamAttr for weight.

  • bias_attr (Union[paddle.ParamAttr, None, bool], default = None) – optional paddle.ParamAttr for bias.

  • self_attn_mask_type ({‘causal’, ‘padding’}, default = causal) – type of attention mask passed into softmax operation.

  • apply_residual_connection_post_layernorm (bool, default = False) – if set to True, residual connections are taken from the output of layer norm (default is taken from input of layer norm)

  • output_layernorm (bool, default = False) – if set to True, layer normalization is applied on the output side, after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation.

  • layer_type ({‘encoder’, ‘decoder’}, default = encoder) – if set to decoder, an additional cross-attn block is added after self-attn. This can be used for structures like T5 Transformer in conjunction with the encoder option.

  • normalization ({‘LayerNorm’, ‘RMSNorm’}, default = LayerNorm) –

  • zero_centered_gamma (bool, default = 'False') –

    if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]

  • activation (str, default = 'gelu') – Type of activation used in MLP block. Options are: ‘gelu’, ‘relu’, ‘reglu’, ‘geglu’ and ‘swiglu’.

  • params_dtype (paddle.dtype, default = paddle.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.

  • backend ({'transformer_engine', 'paddle'}, default = 'transformer_engine') – if set to ‘paddle’, a framework only no-FP8 path is executed with limited optimization.

Parallelism parameters
  • set_parallel_mode (bool, default = False) – if set to True, QKV and FC1 layers are used as Column Parallel whereas PROJ and FC2 is used as Row Parallel as described here.

  • sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.

  • tp_group (ProcessGroup, default = None) – tensor parallel process group.

  • attention_dropout_rng_state_name (str, default = local_seed) – Controls the rng state used for dropout on attention probs. The specified rng should be set different seeds for different TP ranks. It will be ignored if set_parallel_mode is False.

  • hidden_dropout_rng_state_name (str, default = global_seed) – Controls the rng state used for dropout on hidden states. The specified rng should be given the same seeds for different TP ranks. It will be ignored if set_parallel_mode is False. The specified name should be registered through paddle.distributed.fleet.meta_parallel.get_rng_state_tracker() .add(rng_state_name, seed).

forward(hidden_states: paddle.Tensor, attention_mask: Optional[paddle.Tensor] = None, encoder_output: Optional[paddle.Tensor] = None, enc_dec_attn_mask: Optional[paddle.Tensor] = None, rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, core_attention_bias_type: str = 'no_bias', core_attention_bias: Optional[paddle.Tensor] = None, set_zero: bool = True, recompute_core_attention: bool = False, is_first_microbatch: Optional[bool] = None)

Transformer Layer: attention block and a feedforward network (MLP)

Note

Argument attention_mask will be ignored when self_attn_mask_type is set to “causal”.

Parameters
  • hidden_states (paddle.Tensor) – Input tensor.

  • attention_mask (Optional[paddle.Tensor], default = None) – Boolean tensor used to mask out self-attention softmax input.

  • encoder_output (Optional[paddle.Tensor], default = None) – Output of the encoder block to be fed into the decoder block if using layer_type=”decoder”.

  • enc_dec_attn_mask (Optional[paddle.Tensor], default = None) – Boolean tensor used to mask out inter-attention softmax input if using layer_type=”decoder”.

  • rotary_pos_emb (Optional[Tuple[paddle.Tensor, paddle.Tensor]], default = None) – Embeddings for query and key tensors for applying rotary position embedding. By default no input embedding is applied

  • core_attention_bias_type (str, default = no_bias) –

  • core_attention_bias (Optional[paddle.Tensor], default = None) – Bias tensor for Q * K.T

  • set_zero (bool, default = True) – Whether to set output tensors to 0 or not before use.

  • recompute_core_attention (bool, default = False) – If true, forward activations for core attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.

  • is_first_microbatch ({True, False, None}, default = None) –

    During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:

    • during FP8 training, it allows caching of the FP8 versions of the weights

transformer_engine.paddle.fp8_autocast(enabled: bool = False, calibrating: bool = False, fp8_recipe: Optional[transformer_engine.common.recipe.DelayedScaling] = None, fp8_group: Optional[transformer_engine.paddle.constants.dist_group_type] = None)

Context manager for FP8 usage.

with fp8_autocast(enabled=True):
    out = model(inp)

Note

Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors with shapes where both dimensions are divisible by 16. In terms of the input to the full Transformer network, this typically requires padding sequence length to be multiple of 16.

Note

When fp8_recipe.reduce_amax==True, any module must not be invoked more than once inside a single fp8_autocast region. This is unsupported behavior because the amax reduction is handled during the exit of the fp8_autocast context. Calling the same module more than once inside an fp8_autocast region overrides the amax tensors before reduction can occur.

Parameters
  • enabled (bool, default = False) – whether or not to enable fp8

  • calibrating (bool, default = False) – calibration mode allows collecting statistics such as amax and scale data of fp8 tensors even when executing without fp8 enabled. This is useful for saving an inference ready fp8 checkpoint while training using a higher precision.

  • fp8_recipe (recipe.DelayedScaling, default = None) – recipe used for FP8 training.

  • fp8_group (paddle.distributed.collective.Group, default = None) – distributed group over which amaxes for the fp8 tensors are reduced at the end of each training step.

transformer_engine.paddle.recompute(function, *args, **kwargs)

This is a wrapper of paddle.distributed.fleet.utils.recompute. It provides necessary state information for fp8 layers.

Parameters
  • function (Callable) – paddle module used to run the forward and backward passes using the specified args and kwargs.

  • args (tuple) – tuple of torch tensors for inputs to function.

  • kwargs (dict) – dictionary of string keys for keyword arguments to function.