core.inference.quantization.utils#

Module Contents#

Functions#

_verify_te_to_flashinfer_mxfp8_conversion

quantize_model_to_mxfp8

Converts a TE MXFP8 model to a FlashInfer MXFP8 model by recursively translating each layer’s weights.

_should_quantize_param

Return True if a parameter should be quantized to FlashInfer MXFP8.

_to_bf16

Convert a parameter value to BF16 for quantization.

collect_mxfp8_param_metadata

Record shape/dtype/device for each parameter that will be quantized.

quantize_params_to_mxfp8

Quantize model parameters to FlashInfer MXFP8Tensor format.

mm_mxfp8

Computes a matmul in MXFP8 using FlashInfer.

API#

core.inference.quantization.utils._verify_te_to_flashinfer_mxfp8_conversion(
te_dequantized,
fi_quantized: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
) None#
core.inference.quantization.utils.quantize_model_to_mxfp8(model: torch.nn.Module) None#

Converts a TE MXFP8 model to a FlashInfer MXFP8 model by recursively translating each layer’s weights.

core.inference.quantization.utils._should_quantize_param(val: torch.Tensor) bool#

Return True if a parameter should be quantized to FlashInfer MXFP8.

core.inference.quantization.utils._to_bf16(val: torch.Tensor) torch.Tensor#

Convert a parameter value to BF16 for quantization.

core.inference.quantization.utils.collect_mxfp8_param_metadata(
model: torch.nn.Module,
) Dict[str, Tuple[torch.Size, torch.dtype, torch.device]]#

Record shape/dtype/device for each parameter that will be quantized.

Called once before the first quantization to record the original parameter metadata (shape, dtype, device) before any format conversion.

core.inference.quantization.utils.quantize_params_to_mxfp8(
model: torch.nn.Module,
persistent_buffers: Optional[Dict[str, megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor]] = None,
_prefix: str = '',
) Dict[str, megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor]#

Quantize model parameters to FlashInfer MXFP8Tensor format.

Handles both TEMXFP8Tensor (fp8_param=True) and BF16/FP16 nn.Parameter inputs. When persistent_buffers is provided, new quantized values are copy_()’d into the existing MXFP8Tensor objects so that CUDA-graph device-pointer captures remain valid.

Parameters:
  • model – The model whose parameters should be quantized.

  • persistent_buffers – If not None, a dict mapping fully-qualified parameter names to previously-created MXFP8Tensor objects. Updated in-place and returned.

  • _prefix – Internal recursion prefix – callers should not set this.

Returns:

The persistent_buffers dict (created on first call if None).

core.inference.quantization.utils.mm_mxfp8(
x: torch.Tensor,
weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
out: torch.Tensor = None,
)#

Computes a matmul in MXFP8 using FlashInfer.

Quantizes the bf16 input activation tensor. Weight must be pre-quantized.