core.distributed.fsdp.src.megatron_fsdp.mixed_precision#
Module Contents#
Functions#
Check if minimum version of |
|
Check if a tensor is a FP8 tensor. |
|
Check if a tensor is a Blockwise FP8 tensor. |
|
Check if a FP8 tensor needs transpose data. |
|
Check if a FP8 tensor needs transpose data, for meta device init scenario. |
|
Discard the transpose cache of a FP8 tensor. |
|
Create the transpose cache of a FP8 tensor. |
|
Set the raw data of a Transformer Engine Float8Tensor. |
|
Get the underlying raw storage of a FP8 tensor. |
|
Dequantize a FP8 tensor to a higher precision. |
|
Quantize sharded parameters to FP8. |
|
Data#
API#
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.logger#
‘getLogger(…)’
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_te_min_version(vers, check_equality=True)#
Check if minimum version of
transformer-engineis installed.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_float8tensor(tensor: torch.Tensor) bool#
Check if a tensor is a FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_blockwise_float8tensor(tensor: torch.Tensor) bool#
Check if a tensor is a Blockwise FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_need_transpose_data(tensor: torch.Tensor) bool#
Check if a FP8 tensor needs transpose data.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_need_transpose_data_for_meta_device_init(
- module: transformer_engine.pytorch.module.base.TransformerEngineBaseModule,
Check if a FP8 tensor needs transpose data, for meta device init scenario.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_discard_transpose_cache(tensor: torch.Tensor) None#
Discard the transpose cache of a FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_create_transpose_cache(tensors: List[torch.Tensor]) None#
Create the transpose cache of a FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision._fp8_create_transpose_cache_fallback(
- tensors: List[torch.Tensor],
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_set_raw_data(
- tensor: torch.Tensor,
- data: torch.Tensor,
- set_transpose: bool = False,
Set the raw data of a Transformer Engine Float8Tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_get_raw_data(
- tensor: torch.Tensor,
- get_transpose: bool = False,
Get the underlying raw storage of a FP8 tensor.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_dequantize(tensor: torch.Tensor) torch.Tensor#
Dequantize a FP8 tensor to a higher precision.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_quantize(
- model_params: List[torch.Tensor],
- main_params: List[torch.Tensor],
- start_offsets: List[int],
- data_parallel_group: torch.distributed.ProcessGroup,
- fsdp_shard_model_params: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
Quantize sharded parameters to FP8.
- core.distributed.fsdp.src.megatron_fsdp.mixed_precision._fp8_quantize_fallback(
- model_params: List[torch.Tensor],
- main_params: List[torch.Tensor],
- start_offsets: List[int],
- data_parallel_group: torch.distributed.ProcessGroup,
- fsdp_shard_model_params: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],