Jax

class transformer_engine.jax.MajorShardingType

The major sharding type to indicate sharding pattern.

Values
  • SINGLE – Single process training.

  • DP – Data parallel training.

  • TP – Standard tensor parallel training.

  • DPTP – Data and Standard tensor parallel training.

class transformer_engine.jax.ShardingType

The sharding type to indicate sharding pattern.

Values
  • SINGLE – No sharding.

  • DP – Sharding along data parallelism.

  • TP_COL – Sharding along column-split tensor parallelism.

  • TP_ROW – Sharding along row-split tensor parallelism.

  • DP_TP_COL – Sharding along data and column-split tensor parallelism.

  • DP_TP_ROW – Sharding along data and row-split tensor parallelism.

class transformer_engine.jax.TransformerLayerType
class transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None)

A data container to indicate which axis in Mesh for data parallelism and which for tensor parallelism.

Parameters
  • dp_resource (str, default = None) – The axis name in Mesh used to shard batches along. If it is None, then data parallelism is disabled.

  • tp_resource (str, default = None) – The axis name in Mesh used to split the hidden dimensions along. If it is None, then tensor parallelism is disabled.

transformer_engine.jax.fp8_autocast(enabled: bool = False, fp8_recipe: Optional[transformer_engine.common.recipe.DelayedScaling] = None, sharding_resource: Optional[transformer_engine.jax.sharding.ShardingResource] = None)

Context manager for FP8 usage.

mesh_shape = (4, 2)
dp_mesh_axis_name = 'data_parallel'
tp_mesh_axis_name = 'tensor_parallel'
devices = np.asarray(jax.devices()).reshape(*mesh_shape)

with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
    sharding_resource=ShardingResource(dp_mesh_axis_name, tp_mesh_axis_name)

    with fp8_autocast(enabled=True, sharding_resource=sharding_resource):
        rules = extend_logical_axis_rules(tuple())
        transformer = TransformerLayer()

        with partitioning.axis_rules(rules):
            pjit(transformer.init, ...)(...)

Note

We only support margin, fp8_format, interval and amax_history_len in recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling would be ignored, even if set.

Parameters
  • enabled (bool, default = False) – Whether or not to enable fp8

  • fp8_recipe (recipe.DelayedScaling, default = None) – Recipe used for FP8 training.

  • sharding_resource (ShardingResource, default = None) – Specify the mesh axes for data and tensor parallelism to shard along. If set to None, then ShardingResource() would be created.

transformer_engine.jax.update_collections(new: Collection, original: Collection)

A helper to update Flax’s Collection.

Collection = [dict, flax.core.frozen_dict.FrozenDict]

Parameters
  • new (Collection) – A collection that includes new data.

  • original (Collection) – The base collection.

Returns

outputs – The updated collection.

Return type

Collection

transformer_engine.jax.update_fp8_metas(state: Collection)

Calculate new fp8 scales and its inverse via the followed formula

exp = floor(log2(fp8_max / amax)) - margin
sf = round(power(2, abs(exp)))
sf = sf if amax > 0.0, else original_scale
sf = sf if isfinite(amax), else original_scale)
updated_scale = 1/sf if exp < 0, else sf
updated_scale_inv = 1/updated_scale

Collection = [dict, flax.core.frozen_dict.FrozenDict]

Parameters

state (Collection) – A collection that includes FP8 metas.

Returns

outputs – The collection with updated FP8 metas.

Return type

Collection

class transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)

Applies layer normalization over a mini-batch of inputs. There are two types of normalization supported by this module, regular and root mean square layer Normalization.

The regular layer normalization is as described in the paper Layer Normalization

\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

\(\gamma\) and \(\beta\) are learnable affine transform parameters of size of each input sample.

The root mean square layer normalization (RMSNorm) is as described in the paper Root Mean Square Layer Normalization

\[y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma\]
\[RMS = \sqrt{\mathrm{E}[x^2]}\]

\(\gamma\) is learnable affine transform parameters of size of each input sample.

