core.inference.quantization.utils#
Module Contents#
Functions#
Converts a TE MXFP8 model to a FlashInfer MXFP8 model by recursively translating each layer’s weights. |
|
Computes a matmul in MXFP8 using FlashInfer. |
API#
- core.inference.quantization.utils._verify_te_to_flashinfer_mxfp8_conversion(
- te_dequantized,
- fi_quantized: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- core.inference.quantization.utils.quantize_model_to_mxfp8(model: torch.nn.Module) None#
Converts a TE MXFP8 model to a FlashInfer MXFP8 model by recursively translating each layer’s weights.
- core.inference.quantization.utils.mm_mxfp8(
- x: torch.Tensor,
- weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- out: torch.Tensor = None,
Computes a matmul in MXFP8 using FlashInfer.
Quantizes the bf16 input activation tensor. Weight must be pre-quantized.