Jax

Pre-defined Variable of Logical Axes

Variables are available in transformer_engine.jax.sharding.

  • BATCH_AXES: The logical axis of batch dimension. It is usually sharded along DP + FSDP on Mesh.

  • SEQLEN_AXES: The logical axis of sequence length dimension. It is usually not sharded.

  • SEQLEN_TP_AXES: The logical axis of sequence length dimension. It is usually sharded along TP on Mesh.

  • HEAD_AXES: The logical axis of head dimension of MHA. It is usually sharded along TP on Mesh.

  • HIDDEN_AXES: The logical axis of hidden dimension. It is usually not sharded.

  • HIDDEN_TP_AXES: The logical axis of hidden dimension. It is usually sharded along TP on Mesh.

  • JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded.

Modules

class transformer_engine.jax.flax.TransformerLayerType

TransformerLayerType is an Enum class to specify a type of TransformerLayer

Values
  • ENCODER – Encoder type of TransformerLayer.

  • DECODER – Decoder type of TransformerLayer.

class transformer_engine.jax.MeshResource

A data container to indicate which axis in Mesh for data parallelism and which for tensor parallelism.

Parameters
  • dp_resource (str, default = None) – The axis name in Mesh used to shard batches along. If it is None, then data parallelism is disabled.

  • tp_resource (str, default = None) – The axis name in Mesh used to split the hidden dimensions along. If it is None, then tensor parallelism is disabled.

  • fsdp_resource (str, default = None) – The axis name in Mesh used to split the batch and weights along. If it is None, then full-sharded data parallelism is disabled.

  • pp_resource (str, default = None) – The axis name in Mesh used to split model layers. along. If it is None, then pipeline parallelism is disabled.

transformer_engine.jax.fp8_autocast(enabled: bool = False, fp8_recipe: Optional[transformer_engine.common.recipe.DelayedScaling] = None, mesh_resource: Optional[transformer_engine.jax.sharding.MeshResource] = None)

Context manager for FP8 usage.

mesh_shape = (4, 2)
dp_mesh_axis_name = 'data_parallel'
tp_mesh_axis_name = 'tensor_parallel'
devices = np.asarray(jax.devices()).reshape(*mesh_shape)

with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
    mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)

    with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
        rules = extend_logical_axis_rules(tuple())
        transformer = TransformerLayer()

        with partitioning.axis_rules(rules):
            pjit(transformer.init, ...)(...)

Note

We only support margin, fp8_format, interval, amax_history_len and :attr:`amax_compute_algo`(with value ‘max’ and ‘most_recent’) in recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling will trigger an assertion.

Parameters
  • enabled (bool, default = False) – Whether or not to enable fp8

  • fp8_recipe (recipe.DelayedScaling, default = None) – Recipe used for FP8 training.

  • mesh_resource (MeshResource, default = None) – Specify the mesh axes for data and tensor parallelism to shard along. If set to None, then no data or tensor parallelism will be used.

transformer_engine.jax.update_collections(new: Collection, original: Collection)

A helper to update Flax’s Collection.

Collection = [dict, flax.core.frozen_dict.FrozenDict]

Parameters
  • new (Collection) – A collection that includes new data.

  • original (Collection) – The base collection.

Returns

outputs – The updated collection.

Return type

Collection

transformer_engine.jax.update_fp8_metas(state: Collection)

Calculate new fp8 scales and its inverse via the followed formula

sf = (fp8_max / amax) / (2 ^ margin)
sf = sf if amax > 0.0, else original_scale
updated_scale = sf if isfinite(amax), else original_scale)
updated_scale_inv = 1/updated_scale

Collection = [dict, flax.core.frozen_dict.FrozenDict]

Parameters

state (Collection) – A collection that includes FP8 metas.

Returns

outputs – The collection with updated FP8 metas.

Return type

Collection

class transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)

Applies layer normalization over a mini-batch of inputs. There are two types of normalization supported by this module, regular and root mean square layer Normalization.

The regular layer normalization is as described in the paper Layer Normalization

\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

\(\gamma\) and \(\beta\) are learnable affine transform parameters of size of each input sample.

The root mean square layer normalization (RMSNorm) is as described in the paper Root Mean Square Layer Normalization

\[y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma\]
\[RMS = \sqrt{\mathrm{E}[x^2]}\]

