core.inference.quantization.utils#

Module Contents#

Functions#

_verify_te_to_mcore_mxfp8_conversion

quantize_model_to_mxfp8

Convert TE MXFP8 weights to mcore MXFP8Tensor format.

_should_quantize_param

Return True if a parameter should be quantized to FlashInfer MXFP8.

_to_bf16

Convert a parameter value to BF16 for quantization.

collect_mxfp8_param_metadata

Record shape/dtype/device for each parameter that will be quantized.

quantize_params_to_mxfp8

Quantize model parameters to MXFP8Tensor format.

_mm_mxfp8_flashinfer

MXFP8 matmul via FlashInfer.

_mm_mxfp8_torch

MXFP8 matmul via torch.nn.functional.scaled_mm.

mm_mxfp8

Compute a matmul in MXFP8.

API#

core.inference.quantization.utils._verify_te_to_mcore_mxfp8_conversion(
te_dequantized,
fi_quantized: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
) None#
core.inference.quantization.utils.quantize_model_to_mxfp8(
model: torch.nn.Module,
backend: str = 'flashinfer',
) None#

Convert TE MXFP8 weights to mcore MXFP8Tensor format.

Recursively walks the model and replaces each TEMXFP8Tensor parameter with an MXFP8Tensor re-quantized via the specified backend.

Parameters:
  • model – The model whose TE MXFP8 parameters should be converted.

  • backend – ‘flashinfer’ or ‘triton’ quantization backend.

core.inference.quantization.utils._should_quantize_param(val: torch.Tensor) bool#

Return True if a parameter should be quantized to FlashInfer MXFP8.

core.inference.quantization.utils._to_bf16(val: torch.Tensor) torch.Tensor#

Convert a parameter value to BF16 for quantization.

core.inference.quantization.utils.collect_mxfp8_param_metadata(
model: torch.nn.Module,
) Dict[str, Tuple[torch.Size, torch.dtype, torch.device]]#

Record shape/dtype/device for each parameter that will be quantized.

Called once before the first quantization to record the original parameter metadata (shape, dtype, device) before any format conversion.

core.inference.quantization.utils.quantize_params_to_mxfp8(
model: torch.nn.Module,
persistent_buffers: Optional[Dict[str, megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor]] = None,
_prefix: str = '',
backend: str = 'flashinfer',
) Dict[str, megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor]#

Quantize model parameters to MXFP8Tensor format.

Handles both TEMXFP8Tensor (fp8_param=True) and BF16/FP16 nn.Parameter inputs. When persistent_buffers is provided, new quantized values are copy_()’d into the existing MXFP8Tensor objects so that CUDA-graph device-pointer captures remain valid.

Parameters:
  • model – The model whose parameters should be quantized.

  • persistent_buffers – If not None, a dict mapping fully-qualified parameter names to previously-created MXFP8Tensor objects. Updated in-place and returned.

  • _prefix – Internal recursion prefix – callers should not set this.

  • backend – ‘flashinfer’ or ‘triton’ quantization backend.

Returns:

The persistent_buffers dict (created on first call if None).

core.inference.quantization.utils._mm_mxfp8_flashinfer(
x_mxfp8: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
out=None,
)#

MXFP8 matmul via FlashInfer.

core.inference.quantization.utils._mm_mxfp8_torch(
x_mxfp8: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
out=None,
)#

MXFP8 matmul via torch.nn.functional.scaled_mm.

core.inference.quantization.utils.mm_mxfp8(
x: torch.Tensor,
weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
out: torch.Tensor = None,
)#

Compute a matmul in MXFP8.

Quantizes the bf16 input activation tensor on the fly. Weight must be pre-quantized. Dispatches to FlashInfer or torch based on weight.backend.