Common API

class transformer_engine.common.recipe.Format

Supported FP8 formats.

  • E4M3 – All FP8 tensors are in e4m3 format

  • E5M2 – All FP8 tensors are in e5m2 format

  • HYBRID – FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format

class transformer_engine.common.recipe.DelayedScaling(margin=0, interval=1, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo='max', scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))

Use the delayed scaling factor strategy. Use scale factor from previous iteration, recompute once every interval, and record amax history of amax_history_len steps.

  • margin (int, default = 0) – Margin for the scaling factor computation.

  • interval (int, default = 1) – Controls how often the scaling factor is recomputed.

  • fp8_format ({Format.E4M3, Format.HYBRID}, default = Format.HYBRID) – Controls the FP8 data format used during forward and backward pass.

  • amax_history_len (int, default = 1024) – The length of the amax history window used for scaling factor computation.

  • amax_compute_algo ({'max', 'most_recent', Callable}, default = 'max') –

    Algorithm used for choosing the amax value for the scaling factor computation. There are 2 predefined choices: max chooses the largest amax in the history window, while most_recent always chooses the most recently seen value. Alternatively, one may pass a function of the signature:

    def amax_compute(amax_history: Tensor) -> Tensor

    where Tensor is a framework tensor type.

  • scaling_factor_compute_algo (Callable, default = None) –

    Algorithm used for computing the new scaling factor based on the value of amax. It should be a function of the signature:

    def scaling_factor_compute(amax: Tensor,
                               old_scaling_factor: Tensor,
                               fp8_max: Tensor,
                               recipe: DelayedScaling) -> Tensor

    where Tensor is a framework tensor type.

  • override_linear_precision (Tuple(bool, bool, bool), default=(False, False, False)) – Whether or not to execute the fprop, dgrad, and wgrad GEMMs (respectively) in higher precision when using FP8.

  • reduce_amax (bool, default = True) – By default, if torch.distributed is initialized, the amax value for FP8 tensors is reduced across the fp8_group (specified in the fp8_autocast call). This keeps the amaxes and scaling factors synced across the given distributed group. If set to False, this reduction is skipped and every GPU maintains local amaxes and scaling factors. To ensure results are numerically identical across checkpointing boundaries in this case, all ranks must checkpoint in order to store the local tensors.

  • fp8_dpa (bool, default = False) – Whether to enable FP8 dot product attention (DPA). When the model is placed in an fp8_autocast(enabled=True) region and fp8_dpa is set to True, DPA casts the inputs from higher precision to FP8, performs attention in FP8, and casts tensors back to higher precision as outputs. FP8 DPA currently is only supported in the FusedAttention backend.

  • fp8_mha (bool, default = False) – Whether to enable FP8 multi-head attention (MHA). When True, it removes the casting operations mentioned above at the DPA boundaries. Currently only standard MHA modules i.e. LayerNormLinear/Linear + DPA + Linear, are supported for this feature. When fp8_mha = False, fp8_dpa = True, a typical MHA module works as LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear. When fp8_mha = True, fp8_dpa = True, it becomes LayerNormLinear (FP8 output) -> FP8 DPA -> Linear.


  • By default (when scaling_factor_compute_algo is left as None) the scaling factor is computed from the final amax value using the formula:

    FP8_MAX = maximum_representable_value(fp8_format)
    new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin)
  • fp8_dpa and fp8_mha are Beta features, and their API and functionality are subject to change in future Transformer Engine releases.