Common API


class transformer_engine.common.recipe.Format(value)

Supported FP8 formats.

  • E4M3 – All FP8 tensors are in e4m3 format

  • E5M2 – All FP8 tensors are in e5m2 format

  • HYBRID – FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format

class transformer_engine.common.recipe.DelayedScaling(margin=0, interval=1, fp8_format=Format.E4M3, amax_history_len=1, amax_compute_algo='most_recent', scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))

Use the delayed scaling factor strategy. Use scale factor from previous iteration, recompute once every interval, and record amax history of amax_history_len steps.

  • margin (int, default = 0) – Margin for the scaling factor computation.

  • interval (int, default = 1) – Controls how often the scaling factor is recomputed.

  • fp8_format ({Format.E4M3, Format.HYBRID}, default = Format.E4M3) – Controls the FP8 data format used during forward and backward pass.

  • amax_history_len (int, default = 1) – The length of the amax history window used for scaling factor computation.

  • amax_compute_algo ({'max', 'most_recent', Callable}, default = 'most_recent') –

    Algorithm used for choosing the amax value for the scaling factor computation. There are 2 predefined choices: max chooses the largest amax in the history window, while most_recent always chooses the most recently seen value. Alternatively, one may pass a function of the signature:

    def amax_compute(amax_history: Tensor) -> Tensor

    where Tensor is a framework tensor type.

  • scaling_factor_compute_algo (Callable, default = None) –

    Algorithm used for computing the new scaling factor based on the value of amax. It should be a function of the signature:

    def scaling_factor_compute(amax: Tensor,
                               old_scaling_factor: Tensor,
                               fp8_max: Tensor,
                               recipe: DelayedScaling) -> Tensor

    where Tensor is a framework tensor type.

  • override_linear_precision (Tuple(bool, bool, bool), default=(False, False, False)) – Whether or not the execute the fprop, dgrad, and wgrad GEMMs (respectively) in higher precision when using FP8.


  • By default (when scaling_factor_compute_algo is left as None) the scaling factor is computed from the final amax value using the formula:

    FP8_MAX = maximum_representable_value(fp8_format)
    exp = get_exponent(FP8_MAX / amax) - margin
    new_scaling_factor = 2.0 ^ exp
  • The scaling factor should always be a power of 2 to not introduce numerical error during the conversion from FP8 to higher precision format.