core.inference.quantization.utils#
Module Contents#
Functions#
Converts a TE MXFP8 model to a FlashInfer MXFP8 model by recursively translating each layer’s weights. |
|
Return True if a parameter should be quantized to FlashInfer MXFP8. |
|
Convert a parameter value to BF16 for quantization. |
|
Record shape/dtype/device for each parameter that will be quantized. |
|
Quantize model parameters to FlashInfer MXFP8Tensor format. |
|
Computes a matmul in MXFP8 using FlashInfer. |
API#
- core.inference.quantization.utils._verify_te_to_flashinfer_mxfp8_conversion(
- te_dequantized,
- fi_quantized: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- core.inference.quantization.utils.quantize_model_to_mxfp8(model: torch.nn.Module) None#
Converts a TE MXFP8 model to a FlashInfer MXFP8 model by recursively translating each layer’s weights.
- core.inference.quantization.utils._should_quantize_param(val: torch.Tensor) bool#
Return True if a parameter should be quantized to FlashInfer MXFP8.
- core.inference.quantization.utils._to_bf16(val: torch.Tensor) torch.Tensor#
Convert a parameter value to BF16 for quantization.
- core.inference.quantization.utils.collect_mxfp8_param_metadata(
- model: torch.nn.Module,
Record shape/dtype/device for each parameter that will be quantized.
Called once before the first quantization to record the original parameter metadata (shape, dtype, device) before any format conversion.
- core.inference.quantization.utils.quantize_params_to_mxfp8(
- model: torch.nn.Module,
- persistent_buffers: Optional[Dict[str, megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor]] = None,
- _prefix: str = '',
Quantize model parameters to FlashInfer MXFP8Tensor format.
Handles both TEMXFP8Tensor (fp8_param=True) and BF16/FP16 nn.Parameter inputs. When persistent_buffers is provided, new quantized values are
copy_()’d into the existing MXFP8Tensor objects so that CUDA-graph device-pointer captures remain valid.- Parameters:
model – The model whose parameters should be quantized.
persistent_buffers – If not
None, a dict mapping fully-qualified parameter names to previously-createdMXFP8Tensorobjects. Updated in-place and returned._prefix – Internal recursion prefix – callers should not set this.
- Returns:
The
persistent_buffersdict (created on first call ifNone).
- core.inference.quantization.utils.mm_mxfp8(
- x: torch.Tensor,
- weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- out: torch.Tensor = None,
Computes a matmul in MXFP8 using FlashInfer.
Quantizes the bf16 input activation tensor. Weight must be pre-quantized.