\(\gamma\) is learnable affine transform parameters of size of each input sample.

Parameters
  • epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.

  • zero_centered_gamma (bool, default = False) –

    If set to True, the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    This parameter is only applicable for ‘layernorm’. The default of scale_init will also be changed. See scale_init.

  • scale_init (Initializer, default = None) – Used for initializing scale factors \(\gamma\). If None is provided, scale_init is set according to the value of zero_centered_gamma. If zero_centered_gamma is set to True, then scale_init is flax.linen.initializers.zeros. Otherwise, scale_init is flax.linen.initializers.ones. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • scale_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the scale factors \(\gamma\) with a corresponding mesh.

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing shift factors \(\beta\), only used when layernorm_type='layernorm'. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • bias_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the shift factors \(\beta\) with a corresponding mesh. only used when layernorm_type='layernorm'.

Optimization parameters
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – the data type used to allocate the initial parameters.

  • transpose_batch_sequence (bool, default = False) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).

__call__(x: jax.numpy.ndarray)

Applies layer normalization to the input inputs.

Parameters

inputs (jax.numpy.ndarray) – Input tensors.

Returns

outputs – Output tensors.

Return type

jax.numpy.ndarray

class transformer_engine.jax.flax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)

Applies a linear transformation to the incoming data \(y = xA^T + b\)

Parameters
  • features (Union[Iterable[int], int]) – The hidden size of each output sample.

  • kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) Used for initializing weights. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • kernel_axes (Tuple[str, ...], default = ()) – The name of axes used to shard the weights with a corresponding mesh.

  • use_bias (bool, default = False) – Indicate whether to enable bias shifting. If set to False, the layer will not learn an additive bias.

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias, only used when use_bias=True. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • bias_axes (Tuple[str, ...], default = ()) – The name of axes used to shard bias with a corresponding mesh, only used when use_bias=True.

  • axis (Union[Iterable[int], int], default = -1) – An integer tuple with axes to apply the transformation on.

Optimization parameters
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

  • transpose_batch_sequence (bool, default = True) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).

__call__(inputs: Array)

Apply the linear transformation to the input.

Parameters

inputs (jax.numpy.ndarray) – Input tensors.

Returns

outputs – Output tensors.

Return type

jax.numpy.ndarray

class transformer_engine.jax.flax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)

Applies layer normalization followed by linear transformation to the incoming data.

Parameters
  • features (Union[Iterable[int], int]) – The hidden size of each output sample.

  • enable_layernorm (bool, default = True) – Indicate whether to enable layer normalization before linear transformation.

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.

  • epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.

  • zero_centered_gamma (bool, default = False) –

    If set to True, the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    This parameter is only applicable for ‘layernorm’. The default of scale_init will also be changed. See scale_init

  • scale_init (Initializer, default = None) – Used for initializing scale factors \(\gamma\). If None is provided, scale_init is set according to the value of zero_centered_gamma. If zero_centered_gamma is set to True, then scale_init is flax.linen.initializers.zeros. Otherwise, scale_init is flax.linen.initializers.ones. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • scale_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the scale factors \(\gamma\) with a corresponding mesh, only used when enable_layernorm=True.

  • ln_bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing shift factors \(\beta\), only used when enable_layernorm=True and layernorm_type='layernorm'. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • ln_bias_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the shift factors \(\beta\) with a corresponding mesh. It is only used when enable_layernorm=True and layernorm_type='layernorm'.

  • kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) Used for initializing weights. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • kernel_axes (Tuple[str, ...], default = ()) – The name of axes used to shard the weights with a corresponding mesh.

  • use_bias (bool, default = False) – Indicate whether to enable bias shifting. If set to False, the layer will not learn an additive bias.

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias, only used when use_bias=True. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • bias_axes (Tuple[str, ...], default = ()) – The name of axes used to shard bias with a corresponding mesh, only used when use_bias=True.

  • return_layernorm_output (bool, default = True) – Indicate whether to return the output of layer normalization. If set False, return None as the second tensor in outputs.

  • axis (Union[Iterable[int], int], default = -1) – An integer tuple with axes to apply the transformation on.

  • layernorm_input_axes (Tuple[str, ...], default = None) – Indicate the logical axes of sharding constraint to the input of layernorm, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint.

  • dot_input_axes (Tuple[str, ...], default = None) – Indicate the logical axes of sharding constraint to the input of dot, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint.

