core.inference.quantization.utils#

Module Contents#

Functions#

_verify_te_to_flashinfer_mxfp8_conversion

quantize_model_to_mxfp8

Converts a TE MXFP8 model to a FlashInfer MXFP8 model by recursively translating each layer’s weights.

mm_mxfp8

Computes a matmul in MXFP8 using FlashInfer.

API#

core.inference.quantization.utils._verify_te_to_flashinfer_mxfp8_conversion(
te_dequantized,
fi_quantized: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
) None#
core.inference.quantization.utils.quantize_model_to_mxfp8(model: torch.nn.Module) None#

Converts a TE MXFP8 model to a FlashInfer MXFP8 model by recursively translating each layer’s weights.

core.inference.quantization.utils.mm_mxfp8(
x: torch.Tensor,
weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
out: torch.Tensor = None,
)#

Computes a matmul in MXFP8 using FlashInfer.

Quantizes the bf16 input activation tensor. Weight must be pre-quantized.