core.distributed.fsdp.src.megatron_fsdp.mixed_precision#
Module Contents#
Classes#
Megatron-FSDP Mixed Precision Dataclass |
Functions#
Check if minimum version of |
|
Check if a tensor is a FP8 tensor. |
|
Check if a tensor is a Blockwise FP8 tensor. |
|
Check if a FP8 tensor needs transpose data. |
|
Check if a FP8 tensor needs transpose data, for meta device init scenario. |
|
Discard the transpose cache of a FP8 tensor. |
|
Create the transpose cache of a FP8 tensor. |
|
Set the raw data of a Transformer Engine Float8Tensor. |
|
Get the underlying raw storage of a FP8 tensor. |
|
Dequantize a FP8 tensor to a higher precision. |
|
Quantize sharded parameters to FP8. |
|
Get the TransformerEngine model parameter quantization context manager. |
Data#
API#
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.logger#
‘getLogger(…)’
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_te_min_version(vers, check_equality=True)#
Check if minimum version of
transformer-engineis installed.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_float8tensor(tensor: torch.Tensor) bool#
Check if a tensor is a FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_blockwise_float8tensor(tensor: torch.Tensor) bool#
Check if a tensor is a Blockwise FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_need_transpose_data(tensor: torch.Tensor) bool#
Check if a FP8 tensor needs transpose data.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_need_transpose_data_for_meta_device_init(
- module: transformer_engine.pytorch.module.base.TransformerEngineBaseModule,
Check if a FP8 tensor needs transpose data, for meta device init scenario.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_discard_transpose_cache(tensor: torch.Tensor) None#
Discard the transpose cache of a FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_create_transpose_cache(tensors: List[torch.Tensor]) None#
Create the transpose cache of a FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision._fp8_create_transpose_cache_fallback(
- tensors: List[torch.Tensor],
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_set_raw_data(
- tensor: torch.Tensor,
- data: torch.Tensor,
- set_transpose: bool = False,
Set the raw data of a Transformer Engine Float8Tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_get_raw_data(
- tensor: torch.Tensor,
- get_transpose: bool = False,
Get the underlying raw storage of a FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_dequantize(tensor: torch.Tensor) torch.Tensor#
Dequantize a FP8 tensor to a higher precision.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_quantize(
- model_params: List[torch.Tensor],
- main_params: List[torch.Tensor],
- start_offsets: List[int],
- data_parallel_group: torch.distributed.ProcessGroup,
- fsdp_shard_model_params: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
Quantize sharded parameters to FP8.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision._fp8_quantize_fallback(
- model_params: List[torch.Tensor],
- main_params: List[torch.Tensor],
- start_offsets: List[int],
- data_parallel_group: torch.distributed.ProcessGroup,
- fsdp_shard_model_params: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.get_quantized_model_init_context_cls()#
Get the TransformerEngine model parameter quantization context manager.
- class core.distributed.fsdp.src.megatron_fsdp.mixed_precision.MixedPrecisionPolicy#
Megatron-FSDP Mixed Precision Dataclass
- main_params_dtype: Optional[torch.dtype]#
None
Data type for the main weight buffer utilized for distributed optimization and quantization with Megatron-FSDP. If set to None, the model compute weight buffer will take the role of the main weights, or when no sharding is applied, the native model weights become the main weights. Defaults to torch.float32.
- main_grads_dtype: Optional[torch.dtype]#
None
Data type for the main gradient buffer utilized for distributed optimization with Megatron-FSDP. If set to None, main gradients will match the dtype of the model compute parameters specified by the user model. Defaults to torch.float32.
- grad_comm_dtype: Optional[torch.dtype]#
None
Data type for gradient gather / scatter communications. Can be utilized to reduce communication latency, but adds overhead for type-casting and copy operations. If using NCCL UBR v2.27+, gradient reduction may be performed in high-precision depending on the network domain (NVLink or IB), and can enable mixed-precision communication and accumulation, e.g. setting grad_comm_dtype to BF16 can support FP32 reduction even though we have BF16 input and output communication buffers. If set to None, the main_grads_dtype is used. Defaults to torch.float32. If using
no_shard,optim, or aFixedPoolAllocator(fsdp_double_buffer), allocatingdtype-custom gradient communication buffers (per FSDP group) adds memory overhead.