core.distributed.fsdp.src.megatron_fsdp.mixed_precision#

Module Contents#

Functions#

is_te_min_version

Check if minimum version of transformer-engine is installed.

is_float8tensor

Check if a tensor is a FP8 tensor.

is_blockwise_float8tensor

Check if a tensor is a Blockwise FP8 tensor.

fp8_need_transpose_data

Check if a FP8 tensor needs transpose data.

fp8_need_transpose_data_for_meta_device_init

Check if a FP8 tensor needs transpose data, for meta device init scenario.

fp8_discard_transpose_cache

Discard the transpose cache of a FP8 tensor.

fp8_create_transpose_cache

Create the transpose cache of a FP8 tensor.

_fp8_create_transpose_cache_fallback

fp8_set_raw_data

Set the raw data of a Transformer Engine Float8Tensor.

fp8_get_raw_data

Get the underlying raw storage of a FP8 tensor.

fp8_dequantize

Dequantize a FP8 tensor to a higher precision.

fp8_quantize

Quantize sharded parameters to FP8.

_fp8_quantize_fallback

Data#

API#

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.logger#

‘getLogger(…)’

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_te_min_version(vers, check_equality=True)#

Check if minimum version of transformer-engine is installed.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_float8tensor(tensor: torch.Tensor) bool#

Check if a tensor is a FP8 tensor.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.is_blockwise_float8tensor(tensor: torch.Tensor) bool#

Check if a tensor is a Blockwise FP8 tensor.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_need_transpose_data(tensor: torch.Tensor) bool#

Check if a FP8 tensor needs transpose data.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_need_transpose_data_for_meta_device_init(
module: transformer_engine.pytorch.module.base.TransformerEngineBaseModule,
) bool#

Check if a FP8 tensor needs transpose data, for meta device init scenario.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_discard_transpose_cache(tensor: torch.Tensor) None#

Discard the transpose cache of a FP8 tensor.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_create_transpose_cache(tensors: List[torch.Tensor]) None#

Create the transpose cache of a FP8 tensor.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision._fp8_create_transpose_cache_fallback(
tensors: List[torch.Tensor],
) None#
core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_set_raw_data(
tensor: torch.Tensor,
data: torch.Tensor,
set_transpose: bool = False,
) None#

Set the raw data of a Transformer Engine Float8Tensor.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_get_raw_data(
tensor: torch.Tensor,
get_transpose: bool = False,
) torch.Tensor#

Get the underlying raw storage of a FP8 tensor.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_dequantize(tensor: torch.Tensor) torch.Tensor#

Dequantize a FP8 tensor to a higher precision.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision.fp8_quantize(
model_params: List[torch.Tensor],
main_params: List[torch.Tensor],
start_offsets: List[int],
data_parallel_group: torch.distributed.ProcessGroup,
fsdp_shard_model_params: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
) None#

Quantize sharded parameters to FP8.

core.distributed.fsdp.src.megatron_fsdp.mixed_precision._fp8_quantize_fallback(
model_params: List[torch.Tensor],
main_params: List[torch.Tensor],
start_offsets: List[int],
data_parallel_group: torch.distributed.ProcessGroup,
fsdp_shard_model_params: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
) None#