Parameters
  • epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.

  • zero_centered_gamma (bool, default = False) –

    If set to True, the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    This parameter is only applicable for ‘layernorm’. The default of scale_init will also be changed. See scale_init.

  • scale_init (Initializer, default = None) – Used for initializing scale factors \(\gamma\). If None is provided, scale_init is set according to the value of zero_centered_gamma. If zero_centered_gamma is set to True, then scale_init is flax.linen.initializers.zeros. Otherwise, scale_init is flax.linen.initializers.ones. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • scale_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the scale factors \(\gamma\) with a corresponding mesh.

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing shift factors \(\beta\), only used when layernorm_type='layernorm'. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • bias_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the shift factors \(\beta\) with a corresponding mesh. only used when layernorm_type='layernorm'.

Optimization parameters
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – the data type used to allocate the initial parameters.

  • transpose_batch_sequence (bool, default = False) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).

  • sharding_type (ShardingType, default = ShardingType.SINGLE) – Indicate the sharding pattern.

__call__(x: jax.numpy.ndarray)

Applies layer normalization to the input inputs.

Parameters

inputs (jax.numpy.ndarray) – Input tensors.

Returns

outputs – Output tensors.

Return type

jax.numpy.ndarray

class transformer_engine.jax.flax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)

Applies a linear transformation to the incoming data \(y = xA^T + b\)

Parameters
  • features (Union[Iterable[int], int]) – The hidden size of each output sample.

  • kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) Used for initializing weights. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • kernel_axes (Tuple[str, ...], default = ()) – The name of axes used to shard the weights with a corresponding mesh.

  • use_bias (bool, default = False) – Indicate whether to enable bias shifting. If set to False, the layer will not learn an additive bias.

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias, only used when use_bias=True. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • bias_axes (Tuple[str, ...], default = ()) – The name of axes used to shard bias with a corresponding mesh, only used when use_bias=True.

  • axis (Union[Iterable[int], int], default = -1) – An integer tuple with axes to apply the transformation on.

Optimization parameters
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

  • transpose_batch_sequence (bool, default = True) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).

  • sharding_type (ShardingType, default = ShardingType.SINGLE) – Indicate the sharding pattern.

__call__(inputs: Array)

Apply the linear transformation to the input.

Parameters

inputs (jax.numpy.ndarray) – Input tensors.

Returns

outputs – Output tensors.

Return type

jax.numpy.ndarray

class transformer_engine.jax.flax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)

Applies layer normalization followed by linear transformation to the incoming data.

Parameters
  • features (Union[Iterable[int], int]) – The hidden size of each output sample.

  • enable_layernorm (bool, default = True) – Indicate whether to enable layer normalization before linear transformation.

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.

  • epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.

  • zero_centered_gamma (bool, default = False) –

    If set to True, the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    This parameter is only applicable for ‘layernorm’. The default of scale_init will also be changed. See scale_init

  • scale_init (Initializer, default = None) – Used for initializing scale factors \(\gamma\). If None is provided, scale_init is set according to the value of zero_centered_gamma. If zero_centered_gamma is set to True, then scale_init is flax.linen.initializers.zeros. Otherwise, scale_init is flax.linen.initializers.ones. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • scale_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the scale factors \(\gamma\) with a corresponding mesh, only used when enable_layernorm=True.

  • ln_bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing shift factors \(\beta\), only used when enable_layernorm=True and layernorm_type='layernorm'. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • ln_bias_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the shift factors \(\beta\) with a corresponding mesh. It is only used when enable_layernorm=True and layernorm_type='layernorm'.

  • kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) Used for initializing weights. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • kernel_axes (Tuple[str, ...], default = ()) – The name of axes used to shard the weights with a corresponding mesh.

  • use_bias (bool, default = False) – Indicate whether to enable bias shifting. If set to False, the layer will not learn an additive bias.

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias, only used when use_bias=True. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • bias_axes (Tuple[str, ...], default = ()) – The name of axes used to shard bias with a corresponding mesh, only used when use_bias=True.

  • return_layernorm_output (bool, default = True) – Indicate whether to return the output of layer normalization. If set False, return None as the second tensor in outputs.

  • axis (Union[Iterable[int], int], default = -1) – An integer tuple with axes to apply the transformation on.

Optimization parameters
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

  • transpose_batch_sequence (bool, default = True) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).

  • depth_scaling (float, default = None) – The factor to scale the output from DenseGeneral. It should be a float value or None. When None is set, then no scaling is applied.

  • sharding_type (ShardingType, default = ShardingType.SINGLE) – Indicate the sharding pattern.

__call__(inputs: Array)

