core.distributed.fsdp.src.megatron_fsdp.mixed_precision#

Module Contents#

Classes#

MixedPrecisionPolicy

Megatron-FSDP Mixed Precision Dataclass

Functions#

is_te_min_version

Check if minimum version of transformer-engine is installed.

is_float8tensor

Check if a tensor is a FP8 tensor.

is_blockwise_float8tensor

Check if a tensor is a Blockwise FP8 tensor.

fp8_need_transpose_data

Check if a FP8 tensor needs transpose data.

fp8_need_transpose_data_for_meta_device_init

Check if a FP8 tensor needs transpose data, for meta device init scenario.

fp8_discard_transpose_cache

Discard the transpose cache of a FP8 tensor.

fp8_create_transpose_cache

Create the transpose cache of a FP8 tensor.

_fp8_create_transpose_cache_fallback

fp8_set_raw_data

Set the raw data of a Transformer Engine Float8Tensor.

fp8_get_raw_data

Get the underlying raw storage of a FP8 tensor.

fp8_dequantize

Dequantize a FP8 tensor to a higher precision.

fp8_quantize

Quantize sharded parameters to FP8.

_fp8_quantize_fallback

get_quantized_model_init_context_cls

Get the TransformerEngine model parameter quantization context manager.

Data#

API#

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.logger#

‘getLogger(…)’

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_te_min_version(vers, check_equality=True)#

Check if minimum version of transformer-engine is installed.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_float8tensor(tensor: torch.Tensor) bool#

Check if a tensor is a FP8 tensor.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_blockwise_float8tensor(tensor: torch.Tensor) bool#

Check if a tensor is a Blockwise FP8 tensor.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_need_transpose_data(tensor: torch.Tensor) bool#

Check if a FP8 tensor needs transpose data.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_need_transpose_data_for_meta_device_init(
module: transformer_engine.pytorch.module.base.TransformerEngineBaseModule,
) bool#

Check if a FP8 tensor needs transpose data, for meta device init scenario.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_discard_transpose_cache(tensor: torch.Tensor) None#

Discard the transpose cache of a FP8 tensor.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_create_transpose_cache(tensors: List[torch.Tensor]) None#

Create the transpose cache of a FP8 tensor.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision._fp8_create_transpose_cache_fallback(
tensors: List[torch.Tensor],
) None#
core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_set_raw_data(
tensor: torch.Tensor,
data: torch.Tensor,
set_transpose: bool = False,
) None#

Set the raw data of a Transformer Engine Float8Tensor.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_get_raw_data(
tensor: torch.Tensor,
get_transpose: bool = False,
) torch.Tensor#

Get the underlying raw storage of a FP8 tensor.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_dequantize(tensor: torch.Tensor) torch.Tensor#

Dequantize a FP8 tensor to a higher precision.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_quantize(
model_params: List[torch.Tensor],
main_params: List[torch.Tensor],
start_offsets: List[int],
data_parallel_group: torch.distributed.ProcessGroup,
fsdp_shard_model_params: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
) None#

Quantize sharded parameters to FP8.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision._fp8_quantize_fallback(
model_params: List[torch.Tensor],
main_params: List[torch.Tensor],
start_offsets: List[int],
data_parallel_group: torch.distributed.ProcessGroup,
fsdp_shard_model_params: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
) None#
core.distributed.fsdp.src.megatron_fsdp.mixed_precision.get_quantized_model_init_context_cls()#

Get the TransformerEngine model parameter quantization context manager.

class core.distributed.fsdp.src.megatron_fsdp.mixed_precision.MixedPrecisionPolicy#

Megatron-FSDP Mixed Precision Dataclass

main_params_dtype: Optional[torch.dtype]#

None

Data type for the main weight buffer utilized for distributed optimization and quantization with Megatron-FSDP. If set to None, the model compute weight buffer will take the role of the main weights, or when no sharding is applied, the native model weights become the main weights. Defaults to torch.float32.

main_grads_dtype: Optional[torch.dtype]#

None

Data type for the main gradient buffer utilized for distributed optimization with Megatron-FSDP. If set to None, main gradients will match the dtype of the model compute parameters specified by the user model. Defaults to torch.float32.

grad_comm_dtype: Optional[torch.dtype]#

None

Data type for gradient gather / scatter communications. Can be utilized to reduce communication latency, but adds overhead for type-casting and copy operations. If using NCCL UBR v2.27+, gradient reduction may be performed in high-precision depending on the network domain (NVLink or IB), and can enable mixed-precision communication and accumulation, e.g. setting grad_comm_dtype to BF16 can support FP32 reduction even though we have BF16 input and output communication buffers. If set to None, the main_grads_dtype is used. Defaults to torch.float32. If using no_shard, optim, or a FixedPoolAllocator (fsdp_double_buffer), allocating dtype-custom gradient communication buffers (per FSDP group) adds memory overhead.