Optimization parameters
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

  • transpose_batch_sequence (bool, default = True) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).

  • depth_scaling (float, default = None) – The factor to scale the output from DenseGeneral. It should be a float value or None. When None is set, then no scaling is applied.

__call__(inputs: Array)

Apply layer normalization to the input followed by a linear transformation.

Parameters

inputs (jax.numpy.ndarray) – Input tensor.

Returns

  • outputs (jax.numpy.ndarray) – Output tensors.

  • ln_outputs (jax.numpy.ndarray) – The output tensors of layer normalization. If return_layernorm_output=False, then this would be None.

class transformer_engine.jax.flax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)

Applies layer normalization on the input followed by the MLP module, consisting of 2 successive linear transformations, separated by given activations.

Parameters
  • intermediate_dim (int, default = 2048) – Intermediate size to which input samples are projected.

  • enable_layernorm (bool, default = True) – Indicate whether to enable layer normalization before linear transformation.

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.

  • epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.

  • zero_centered_gamma (bool, default = False) –

    If set to True, the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    This parameter is only applicable for ‘layernorm’. The default of scale_init will also be changed. See scale_init.

  • scale_init (Initializer, default = None) – Used for initializing scale factors \(\gamma\). If None is provided, scale_init is set according to the value of zero_centered_gamma. If zero_centered_gamma is set to True, then scale_init is flax.linen.initializers.zeros. Otherwise, scale_init is flax.linen.initializers.ones. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • scale_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the scale factors \(\gamma\) with a corresponding mesh, only used when enable_layernorm=True.

  • ln_bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing shift factors \(\beta\), only used when enable_layernorm=True and layernorm_type='layernorm'. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • ln_bias_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the shift factors \(\beta\) with a corresponding mesh. Only used when enable_layernorm=True and layernorm_type='layernorm'.

  • kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) Used for initializing the weights of both linear transformations. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • kernel_axes_1 (Tuple[str, ...], default = ('embed', 'act', 'mlp')) – The name of axes used to shard the weights with a corresponding mesh for the weight of the first linear transformations.

  • kernel_axes_2 (Tuple[str, ...], default = ('mlp', 'embed')) – The name of axes used to shard the weights with a corresponding mesh for the weight of the second linear transformations.

  • use_bias (bool, default = False) – Indicate whether to enable bias shifting. If set to False, the layer will not learn an additive bias.

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias, only used when use_bias=True. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • bias_axes_1 (Tuple[str, ...], default = ('mlp',)) – The name of axes used to shard bias with a corresponding mesh for the weight of the first linear transformations. Only used when use_bias=True.

  • bias_axes_2 (Tuple[str, ...], default = ('embed',)) – The name of axes used to shard bias with a corresponding mesh for the weight of the second linear transformations. Only used when use_bias=True.

  • return_layernorm_output (bool, default = True) – Indicate whether to return the output of layer normalization. If set False, return None as the second tensor in outputs.

  • activations (Sequence[Union[str, Callable]], default = ('relu',)) – The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer.

  • intermediate_dropout_rng_name (str, default = 'dropout') – The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.

  • intermediate_dropout_rate (float, default = 0.1) – Dropout probability for the dropout op after the activations.

  • intermediate_hidden_dropout_dims (Sequence[int], default = ()) – Dimensions that will share the same dropout mask for hidden

  • axis (Union[Iterable[int], int], default = -1) – An integer tuple with axes to apply the transformation on.

  • layernorm_input_axes (Tuple[str, ...], default = None) – Indicate the logical axes of sharding constraint to the input of layernorm, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint.

  • dot_1_input_axes (Tuple[str, ...], default = None) – Indicate the logical axes of sharding constraint to the input of 1st dot, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint.

  • dot_2_input_axes (Tuple[str, ...], default = None) – Indicate the logical axes of sharding constraint to the input of 2nd dot, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint.

Optimization parameters
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

  • transpose_batch_sequence (bool, default = True) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).

__call__(inputs: Array, deterministic: bool = False)

Apply layer normalization to the input followed by a feedforward network (MLP Block).

Parameters
  • inputs (jax.numpy.ndarray) – Input tensor.

  • deterministic (bool, default = False) – Disable dropout ops if set to True.