Apply layer normalization to the input followed by a linear transformation.

Parameters

inputs (jax.numpy.ndarray) – Input tensor.

Returns

  • outputs (jax.numpy.ndarray) – Output tensors.

  • ln_outputs (jax.numpy.ndarray) – The output tensors of layer normalization. If return_layernorm_output=False, then this would be None.

class transformer_engine.jax.flax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)

Applies layer normalization on the input followed by the MLP module, consisting of 2 successive linear transformations, separated by given activations.

Parameters
  • intermediate_dim (int, default = 2048) – Intermediate size to which input samples are projected.

  • enable_layernorm (bool, default = True) – Indicate whether to enable layer normalization before linear transformation.

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.

  • epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.

  • zero_centered_gamma (bool, default = False) –

    If set to True, the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    This parameter is only applicable for ‘layernorm’. The default of scale_init will also be changed. See scale_init.

  • scale_init (Initializer, default = None) – Used for initializing scale factors \(\gamma\). If None is provided, scale_init is set according to the value of zero_centered_gamma. If zero_centered_gamma is set to True, then scale_init is flax.linen.initializers.zeros. Otherwise, scale_init is flax.linen.initializers.ones. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • scale_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the scale factors \(\gamma\) with a corresponding mesh, only used when enable_layernorm=True.

  • ln_bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing shift factors \(\beta\), only used when enable_layernorm=True and layernorm_type='layernorm'. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • ln_bias_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the shift factors \(\beta\) with a corresponding mesh. Only used when enable_layernorm=True and layernorm_type='layernorm'.

  • kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) Used for initializing the weights of both linear transformations. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • kernel_axes_1 (Tuple[str, ...], default = ('embed', 'act', 'mlp')) – The name of axes used to shard the weights with a corresponding mesh for the weight of the first linear transformations.

  • kernel_axes_2 (Tuple[str, ...], default = ('mlp', 'embed')) – The name of axes used to shard the weights with a corresponding mesh for the weight of the second linear transformations.

  • use_bias (bool, default = False) – Indicate whether to enable bias shifting. If set to False, the layer will not learn an additive bias.

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias, only used when use_bias=True. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • bias_axes_1 (Tuple[str, ...], default = ('mlp',)) – The name of axes used to shard bias with a corresponding mesh for the weight of the first linear transformations. Only used when use_bias=True.

  • bias_axes_2 (Tuple[str, ...], default = ('embed',)) – The name of axes used to shard bias with a corresponding mesh for the weight of the second linear transformations. Only used when use_bias=True.

  • return_layernorm_output (bool, default = True) – Indicate whether to return the output of layer normalization. If set False, return None as the second tensor in outputs.

  • activations (Sequence[Union[str, Callable]], default = ('relu',)) – The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer.

  • intermediate_dropout_rate (float, default = 0.1) – Dropout probability for the dropout op after the activations.

  • intermediate_hidden_dropout_dims (Sequence[int], default = ()) – Dimensions that will share the same dropout mask for hidden

  • axis (Union[Iterable[int], int], default = -1) – An integer tuple with axes to apply the transformation on.

Optimization parameters
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

  • transpose_batch_sequence (bool, default = True) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).

  • major_sharding_type (MajorShardingType, default = MajorShardingType.SINGLE) – Indicate the sharding pattern.

__call__(inputs: Array, deterministic: bool = False)

Apply layer normalization to the input followed by a feedforward network (MLP Block).

Parameters
  • inputs (jax.numpy.ndarray) – Input tensor.

  • deterministic (bool, default = False) – Disable dropout ops if set to True.

Returns

  • outputs (jax.numpy.ndarray) – Output tensors.

  • ln_outputs (jax.numpy.ndarray) – The output tensors of layer normalization. If return_layernorm_output=False, then this would be None.

class transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)

T5-style relative positional embeddings to the attention logits.

Parameters
  • num_buckets (int) – The number of buckets to bucket distances between key and query positions into.

  • max_distance (int) – The maximum distance before everything is lumped into the last distance bucket.

  • num_attention_heads (int) – Number of attention heads in the transformer layer.

  • embedding_init (Initializer, default = flax.linen.linear.default_embed_init) – Used for initializing relative embedding tables.

  • embedding_axes (Tuple[str, ...], default = ('heads', 'relpos_buckets')) – The name of axes used to shard embedding attention bias with a corresponding mesh.

