Debug features
- class transformer_engine.debug.features.log_tensor_stats.LogTensorStats
This feature handles the logging of basic tensor statistics.
For a distributed setting, the auxiliary stats are computed for each node and gathered after the debug_api.step() call. Do not forget to invoke debug_api.step() at every step to log stats!
LogTensorStats supports micro-batching. If multiple forward/backward passes are invoked per debug_api.step(), then stats for all tensors except weights will be accumulated.
LogTensorStats can induce significant overhead. To mitigate this issue, logging stats with freq > 1 is recommended. If LogTensorStats is not used in a given step, the overhead is smaller. Moreover, if no other feature is used for the layer, the TE layer will run as fast as it would without debug_api initialized.
- Parameters:
stats (List[str]) –
list of statistics to log
min
max
mean
std
l1_norm
l2_norm
cur_amax – maximal absolute value of a tensor,
dynamic_range – equal to torch.log2(amax) - torch.log2(amin)
tensors/tensors_struct (List[str]) –
list of tensors to log
activation
gradient
weight
output
wgrad
dgrad
freq (Optional[int], default = 1) – frequency of logging stats, stats will be logged every freq steps
start_step (Optional[int], default = None) – start step of logging stats
end_step (Optional[int], default = None) – end step of logging stats
start_end_list (Optional[list([int, int])], default = None) – non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step
Example
example_tensor_stat_collection: enabled: True layers: layer_name_regex_pattern: .*(fc1|self_attention).* transformer_engine: LogTensorStats: enabled: True tensors_struct: - tensor: activation stats: [mean] freq: 10 start_step: 5 end_step: 100 - tensor: gradient stats: [mean, max, min] freq: 2 start_end_list: [[0, 20], [80, 100]] - tensor: weight stats: [dynamic_range]
- class transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats
This feature handles logging of FP8 tensor stats.
In a distributed setting, the auxiliary stats are computed on each rank and gathered after the debug_api.step() call. Do not forget to invoke debug_api.step() at every step to log stats!
LogFp8TensorStats supports micro-batching. If multiple forward/backward passes are invoked per debug_api.step(), then stats for all tensors except weights will be accumulated.
LogFp8TensorStats can induce significant overhead. To mitigate this issue, logging stats with freq > 1 is recommended. If LogFp8TensorStats is not used in a given step, the overhead is smaller. If no other feature is used for the layer, the TE layer will run as fast as it would without debug_api initialized.
- Parameters:
stats (List[str]) –
list of statistics to log
underflows% - percentage of elements of the tensor equal to 0,
tensors/tensors_struct (List[str]) –
list of tensors to log
activation
gradient
weight
freq (Optional[int], default = 1) – frequency of logging stats, stats will be logged every freq steps
start_step (Optional[int], default = None) – start step of logging stats
end_step (Optional[int], default = None) – end step of logging stats
start_end_list (Optional[list([int, int])], default = None) – non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step
Example
example_fp8_tensor_stat_collection: enabled: True layers: layer_types: [layernorm_linear] transformer_engine: LogFp8TensorStats: enabled: True tensors_struct: - tensor: activation stats: [underflows%] freq: 1 - tensor: gradient stats: [underflows%] freq: 5 start_step: 0 end_step: 80
- class transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
GEMM operations are executed in higher precision, even when FP8 autocast is enabled.
- Parameters:
gemms (List[str]) –
list of gemms to disable
fprop
dgrad
wgrad
Example
example_disable_fp8_gemm: enabled: True layers: layer_types: [fc1] transformer_engine: DisableFP8GEMM: enabled: True gemms: [dgrad, wgrad]
- class transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
Disables all FP8 GEMMs in the layer.
Example
example_disable_fp8_layer: enabled: True layers: layer_types: [fc1] transformer_engine: DisableFP8Layer: enabled: True
- class transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling
Allows using per-tensor current scaling for the specific tensors.
Can be used only within DelayedScaling recipe autocast.
- Parameters:
gemms/gemms_struct (List[str]) –
list of gemms to enable per-tensor current scaling for
fprop
dgrad
wgrad
tensors/tensors_struct (List[str]) –
list of tensors to enable per-tensor current scaling for
activation
gradient
weight
Example
example_per_tensor_scaling: enabled: True layers: layer_types: [transformer_layer.self_attn.layernorm_q] transformer_engine: PerTensorScaling: enabled: True gemms: [dgrad] tensors: [weight, activation]
- class transformer_engine.debug.features.fake_quant.FakeQuant
Disables FP8 GEMM. Fake quantizes chosen tensors to FP8 - using per-tensor scaling factor, not delayed scaling - and runs high-precision GEMM.
Fig 1: Comparison of FP8 FPROP GEMM with the same GEMM in BF16 with fake quantization of activation tensor. Green tensors have the same values, but different dtypes.
- Parameters:
gemms/gemms_struct (List[str]) –
list of gemms to fake quantize
fprop
dgrad
wgrad
tensors/tensors_struct (List[str]) –
list of tensors to fake quantize
activation
gradient
weight
output
wgrad
dgrad
quant_format (str) –
specifies the FP8 format to use:
FP8E5M2
FP8E4M3
Example
example_fake_quant_fp8: enabled: True layers: layer_types: [transformer_layer.layernorm_mlp.fc1] transformer_engine: FakeQuant: enabled: True quant_format: FP8E5M2 gemms_struct: - gemm: fprop tensors: [activation, weight] - gemm: dgrad tensors: [gradient]