Returns

  • outputs (jax.numpy.ndarray) – Output tensors.

  • ln_outputs (jax.numpy.ndarray) – The output tensors of layer normalization. If return_layernorm_output=False, then this would be None.

class transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)

T5-style relative positional embeddings to the attention logits.

Parameters
  • num_buckets (int) – The number of buckets to bucket distances between key and query positions into.

  • max_distance (int) – The maximum distance before everything is lumped into the last distance bucket.

  • num_attention_heads (int) – Number of attention heads in the transformer layer.

  • embedding_init (Initializer, default = flax.linen.linear.default_embed_init) – Used for initializing relative embedding tables.

  • embedding_axes (Tuple[str, ...], default = ('heads', 'relpos_buckets')) – The name of axes used to shard embedding attention bias with a corresponding mesh.

Optimization parameters

dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

__call__(q_seqlen, k_seqlen, bidirectional=True)

Generate relative position embedding attention biases.

Parameters
  • q_seqlen (int) – The sequence length of query.

  • k_seqlen (int) – The sequence length of key.

  • bidirectional (bool, default = True) – Indicate whether to allow positive memory-query relative position embeddings.

Returns

output – An attention bias with shape (1, num_attention_heads, q_seqlen, k_seqlen).

Return type

jax.numpy.ndarray

class transformer_engine.jax.flax.DotProductAttention(head_dim, num_heads, **kwargs)

Dot Product Attention (DPA). Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need.

Note

The DotProductAttention module supports two backends: the unfused and the fused attention mechanisms. The unfused attention is implemented using JAX native operations, providing broad compatibility and flexibility. In contrast, the fused attention uses cuDNN fused attention for higher performance and lower memory usage on the supported hardwares. Users can select between these two backends via the NVTE_FUSED_ATTN environment variable:

  • Set NVTE_FUSED_ATTN=0 for unfused attention (default).

  • Set NVTE_FUSED_ATTN=1 for fused attention. If the required cuDNN fused attention kernel is not available on the system, a warning will be issued, and the module will automatically fall back to the unfused backend.

Parameters
  • head_dim (int) – The hidden dimension of each attention head.

  • num_attention_heads (int) – The number of attention heads.

  • num_gqa_groups (int, default = None) – Number of GQA groups. When None is present, it is equal to num_attention_heads. Grouped Query Attention is described in this paper. This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e. num_gqa_groups = num_attention_heads.

  • attention_dropout (float, default = 0.0) – Dropout probability for the dropout op after the softmax.

  • attn_mask_type (str, default = 'causal') – Type of the attention mask passed into softmax operation in the self attention. Available options: {‘no_mask’, ‘padding’, ‘causal’, ‘causal_padding’} Introduced in v0.10.0.

  • attn_bias_type (Optional[str], default = None) – Type of the attention bias passed in the self attention. Available options: {‘no_bias’, ‘pre_scale_bias’, ‘post_scale_bias’}. When default is present, the type is automatically decided by the MHA’s bias parameter. Where it is post_scale_bias if there is bias. Otherwise no_bias is used.

  • dropout_rng_name (str, default = 'dropout') – The key in given RNGs via flax.linen.Module.apply that is used to generate Dropout masks in the core attention.

  • float32_logits (bool, default = False) – Whether to compute attention logits in float32 for the unfused attention backend. For fused attention backend, the accumulation is always float32 without the perf overhead.

  • qkv_layout (str, default = 'bshd_bshd_bshd') –

    Specifies the dimensional layout format for the query, key, and value tensors in __call__(). It indicates how the inputs are processed. Available options: {‘bs3hd’, ‘bshd_bs2hd’, ‘bshd_bshd_bshd’}. Where

    • bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d]. key and value arguments in __call__() are ignored in this layout.

    • bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked tensor with shape = [b, s, 2, h, d]. value argument in __call__() is ignored.

    • bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].

    Explanation of denotations:

    • b: batch size

    • s: seqeuence length

    • h: num_attention_heads or num_gqa_groups

    • d: head dimension

  • scale_factor (Optional[float], default = None) – Scale factor to apply on query. When None is present, the scale factor is equal to \(\frac{1}{\sqrt{head\_dim}}\). This is useful for model like T5X, which doesn’t need to apply scale on query, which is to set scale_factor=1..

  • transpose_batch_sequence (bool, default = True) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. if set to True, the input tensors should be in (seqlen, batch, …), otherwise (batch, seqlen, …).

