core.fp8_utils#

Utility functions related to FP8 that are used throughout Megatron core

Module Contents#

Functions#

is_float8tensor

Check if a tensor is a Transformer Engine Float8Tensor.

is_mxfp8tensor

Check if a tensor is a Transformer Engine MXFP8Tensor

dequantize_fp8_tensor

Dequantize a fp8 tensor to a higher precision tensor.

_resolve_callable_from_python_import_path

Resolve a Python import path like ‘pkg.mod.func’ to a callable.

_get_custom_recipe

get_fp8_align_size

Get the alignment size required for fp8 GEMM.

is_column_parallel_linear

Returns whether the given module is a ColumnParallelLinear layer.

is_row_parallel_linear

Returns whether the given module is a RowParallelLinear layer.

modify_underlying_storage

Replace the underlying raw data of a tensor with new data.

quantize_param_shard

Cast shard fp32 main params to fp8 model params.

correct_amax_history_if_needed

Correct the amax history of fp8 tensors when it’s necessary (i.e., in TE1.x).

post_all_gather_processing

Post-processing after all-gather for weights in distributed optimizer.

is_first_last_bf16_layer

Check if the layer is in bf16.

Data#

API#

core.fp8_utils.HAVE_TE#

False

core.fp8_utils.HAVE_TE_FP8_TENSOR_CLASS#

False

core.fp8_utils.is_float8tensor(tensor: torch.Tensor) bool#

Check if a tensor is a Transformer Engine Float8Tensor.

Note that in TE2.x, in order to support more recipes, the design of the fp8 tensor class has changed. Now Float8Tensor is only used for current scaling and delayed scaling. And mxfp8 and blockwise scaling have their own fp8 tensor classes. These different fp8 tensor classes are both inherited from QuantizedTensor. So, for TE1.x, FP8_TENSOR_CLASS is Float8Tensor, and for TE2.x, FP8_TENSOR_CLASS is QuantizedTensor.

core.fp8_utils.is_mxfp8tensor(tensor: torch.Tensor) bool#

Check if a tensor is a Transformer Engine MXFP8Tensor

core.fp8_utils.dequantize_fp8_tensor(fp8_tensor: torch.Tensor) torch.Tensor#

Dequantize a fp8 tensor to a higher precision tensor.

core.fp8_utils._resolve_callable_from_python_import_path(dotted_path: str)#

Resolve a Python import path like ‘pkg.mod.func’ to a callable.

Raises ValueError with clear message on failure.

core.fp8_utils._get_custom_recipe(
quantizer_factory_python_path: str,
) Union[megatron.core.enums.Fp8Recipe, megatron.core.enums.Fp4Recipe]#
core.fp8_utils.get_fp8_align_size(fp8_recipe: megatron.core.enums.Fp8Recipe) int#

Get the alignment size required for fp8 GEMM.

core.fp8_utils.is_column_parallel_linear(module)#

Returns whether the given module is a ColumnParallelLinear layer.

core.fp8_utils.is_row_parallel_linear(module)#

Returns whether the given module is a RowParallelLinear layer.

core.fp8_utils.modify_underlying_storage(
tensor: torch.Tensor,
new_raw_data: torch.Tensor,
)#

Replace the underlying raw data of a tensor with new data.

core.fp8_utils.quantize_param_shard(
model_params,
main_params,
start_offsets,
data_parallel_group,
fsdp_shard_model_params=None,
)#

Cast shard fp32 main params to fp8 model params.

core.fp8_utils.correct_amax_history_if_needed(model: List[torch.nn.Module])#

Correct the amax history of fp8 tensors when it’s necessary (i.e., in TE1.x).

core.fp8_utils.post_all_gather_processing(model_params)#

Post-processing after all-gather for weights in distributed optimizer.

  • tensorwise: may need to create a transposed view to match backend GEMM.

  • blockwise: create column-wise storage.

core.fp8_utils.is_first_last_bf16_layer(
config: megatron.core.transformer.transformer_config.TransformerConfig,
layer_no: int,
)#

Check if the layer is in bf16.