Optimization parameters

dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

__call__(q_seqlen, k_seqlen, bidirectional=True)

Generate relative position embedding attention biases.

Parameters
  • q_seqlen (int) – The sequence length of query.

  • k_seqlen (int) – The sequence length of key.

  • bidirectional (bool, default = True) – Indicate whether to allow positive memory-query relative position embeddings.

Returns

output – An attention bias with shape (1, num_attention_heads, q_seqlen, k_seqlen).

Return type

jax.numpy.ndarray

class transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)

Multi-head Attention (MHA), including Query, Key, Value and Output projection.

Parameters
  • head_dim (int) – The hidden dimension of each attention head.

  • num_heads (int) – The number of attention heads

  • dropout_rate (float, default = 0.0) – Dropout probability for the dropout op during multi-head attention.

  • dropout_rng_name (str, default = 'dropout') – The key in given RNGs via flax.linen.Module.apply that is used to generate Dropout masks in the core attention.

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.

  • layernorm_epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.

  • zero_centered_gamma (bool, default = False) –

    If set to True, the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    This parameter is only applicable for ‘layernorm’.

  • kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) Used for initializing the QKV and Output projection weights. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • use_bias (bool, default = False) – Indicate whether or not to enable bias shifting for QKVO projections. If set to False, the layer will not learn additive biases.

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias of QKVO projections, only used when use_bias=True. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • apply_residual_connection_post_layernorm (bool, default = False) – Indicate if apply residual connection with the output of layer normalization.

  • output_layernorm (bool, default = False) – Indicate if apply a layer normalization at the end of MHA.

  • attn_type (AttentionType, defult = AttentionType.PADDING) – Indicate the format of the attention mask in the core attention.

Optimization parameters
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

  • fuse_qkv (bool, default = True) – If set to True, this module exposes a single fused parameter for query-key-value for self-attention and key-value for cross-attention.

  • transpose_batch_sequence (bool, default = True) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. if set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).

  • scale_attn_logits (bool, default = False) – Indicate whether to scale attention logits. If set to True, \(\frac{Q}{\sqrt{head_dim}*K}\), else \(Q*K\)

  • scaled_query_init (bool, default = True) – Whether to scale WQ on initialization by \(\sqrt{head_dim}\)

  • float32_logits (bool, default = False) – Whether to compute attention logits in float32.

__call__(inputs_q: Array, inputs_kv: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, *, decode: bool = False, deterministic: bool = False)

MultiHeadAttention Layer: [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

Parameters
  • inputs_q (jax.numpy.ndarray) – Input tensor for query projection.

  • inputs_kv (jax.numpy.ndarray) – Input tensor for key/value projection.

  • mask (jax.numpy.ndarray, default = None) – Boolean tensor used to mask out self-attention softmax input.

  • bias (jax.numpy.ndarray, default = None) – A tensor used to shift self-attention softmax input.

  • *

  • decode (bool,default = False) – Indicate whether to prepare and use an autoregressive cache.

  • deterministic (bool,default = False) – Disable dropout layers if set to True.

Returns

outputs – Output tensors.

Return type

jax.numpy.ndarray

class transformer_engine.jax.flax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)

TransformerLayer is made up of a relative embedding, an attention block and a feedforward network (MLP). This standard layer is based on the paper “Attention Is All You Need”.