Optimization parameters

dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

__call__(query: Array, key: Array, value: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, *, deterministic: bool = False)
Parameters
  • query (jax.numpy.ndarray) – The details of query tensor representation is described in qkv_layout.

  • key (jax.numpy.ndarrary) – The details of kery tensor representation is described in qkv_layout.

  • value (jax.numpy.ndarrary) – The details of value tensor representation is described in qkv_layout.

  • mask (jax.numpy.ndarray, default = None) – Boolean tensor used to mask out the attention softmax input. True means to mask out the corresponding values.

  • bias (jax.numpy.ndarray, default = None) – A tensor used to shift attention softmax input.

  • * – Below parameters are keyword only

  • deterministic (bool, default = False) – Disable dropout layers if set to True.

Returns

outputs – Output tensors.

Return type

jax.numpy.ndarray

class transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)

Multi-head Attention (MHA), including Query, Key, Value and Output projection.

Parameters
  • head_dim (int) – The hidden dimension of each attention head.

  • num_attention_heads (int) – The number of attention heads.

  • num_gqa_groups (int, default = None) –

    Number of GQA groups. When None is present, it is equal to num_attention_heads. Grouped Query Attention is described in this paper. This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e. num_gqa_groups = num_attention_heads.

  • attention_dropout (float, default = 0.0) – Dropout probability for the dropout op after the softmax.

  • attn_mask_type (str, default = 'causal') – Type of the attention mask passed into softmax operation in the attention. Available options: {‘no_mask’, ‘padding’, ‘causal’, ‘causal_padding’} Introduced in v0.10.0.

  • attn_bias_type (Optional[str], default = None) – Type of the attention bias passed in the attention. Available options: {‘no_bias’, ‘pre_scale_bias’, ‘post_scale_bias’}. When default is present, the type is automatically decided by the MHA’s bias parameter. Where it is post_scale_bias if there is bias. Otherwise no_bias is used.

  • dropout_rng_name (str, default = 'dropout') – The key in given RNGs via flax.linen.Module.apply that is used to generate Dropout masks in the core attention.

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.

  • layernorm_epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.

  • zero_centered_gamma (bool, default = False) –

    If set to True, the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    This parameter is only applicable for ‘layernorm’.

  • kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) Used for initializing the QKV and output projection weights. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • use_bias (bool, default = False) – Indicate whether or not to enable bias shifting for QKV and output projections. If set to False, the layer will not learn additive biases.

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias of QKVO projections, only used when use_bias=True. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • input_layernorm (bool, default = True) – If set to False, layer normalization to the input is not applied.

  • return_layernorm_output (bool, default = False) – If set to True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.

  • enable_rotary_pos_emb (bool, default = False) – Whether to enable rotary position embedding to projected query and key.

  • rotary_pos_emb_windows (Tuple[int, int], default = (1, 10000)) – Indicate the min and max time-scales of rotary position embedding, only used when enable_rotary_pos_emb=True

  • rotary_pos_emb_group_method (str, default = 'consecutive') – Indicate the method to coupled the coordinates. It should be one of [‘consecutive’, ‘alternate’]. ‘alternate’ is to pair index \(i\) with \(i + d/2\) , d is the hidden dimension. ‘consecutive’ pairs index \(i\) with \(i + 1\).

  • enable_sequence_parallel (bool, default = False) – Whether to enable sequence parallelism to operations except dot.

  • num_heads (int, default = None) – Deprecated. Please refer num_attention_heads.

  • dropout_rate (float, default = None) – Deprecated. Please refer attention_dropout.

  • output_layernorm (bool, default = None) – Deprecated. Please refer input_layernorm

  • apply_residual_connection_post_layernorm (bool, default = None) – Deprecated. Please refer return_layernorm_output.

