Debug features

class transformer_engine.debug.features.log_tensor_stats.LogTensorStats

This feature handles the logging of basic tensor statistics.

For a distributed setting, the auxiliary stats are computed for each node and gathered after the debug_api.step() call. Do not forget to invoke debug_api.step() at every step to log stats!

LogTensorStats supports micro-batching. If multiple forward/backward passes are invoked per debug_api.step(), then stats for all tensors except weights will be accumulated.

LogTensorStats can induce significant overhead. To mitigate this issue, logging stats with freq > 1 is recommended. If LogTensorStats is not used in a given step, the overhead is smaller. Moreover, if no other feature is used for the layer, the TE layer will run as fast as it would without debug_api initialized.

Parameters:
  • stats (List[str]) –

    list of statistics to log

    • min

    • max

    • mean

    • std

    • l1_norm

    • l2_norm

    • cur_amax – maximal absolute value of a tensor,

    • dynamic_range – equal to torch.log2(amax) - torch.log2(amin)

  • tensors/tensors_struct (List[str]) –

    list of tensors to log

    • activation

    • gradient

    • weight

    • output

    • wgrad

    • dgrad

  • freq (Optional[int], default = 1) – frequency of logging stats, stats will be logged every freq steps

  • start_step (Optional[int], default = None) – start step of logging stats

  • end_step (Optional[int], default = None) – end step of logging stats

  • start_end_list (Optional[list([int, int])], default = None) – non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step

Example

example_tensor_stat_collection:
    enabled: True
    layers:
        layer_name_regex_pattern: .*(fc1|self_attention).*
    transformer_engine:
        LogTensorStats:
            enabled: True
            tensors_struct:
                - tensor: activation
                  stats: [mean]
                  freq: 10
                  start_step: 5
                  end_step: 100
                - tensor: gradient
                  stats: [mean, max, min]
                  freq: 2
                  start_end_list: [[0, 20], [80, 100]]
                - tensor: weight
                  stats: [dynamic_range]
class transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats

This feature handles logging of FP8 tensor stats.

In a distributed setting, the auxiliary stats are computed on each rank and gathered after the debug_api.step() call. Do not forget to invoke debug_api.step() at every step to log stats!

LogFp8TensorStats supports micro-batching. If multiple forward/backward passes are invoked per debug_api.step(), then stats for all tensors except weights will be accumulated.

LogFp8TensorStats can induce significant overhead. To mitigate this issue, logging stats with freq > 1 is recommended. If LogFp8TensorStats is not used in a given step, the overhead is smaller. If no other feature is used for the layer, the TE layer will run as fast as it would without debug_api initialized.

Parameters:
  • stats (List[str]) –

    list of statistics to log

    • underflows% - percentage of elements of the tensor equal to 0,

  • tensors/tensors_struct (List[str]) –

    list of tensors to log

    • activation

    • gradient

    • weight

  • freq (Optional[int], default = 1) – frequency of logging stats, stats will be logged every freq steps

  • start_step (Optional[int], default = None) – start step of logging stats

  • end_step (Optional[int], default = None) – end step of logging stats

  • start_end_list (Optional[list([int, int])], default = None) – non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step

Example

example_fp8_tensor_stat_collection:
    enabled: True
    layers:
        layer_types: [layernorm_linear]
    transformer_engine:
        LogFp8TensorStats:
            enabled: True
            tensors_struct:
            - tensor: activation
            stats: [underflows%]
            freq: 1
            - tensor: gradient
            stats: [underflows%]
            freq: 5
            start_step: 0
            end_step: 80
class transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM

GEMM operations are executed in higher precision, even when FP8 autocast is enabled.

Parameters:

gemms (List[str]) –

list of gemms to disable

  • fprop

  • dgrad

  • wgrad

Example

example_disable_fp8_gemm:
    enabled: True
    layers:
        layer_types: [fc1]
    transformer_engine:
        DisableFP8GEMM:
            enabled: True
            gemms: [dgrad, wgrad]
class transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer

Disables all FP8 GEMMs in the layer.

Example

example_disable_fp8_layer:
    enabled: True
layers:
    layer_types: [fc1]
transformer_engine:
    DisableFP8Layer:
        enabled: True
class transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling

Allows using per-tensor current scaling for the specific tensors.

Can be used only within DelayedScaling recipe autocast.

Parameters:
  • gemms/gemms_struct (List[str]) –

    list of gemms to enable per-tensor current scaling for

    • fprop

    • dgrad

    • wgrad

  • tensors/tensors_struct (List[str]) –

    list of tensors to enable per-tensor current scaling for

    • activation

    • gradient

    • weight

Example

example_per_tensor_scaling:
    enabled: True
    layers:
        layer_types: [transformer_layer.self_attn.layernorm_q]
    transformer_engine:
        PerTensorScaling:
            enabled: True
            gemms: [dgrad]
            tensors: [weight, activation]
class transformer_engine.debug.features.fake_quant.FakeQuant

Disables FP8 GEMM. Fake quantizes chosen tensors to FP8 - using per-tensor scaling factor, not delayed scaling - and runs high-precision GEMM.

../_images/fake_quant.svg

Fig 1: Comparison of FP8 FPROP GEMM with the same GEMM in BF16 with fake quantization of activation tensor. Green tensors have the same values, but different dtypes.

Parameters:
  • gemms/gemms_struct (List[str]) –

    list of gemms to fake quantize

    • fprop

    • dgrad

    • wgrad

  • tensors/tensors_struct (List[str]) –

    list of tensors to fake quantize

    • activation

    • gradient

    • weight

    • output

    • wgrad

    • dgrad

  • quant_format (str) –

    specifies the FP8 format to use:

    • FP8E5M2

    • FP8E4M3

Example

example_fake_quant_fp8:
    enabled: True
    layers:
        layer_types: [transformer_layer.layernorm_mlp.fc1]
    transformer_engine:
        FakeQuant:
            enabled: True
            quant_format: FP8E5M2
            gemms_struct:
            - gemm: fprop
                tensors: [activation, weight]
            - gemm: dgrad
                tensors: [gradient]