core.fp8_utils#
Utility functions related to FP8 that are used throughout Megatron core
Module Contents#
Functions#
Check if a tensor is a Transformer Engine Float8Tensor. |
|
Check if a tensor is a Transformer Engine MXFP8Tensor |
|
Dequantize a fp8 tensor to a higher precision tensor. |
|
Resolve a Python import path like ‘pkg.mod.func’ to a callable. |
|
Get the alignment size required for fp8 GEMM. |
|
Returns whether the given module is a ColumnParallelLinear layer. |
|
Returns whether the given module is a RowParallelLinear layer. |
|
Replace the underlying raw data of a tensor with new data. |
|
Cast shard fp32 main params to fp8 model params. |
|
Correct the amax history of fp8 tensors when it’s necessary (i.e., in TE1.x). |
|
Post-processing after all-gather for weights in distributed optimizer. |
|
Check if the layer is in bf16. |
Data#
API#
- core.fp8_utils.HAVE_TE#
False
- core.fp8_utils.HAVE_TE_FP8_TENSOR_CLASS#
False
- core.fp8_utils.is_float8tensor(tensor: torch.Tensor) bool#
Check if a tensor is a Transformer Engine Float8Tensor.
Note that in TE2.x, in order to support more recipes, the design of the fp8 tensor class has changed. Now Float8Tensor is only used for current scaling and delayed scaling. And mxfp8 and blockwise scaling have their own fp8 tensor classes. These different fp8 tensor classes are both inherited from QuantizedTensor. So, for TE1.x, FP8_TENSOR_CLASS is Float8Tensor, and for TE2.x, FP8_TENSOR_CLASS is QuantizedTensor.
- core.fp8_utils.is_mxfp8tensor(tensor: torch.Tensor) bool#
Check if a tensor is a Transformer Engine MXFP8Tensor
- core.fp8_utils.dequantize_fp8_tensor(fp8_tensor: torch.Tensor) torch.Tensor#
Dequantize a fp8 tensor to a higher precision tensor.
- core.fp8_utils._resolve_callable_from_python_import_path(dotted_path: str)#
Resolve a Python import path like ‘pkg.mod.func’ to a callable.
Raises ValueError with clear message on failure.
- core.fp8_utils._get_custom_recipe(
- quantizer_factory_python_path: str,
- core.fp8_utils.get_fp8_align_size(fp8_recipe: megatron.core.enums.Fp8Recipe) int#
Get the alignment size required for fp8 GEMM.
- core.fp8_utils.is_column_parallel_linear(module)#
Returns whether the given module is a ColumnParallelLinear layer.
- core.fp8_utils.is_row_parallel_linear(module)#
Returns whether the given module is a RowParallelLinear layer.
- core.fp8_utils.modify_underlying_storage(
- tensor: torch.Tensor,
- new_raw_data: torch.Tensor,
Replace the underlying raw data of a tensor with new data.
- core.fp8_utils.quantize_param_shard(
- model_params,
- main_params,
- start_offsets,
- data_parallel_group,
- fsdp_shard_model_params=None,
Cast shard fp32 main params to fp8 model params.
- core.fp8_utils.correct_amax_history_if_needed(model: List[torch.nn.Module])#
Correct the amax history of fp8 tensors when it’s necessary (i.e., in TE1.x).
- core.fp8_utils.post_all_gather_processing(model_params)#
Post-processing after all-gather for weights in distributed optimizer.
tensorwise: may need to create a transposed view to match backend GEMM.
blockwise: create column-wise storage.
- core.fp8_utils.is_first_last_bf16_layer(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- layer_no: int,
Check if the layer is in bf16.