Optimization parameters
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

  • fuse_qkv_params (bool, default = True) – If set to True, this module exposes a single fused parameter for query-key-value for self-attention and key-value for cross-attention.

  • transpose_batch_sequence (bool, default = True) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. if set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).

  • scale_attn_logits (bool, default = False) – Indicate whether to scale attention logits. If set to True, \(\frac{Q}{\sqrt{head\_dim}*K}\), else \(Q*K\)

  • scaled_query_init (bool, default = True) – Whether to scale WQ on initialization by \(\frac{1}{\sqrt{head\_dim}}\)

  • float32_logits (bool, default = False) – Whether to compute attention logits in float32 for the unfused attention backend. For fused attention backend, the accumulation is always float32 without the perf overhead.

  • fuse_qkv (bool, default = None) – Deprecated. Please refer fuse_qkv_params

__call__(inputs_q: Array, inputs_kv: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, *, decode: bool = False, deterministic: bool = False)

MultiHeadAttention Layer: [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

Parameters
  • inputs_q (jax.numpy.ndarray) – Input tensor for query projection.

  • inputs_kv (jax.numpy.ndarray) – Input tensor for key/value projection.

  • mask (jax.numpy.ndarray, default = None) – Boolean tensor used to mask out the attention softmax input. True means mask out the corresponding values.

  • bias (jax.numpy.ndarray, default = None) – A tensor used to shift the attention softmax input.

  • *

  • decode (bool, default = False) – Indicate whether to prepare and use an autoregressive cache.

  • deterministic (bool, default = False) – Disable dropout layers if set to True.

Returns

outputs – Output tensors.

Return type

jax.numpy.ndarray

class transformer_engine.jax.flax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)

TransformerLayer is made up of a relative embedding, an attention block and a feedforward network (MLP). This standard layer is based on the paper “Attention Is All You Need”.

Parameters
  • hidden_size (int, default = 512) – The hidden size of each input sample.

  • mlp_hidden_size (int, default = 2048) – Intermediate size to which input samples are projected.

  • num_attention_heads (int, default = 8) – Number of attention heads in the transformer layer.

  • num_gqa_groups (int, default = None) –

    Number of GQA groups. When None is present, it is equal to num_attention_heads. Grouped Query Attention is described in this paper. This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e. num_gqa_groups = num_attention_heads.

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.

  • layernorm_epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.

  • zero_centered_gamma (bool, default = False) –

    If set to True, the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    This parameter is only applicable for ‘layernorm’.

  • hidden_dropout (float, default = 0.1) – Dropout probability for the dropout op after FC2 layer.

  • hidden_dropout_dims (Sequence[int], default = ()) – Dimensions that will share the same dropout mask for hidden

  • attention_dropout (float, default = 0.1) – Dropout probability for the dropout op during multi-head attention.

  • intermediate_dropout (float, default = 0.1) – Dropout probability for the dropout op after FC1 layer.

  • intermediate_dropout_dims (Sequence[int], default = ()) – Dimensions that will share the same dropout mask for hidden after FC1 layer.

  • dropout_rng_name (str, default = 'dropout') – The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks in the Multi-Head Attention.

  • mha_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) Used for initializing weights of QKV and Output projection weights. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • mlp_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) Used for initializing weights of FC1 and FC2 layers. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • mlp_activations (Sequence[str], default = ('relu', )) – The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer.

  • use_bias (bool, default = False) – Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. If set to False, the layer will not learn additive biases.

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias of QKVO projections, FC1 and FC2. It is only used when use_bias=True. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • apply_residual_connection_post_layernorm (bool, default = False) – If set to True, residual connections are taken from the output of layer norm (default is taken from input of layer norm)

  • output_layernorm (bool, default = False) – If set to True, layer normalization is applied on the output side, after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation.

  • float32_attention_logits (bool, default = False) – Whether to compute attention logits in float32 for the unfused attention backend. For fused attention backend, the accumulation is always float32 without the perf overhead.

  • layer_type (TransformerLayerType, default = TransformerLayerType.ENCODER) – If set to TransformerLayerType.DECODER, an additional cross-attention block is added after self-attention.this can be used for structures like T5 Transformer in conjunction with the TransformerLayerType.ENCODER option.

  • self_attn_mask_type (str, default = 'causal') – Type of the attention mask passed into softmax operation in the self attention. Available options: {‘no_mask’, ‘padding’, ‘causal’, ‘causal_padding’} Introduced in v0.10.0.

  • self_attn_bias_type (Optional[str], default = None) – Type of the attention bias passed into the self attention. Available options: {‘no_bias’, ‘pre_scale_bias’, ‘post_scale_bias’}. When default is present, the type is automatically decided by the MHA’s bias parameter. Where it is post_scale_bias if there is bias. Otherwise no_bias is used.

  • enable_relative_embedding (bool, default = True) – Whether to enable relative embedding as shifting of attention logits.

  • relative_embedding (flax.linen.Module, default = None) – The module for relative embedding execution, only used when enable_relative_embedding=True. Default is None, which will create an instance of RelativePositionBiases if enable_relative_embedding=True. Default: RelativePositionBiases( num_buckets=32, max_distance=128, num_attention_heads=self.num_attention_heads, dtype=self.dtype, embedding_init=flax.linen.initializers.variance_scaling(1.0, ‘fan_avg’, ‘uniform’), name=’relpos_bias’)

  • enable_rotary_pos_emb (bool, default = False) – Whether to enable rotary position embedding to projected query and key in MHA.

  • rotary_pos_emb_windows (Tuple[int, int], default = (1, 10000)) – Indicate the min and max time-scales of rotary position embedding, only used when enable_rotary_pos_emb=True

  • rotary_pos_emb_group_method (str, default = 'consecutive') – Indicate the method to coupled the coordinates. It should be one of [‘consecutive’, ‘alternate’]. ‘alternate’ is to pair index \(i\) with \(i + d/2\) , d is the hidden dimension. ‘consecutive’ pairs index \(i\) with \(i + 1\).

  • enable_sequence_parallel (bool, default = False) – Whether to enable sequence parallelism to operations except dot.

