Calls to Nvidia-DL-Framework-Inspect
Let’s look deeper into how Nvidia-DL-Framework-Inspect with Transformer Engine work together. TransformerEngine layers have some hook calls inside each of the GEMMs. Users can define feature classes or use feature classes provided with TE. File config.yaml
describes which hooks need to be used for which layers. Nvidia-DL-Framework-Inspect combines 3 things: TE training, feature classes and config.yaml
and takes care of inserting hooks in the correct places. This process is illustrated in the image below.
Fig 1: Example of Nvidia-DL-Framework-Inspect affecting training script with 1 Linear Layer. For tensors mentioned in config.yaml
, behavior of modify_tensor_enabled()
and modify_tensor()
calls are substituted with definitions from the feature class. Other calls return default values - in fact they do nothing.
In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed. There are 2 categories of API calls, each is used for different purposes:
GEMM calls - invoked during every GEMM, used to process or quantize tensors and collect information about them,
Routing calls - invoked at the beginning of every forward pass - they indicate whether a feature is going to use modify_tensor(), etc.
If all routing calls for the layer return False, then the layer is invoked in an optimized version with Transformer Engine fusions. If any of the routing calls return True, layers are run without the fusions. This is necessary because otherwise some tensors cannot be accessed if fusions happen. An important remark is that if no feature is used for the layer, then it should perform as fast as the layer without initializing debug_api.
- transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor(config: Dict, layer_name: str, gemm: str, tensor_name: str, tensor: torch.Tensor, default_quantizer: transformer_engine.pytorch.tensor.Quantizer, iteration: int, out: torch.Tensor | transformer_engine.pytorch.tensor.QuantizedTensor) torch.Tensor | transformer_engine.pytorch.tensor.QuantizedTensor | None
It allows tensor modification. For example, feature FakeQuant uses it to emulate casting to FP8. It can be invoked at most once for each tensor within a given GEMM operation.
This call is invoked if modify_tensor_enabled returns True and the feature is enabled for the tensor_name and gemm. Then it is called instead of the default quantization.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
tensor (torch.Tensor) – tensor in high precision,
gemm (str) – one of [fprop, dgrad, wgrad],
tensor_name (str) – one of [activation, weight, gradient, output, wgrad, dgrad],
default_quantizer (Quantizer) – quantizer which is used to cast the tensor to lower precision if modify_tensor is not invoked. For example, feature per tensor scale uses it to obtain FP8 dtype of the tensor. If the recipe indicates that the tensor is not cast - for example, if running without FP8 autocast, then default_quantizer=None,
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
out (Union[torch.Tensor, QuantizedTensor]) – output tensor, used in the weight caching mechanism.
- Returns:
Union[torch.Tensor, transformer_engine.pytorch.QuantizerTensor, None] – can be torch.Tensor or one of the Transformer Engine’s QuantizedTensor - the rule is that both tensors returned for each GEMM should have the same type. If both are Float8Tensor, then GEMM is run in FP8. If both are torch.Tensor, GEMM is run in high precision. Please take that into account especially if only one tensor of the GEMM is processed by the modify_tensor(). For example, FakeQuant disabled FP8 GEMM to ensure that the second tensor is also in high precision. If the tensor is not the input for any GEMM - namely output, wgrad and dgrad - the return type would match the input type.
Should return None if out is not None.
- transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor_enabled(config: Dict, layer_name: str, gemm: str, tensor_name: str, iteration: int) bool | Tuple[bool, int | None]
It is used to determine whether modify_tensor will be run for a given GEMM and tensor name. It has higher priority than fp8_gemm; if modify_tensor_enabled returns True or (True, next_enabled_iter), then modify_tensor call is invoked for the respective tensor no matter what.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled. It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor. Returning the next enabled iteration can help optimize CPU usage, especially when the interval between modify_tensor is large. Returning only a bool is deprecated.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
gemm (str) – one of [fprop, dgrad, wgrad],
tensor_name (str) – one of [activation, weight, gradient, output, wgrad, dgrad],
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
- Return type:
Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
- transformer_engine.debug.features.api.TEDefaultFeatures.fp8_gemm_enabled(config: Dict, layer_name: str, gemm: str, iteration: int) bool | Tuple[bool, int | None]
If the tensor is not processed using modify_tensor and the fp8 recipe is enabled, then the decision whether to cast it to fp8 is based on the value returned by the call fp8_gemm_enabled. If the tensor is processed using modify_tensor or fp8 autocast is not enabled, the result of this call does not matter.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be disabled. It can return (bool, None) if the feature will never be enabled for that layer and gemm. Returning the next enabled iteration can help optimize CPU usage.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
gemm (str) – one of [fprop, dgrad, wgrad],
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
- Return type:
Union[bool, Tuple[bool, Optional[int]]] - default is (True, None)
- transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor(config: Dict, layer_name: str, tensor_name: str, tensor: torch.Tensor, rowwise_quantized_tensor: torch.Tensor | None, columnwise_quantized_tensor: torch.Tensor | None, quantizer: transformer_engine.pytorch.tensor.Quantizer | None, iteration: int, tp_group: torch.distributed.ProcessGroup) None
The feature is invoked if inspect_tensor_enabled returns True. It can be used to obtain information on the high precision tensor. For example, it is run by the LogTensorStats feature.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
tensor_name (str) – one of [activation, weight, gradient, output, wgrad, dgrad],
tensor (torch.Tensor) – tensor in high precision,
rowwise_quantized_tensor (Optional[torch.Tensor]) – rowwise quantized tensor,
columnwise_quantized_tensor (Optional[torch.Tensor]) – columnwise quantized tensor,
quantizer (Optional[Quantizer]) – quantizer,
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
tp_group (torch.distributed.ProcessGroup) – process group for the tensor parallel group. This is used for weight statistics reduction. This is not reduction group from debug_api.
- Return type:
Should return nothing.
- transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize(config: Dict, layer_name: str, tensor_name: str, tensor: torch.Tensor, iteration: int, tp_group: torch.distributed.ProcessGroup, rowwise: bool) None
This is deprecated call, we advise to use inspect_tensor instead.
Similar to inspect_tensor, but is run after one of the: fp8 cast, modify_tensor if they are run. If none of the fp8 cast or modify_tensor is invoked, then inspect_tensor_postquantize is also not invoked. The feature LogFp8Stats uses this call to collect FP8 statistics after the quantization.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
tensor_name (str) – one of [activation, weight, gradient, output, wgrad, dgrad],
tensor (torch.Tensor) – tensor in fp8 or processed tensor after the modify_tensor call,
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
tp_group (torch.distributed.ProcessGroup) – process group for the tensor parallel group. This is used for weight statistics reduction. This is not reduction group from debug_api.
- Return type:
Should return nothing.
- transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_enabled(config: Dict, layer_name: str, tensor_name: str, iteration: int) bool | Tuple[bool, int | None]
It is a routing call, which is run at the initialization of the layer. Determines if inspect_tensor for a given GEMM and tensor will be invoked.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled. It can return (bool, None) if the feature will never be enabled for that layer and tensor. Returning the next enabled iteration can help optimize CPU usage, especially when the interval between inspect_tensor is large. Returning only a bool is deprecated.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
tensor_name (str) – one of [activation, weight, gradient, output, wgrad, dgrad].
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
- Return type:
Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
- transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize_enabled(config: Dict, layer_name: str, gemm: str, tensor_name: str, iteration: int) bool | Tuple[bool, int | None]
This is deprecated call, we advise to use inspect_tensor and inspect_tensor_enabled instead.
It is a routing call, which is run at the initialization of the layer. Determines if inspect_tensor_postquantize for a given GEMM and tensor will be invoked.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled. It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor name. Returning the next enabled iteration can help optimize CPU usage, especially when the interval between inspect_tensor_postquantize is large. Returning only a bool is deprecated.
- Parameters:
config (Dict) – dictionary containing information from config.yaml corresponding to the feature, tensor_name and gemm.
layer_name (str)
gemm (str) – one of [fprop, dgrad, wgrad],
tensor_name (str) – one of [activation, weight, gradient, output, wgrad, dgrad],
iteration (int) – iteration number - equal to the number of times debug_api.step() was called.
- Return type:
Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)