core.distributed.fsdp.src.megatron_fsdp.mixed_precision#
Module Contents#
Classes#
Megatron-FSDP Mixed Precision Dataclass |
Functions#
Check if minimum version of |
|
Check if a tensor is a FP8 tensor. |
|
Check if a tensor is a Blockwise FP8 tensor. |
|
Check if a FP8 tensor needs transpose data. |
|
Check if a FP8 tensor needs transpose data, for meta device init scenario. |
|
Discard the transpose cache of a FP8 tensor. |
|
Create the transpose cache of a FP8 tensor. |
|
Set the raw data of a Transformer Engine Float8Tensor. |
|
Get the underlying raw storage of a FP8 tensor. |
|
Dequantize a FP8 tensor to a higher precision. |
|
Quantize sharded parameters to FP8. |
|
Get the TransformerEngine model parameter quantization context manager. |
Data#
API#
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.logger#
‘getLogger(…)’
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_te_min_version(vers, check_equality=True)#
Check if minimum version of
transformer-engineis installed.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_float8tensor(tensor: torch.Tensor) bool#
Check if a tensor is a FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_blockwise_float8tensor(tensor: torch.Tensor) bool#
Check if a tensor is a Blockwise FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_need_transpose_data(tensor: torch.Tensor) bool#
Check if a FP8 tensor needs transpose data.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_need_transpose_data_for_meta_device_init(
- module: transformer_engine.pytorch.module.base.TransformerEngineBaseModule,
Check if a FP8 tensor needs transpose data, for meta device init scenario.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_discard_transpose_cache(tensor: torch.Tensor) None#
Discard the transpose cache of a FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_create_transpose_cache(tensors: List[torch.Tensor]) None#
Create the transpose cache of a FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision._fp8_create_transpose_cache_fallback(
- tensors: List[torch.Tensor],
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_set_raw_data(
- tensor: torch.Tensor,
- data: torch.Tensor,
- set_transpose: bool = False,
Set the raw data of a Transformer Engine Float8Tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_get_raw_data(
- tensor: torch.Tensor,
- get_transpose: bool = False,
Get the underlying raw storage of a FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_dequantize(tensor: torch.Tensor) torch.Tensor#
Dequantize a FP8 tensor to a higher precision.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_quantize(
- model_params: List[torch.Tensor],
- main_params: List[torch.Tensor],
- start_offsets: List[int],
- data_parallel_group: torch.distributed.ProcessGroup,
- fsdp_shard_model_params: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
Quantize sharded parameters to FP8.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision._fp8_quantize_fallback(
- model_params: List[torch.Tensor],
- main_params: List[torch.Tensor],
- start_offsets: List[int],
- data_parallel_group: torch.distributed.ProcessGroup,
- fsdp_shard_model_params: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.get_quantized_model_init_context_cls()#
Get the TransformerEngine model parameter quantization context manager.
- class core.distributed.fsdp.src.megatron_fsdp.mixed_precision.MixedPrecisionPolicy#
Megatron-FSDP Mixed Precision Dataclass
- main_params_dtype: Optional[torch.dtype]#
None
Data type for the main weight buffer utilized for distributed optimization and quantization with Megatron-FSDP. If set to None, the model compute weight buffer will take the role of the main weights, or when no sharding is applied, the native model weights become the main weights. Defaults to torch.float32.
- main_grads_dtype: Optional[torch.dtype]#
None
Data type for the main gradient buffer utilized for distributed optimization with Megatron-FSDP. If set to None, main gradients will match the dtype of the model compute parameters specified by the user model. Defaults to None.
- grad_comm_dtype: Optional[torch.dtype]#
None
Data type for gradient gather / scatter communications. Can be utilized to reduce communication latency, but adds overhead for type-casting and copy operations. If using NCCL UBR v2.27+, gradient reduction may be performed in high-precision depending on the network domain (NVLink or IB), and can enable mixed-precision communication and accumulation, e.g. setting grad_comm_dtype to
BF16can supportFP32reduction even though we haveBF16input and output communication buffers. If set to None, themain_grads_dtypeis used. If using HSDP (either DP-Replicate or DP-Outer inouter_dp_sharding_strategy),no_shard,optim, or aFixedPoolAllocator(fsdp_double_buffer), allocatingdtype-custom gradient communication buffers (per FSDP group) adds memory overhead. Defaults to None. No additional memory is allocated whengrad_comm_dtype == main_grads_dtype.