Parameters
  • hidden_size (int, default = 512) – The hidden size of each input sample.

  • mlp_hidden_size (int, default = 2048) – Intermediate size to which input samples are projected.

  • num_attention_heads (int, default = 8) – Number of attention heads in the transformer layer.

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.

  • layernorm_epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.

  • zero_centered_gamma (bool, default = False) –

    If set to True, the LayerNorm formula changes to

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    This parameter is only applicable for ‘layernorm’.

  • hidden_dropout (float, default = 0.1) – Dropout probability for the dropout op after FC2 layer.

  • hidden_dropout_dims (Sequence[int], default = ()) – Dimensions that will share the same dropout mask for hidden

  • attention_dropout (float, default = 0.1) – Dropout probability for the dropout op during multi-head attention.

  • dropout_rng_name (str, default = 'dropout') – The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks in the Multi-Head Attention.

  • mha_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) Used for initializing weights of QKV and Output projection weights. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • mlp_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) Used for initializing weights of FC1 and FC2 layers. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • mlp_activations (Sequence[str], default = ('relu', )) – The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer.

  • use_bias (bool, default = False) – Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. If set to False, the layer will not learn additive biases.

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias of QKVO projections, FC1 and FC2. It is only used when use_bias=True. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).

  • apply_residual_connection_post_layernorm (bool, default = False) – If set to True, residual connections are taken from the output of layer norm (default is taken from input of layer norm)

  • output_layernorm (bool, default = False) – If set to True, layer normalization is applied on the output side, after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation.

  • float32_attention_logits (bool, default = False) – If set to True, attention logits are executed in jax.numpy.float32.

  • layer_type (TransformerLayerType, default = TransformerLayerType.ENCODER) – If set to TransformerLayerType.DECODER, an additional cross-attention block is added after self-attention.this can be used for structures like T5 Transformer in conjunction with the TransformerLayerType.ENCODER option.

  • enable_relative_embedding (bool, default = True) – Whether to enable relative embedding as shifting of attention logits.

  • relative_embedding (flax.linen.Module, default = None) – The module for relative embedding execution, only used when enable_relative_embedding=True. Default is None, which will create an instance of RelativePositionBiases if enable_relative_embedding=True. Default: RelativePositionBiases( num_buckets=32, max_distance=128, num_attention_heads=self.num_attention_heads, dtype=self.dtype, embedding_init=flax.linen.initializers.variance_scaling(1.0, ‘fan_avg’, ‘uniform’), name=’relpos_bias’)

Optimization parameters
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

  • drop_path (float, default = 0.0) – When > 0.0, applies stochastic depth per sample in the main path of the residual block.

  • fuse_qkv_params (bool, default = True) – If set to True, TransformerLayer module exposes a single fused parameter for query-key-value for self-attention and key-value for cross-attention.

  • transpose_batch_sequence (bool, default = False) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. if set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).

  • scale_attn_logits (bool, default = False) – Indicate whether to scale attention logits. if set to True, \(\frac{Q}{\sqrt{head_dim}*K}\), else \(Q*K\)

  • scaled_query_init (bool, default = True) – Whether to scale WQ on initialization by \(\sqrt{head_dim}\)

__call__(inputs: Array, encoded: Array = None, attention_mask: Array = None, encoder_decoder_mask: Array = None, deterministic: bool = False, decode: bool = False, max_decode_length: bool = None)

Transformer Layer: attention block and a feedforward network (MLP)

Parameters
  • inputs (jax.numpy.ndarray) – Input tensor.

  • encoded (jax.numpy.ndarray, default = None) – Output tensors of the encoder block to be fed into the decoder block if using layer_type=TransformerLayerType.DECODER.

  • attention_mask (jax.numpy.ndarray, default = None) – Boolean tensor used to mask out self-attention softmax input.

  • encoder_decoder_mask (jax.numpy.ndarray, default = None) – Boolean tensor used to mask out cross-attention softmax input when layer_type=TransformerLayerType.DECODER.

  • deterministic (bool, default = False) – Disable dropout layers if set to True.

  • decode (bool,default = False) – Indicate whether to prepare and use an autoregressive cache in Multi-head attention (MHA).

  • max_decode_length (bool, default = None) – The maximum length to generate relative embedding biases when layer_type=TransformerLayerType.DECODER and enable_relative_embedding=True.

Returns

outputs – Output tensors.

Return type

jax.numpy.ndarray

transformer_engine.jax.flax.extend_logical_axis_rules(rules: LogicalRules)

Extend the given Flax logical axis rules with the predefined TransformerLayer’s logical axis rules.

Note

We currently only support logical axis rules for single GPU training, data parallel training and 1D-sharding tensor parallel training. Refer to Figure 3 in Megatron-LM tensor parallel for 1D-sharding tensor parallelism.

Warning

Please make sure ShardingResource is set via fp8_autocast before calling this function.

Note

This function is only needed when using TransformerLayer. For other modules, such as DenseGeneral, please properly set axes of kernels and bias.

Parameters

rules (Sequence[Tuple[str, Union[str, None]]]) – the base Flax logical axis rules to extend.

Returns

extended_rules – the extended Flax logical axis rules.

Return type

Sequence[Tuple[str, Union[str, None]]]