Calls to Nvidia-DL-Framework-Inspect
Let’s look deeper into how Nvidia-DL-Framework-Inspect with Transformer Engine work together. TransformerEngine layers have some hook calls inside each of the GEMMs. Users can define feature classes or use feature classes provided with TE. File config.yaml
describes which hooks need to be used for which layers. Nvidia-DL-Framework-Inspect combines 3 things: TE training, feature classes and config.yaml
and takes care of inserting hooks in the correct places. This process is illustrated in the image below.
Fig 1: Example of Nvidia-DL-Framework-Inspect affecting training script with 1 Linear Layer. For tensors mentioned in config.yaml
, behavior of modify_tensor_enabled()
and modify_tensor()
calls are substituted with definitions from the feature class. Other calls return default values - in fact they do nothing.
In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed. The order of these calls is illustrated in the image below.
Fig 2: The calls to Nvidia-DL-Framework-Inspect done for Transformer Engine. There are 2 types of calls: GEMM calls and routing calls.
There are 2 categories of API calls, each is used for different purposes:
GEMM calls - invoked during every GEMM, used to process or quantize tensors and collect information about them,
Routing calls - invoked at the beginning of every forward pass - they indicate whether a feature is going to use modify_tensor(), etc.
If all routing calls for the layer return False, then the layer is invoked in an optimized version with Transformer Engine fusions. If any of the routing calls return True, layers are run without the fusions. This is necessary because otherwise some tensors cannot be accessed if fusions happen. An important remark is that if no feature is used for the layer, then it should perform as fast as the layer without initializing debug_api.
- transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor(config: Dict, layer_name: str, gemm: str, tensor_name: str, tensor: torch.Tensor, default_quantizer: transformer_engine.pytorch.tensor.Quantizer, iteration: int, out: torch.Tensor | transformer_engine.pytorch.tensor.QuantizedTensor) torch.Tensor | transformer_engine.pytorch.tensor.QuantizedTensor | None
It allows tensor modification. For example, feature FakeQuant uses it to emulate casting to FP8. It can be invoked at most once for each tensor within a given GEMM operation.
This call is invoked if modify_tensor_enabled returns True and the feature is enabled for the tensor_name and gemm. Then it is called instead of the default quantization.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
tensor (torch.Tensor) – tensor in high precision,
gemm (str) – one of [fprop, dgrad, wgrad],
tensor_name (str) – one of [activation, weight, gradient, output, wgrad, dgrad],
default_quantizer (Quantizer) – quantizer which is used to cast the tensor to lower precision if modify_tensor is not invoked. For example, feature per tensor scale uses it to obtain FP8 dtype of the tensor. If the recipe indicates that the tensor is not cast - for example, if running without FP8 autocast, then default_quantizer=None,
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
out (Union[torch.Tensor, QuantizedTensor]) – output tensor, used in the weight caching mechanism.
- Returns:
Union[torch.Tensor, transformer_engine.pytorch.QuantizerTensor, None] – can be torch.Tensor or one of the Transformer Engine’s QuantizedTensor - the rule is that both tensors returned for each GEMM should have the same type. If both are Float8Tensor, then GEMM is run in FP8. If both are torch.Tensor, GEMM is run in high precision. Please take that into account especially if only one tensor of the GEMM is processed by the modify_tensor(). For example, FakeQuant disabled FP8 GEMM to ensure that the second tensor is also in high precision. If the tensor is not the input for any GEMM - namely output, wgrad and dgrad - the return type would match the input type.
Should return None if out is not None.
- transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor(config: Dict, layer_name: str, tensor_name: str, tensor: torch.Tensor, iteration: int, tp_group: torch.distributed.ProcessGroup) None
The feature is invoked if inspect_tensor_enabled returns True. It can be used to obtain information on the high precision tensor. For example, it is run by the LogTensorStats feature.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
tensor_name (str) – one of [activation, weight, gradient, output, wgrad, dgrad],
tensor (torch.Tensor) – tensor in high precision,
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
tp_group (torch.distributed.ProcessGroup) – process group for the tensor parallel group. This is used for weight statistics reduction. This is not reduction group from debug_api.
- Return type:
Should return nothing.
- transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize(config: Dict, layer_name: str, tensor_name: str, gemm: str, tensor: torch.Tensor, iteration: int, tp_group: torch.distributed.ProcessGroup) None
Similar to inspect_tensor, but is run after one of the: fp8 cast, modify_tensor if they are run. If none of the fp8 cast or modify_tensor is invoked, then inspect_tensor_postquantize is also not invoked. The feature LogFp8Stats uses this call to collect FP8 statistics after the quantization.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
tensor_name (str) – one of [activation, weight, gradient, output, wgrad, dgrad],
tensor (torch.Tensor) – tensor in fp8 or processed tensor after the modify_tensor call,
gemm (str) – one of [fprop, dgrad, wgrad],
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
tp_group (torch.distributed.ProcessGroup) – process group for the tensor parallel group. This is used for weight statistics reduction. This is not reduction group from debug_api.
- Return type:
Should return nothing.
- transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor_enabled(config: Dict, layer_name: str, gemm: str, tensor_name: str, iteration: int) bool
It is used to determine whether modify_tensor will be run for a given GEMM and tensor name. It has higher priority than fp8_gemm, if modify_tensor_enabled returns True, then modify_tensor call is invoked for the respective tensor no matter what.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
gemm (str) – one of [fprop, dgrad, wgrad],
tensor_name (str) – one of [activation, weight, gradient, output, wgrad, dgrad],
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
- Return type:
bool - default is False
- transformer_engine.debug.features.api.TEDefaultFeatures.fp8_gemm_enabled(config: Dict, layer_name: str, gemm: str, iteration: int) bool
If the tensor is not processed using modify_tensor and the fp8 recipe is enabled, then the decision whether to cast it to fp8 is based on the value returned by the call fp8_gemm_enabled. If the tensor is processed using modify_tensor or fp8 autocast is not enabled, the result of this call does not matter.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
gemm (str) – one of [fprop, dgrad, wgrad],
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
- Return type:
bool - default is True
- transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_enabled(config: Dict, layer_name: str, tensor_name: str, iteration: int) bool
It is a routing call, which is run at the initialization of the layer. If it returns true, then inspect_tensor for a given GEMM and tensor will be invoked.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
tensor_name (str) – one of [activation, weight, gradient, output, wgrad, dgrad].
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
- Return type:
bool - default is False
- transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize_enabled(config: Dict, layer_name: str, gemm: str, tensor_name: str, iteration: int) bool
It is a routing call, which is run at the initialization of the layer. If it returns true, then inspect_tensor_postquantize for a given GEMM and tensor will be invoked.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
gemm (str) – one of [fprop, dgrad, wgrad],
tensor_name (str) – one of [activation, weight, gradient, output, wgrad, dgrad],
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
- Return type:
bool - default is False