Optimization parameters
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

  • drop_path (float, default = 0.0) – When > 0.0, applies stochastic depth per sample in the main path of the residual block.

  • fuse_qkv_params (bool, default = True) – If set to True, TransformerLayer module exposes a single fused parameter for query-key-value for self-attention and key-value for cross-attention.

  • transpose_batch_sequence (bool, default = False) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. if set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).

  • scale_attn_logits (bool, default = False) – Indicate whether to scale attention logits. if set to True, \(\frac{Q}{\sqrt{head_dim}*K}\), else \(Q*K\)

  • scaled_query_init (bool, default = True) – Whether to scale WQ on initialization by \(\sqrt{head_dim}\)

__call__(inputs: Array, encoded: Array = None, attention_mask: Array = None, encoder_decoder_mask: Array = None, deterministic: bool = False, decode: bool = False, max_decode_length: bool = None)

Transformer Layer: attention block and a feedforward network (MLP)

Parameters
  • inputs (jax.numpy.ndarray) – Input tensor.

  • encoded (jax.numpy.ndarray, default = None) – Output tensors of the encoder block to be fed into the decoder block if using layer_type=TransformerLayerType.DECODER.

  • attention_mask (jax.numpy.ndarray, default = None) – Boolean tensor used to mask out self-attention softmax input.

  • encoder_decoder_mask (jax.numpy.ndarray, default = None) – Boolean tensor used to mask out cross-attention softmax input when layer_type=TransformerLayerType.DECODER.

  • deterministic (bool, default = False) – Disable dropout layers if set to True.

  • decode (bool, default = False) – Indicate whether to prepare and use an autoregressive cache in Multi-head attention (MHA).

  • max_decode_length (bool, default = None) – The maximum length to generate relative embedding biases when layer_type=TransformerLayerType.DECODER and enable_relative_embedding=True.

Returns

outputs – Output tensors.

Return type

jax.numpy.ndarray

transformer_engine.jax.flax.extend_logical_axis_rules(rules: LogicalRules)

Extend the given Flax logical axis rules with the predefined TransformerLayer’s logical axis rules.

Note

We currently only support logical axis rules for single GPU training, data parallel training and 1D-sharding tensor parallel training. Refer to Figure 3 in Megatron-LM tensor parallel for 1D-sharding tensor parallelism.

Warning

Please make sure ShardingResource is set via fp8_autocast before calling this function.

Note

This function is only needed when using TransformerLayer. For other modules, such as DenseGeneral, please properly set axes of kernels and bias.

Parameters

rules (Sequence[Tuple[str, Union[str, None]]]) – the base Flax logical axis rules to extend.

Returns

extended_rules – the extended Flax logical axis rules.

Return type

Sequence[Tuple[str, Union[str, None]]]