bridge.models.conversion.quantization_utils#

Module Contents#

Functions#

is_fp8_tensor

Return whether tensor uses one of PyTorch’s FP8 dtypes.

is_float8_e8m0_dtype

Return whether dtype is PyTorch’s E8M0 scale dtype.

scale_from_amax

Build positive quantization scales in the same scale family as scale_dtype.

dequantize_fp8_blockwise

Dequantize FP8 weights with one scale per 2D block.

maybe_dequantize_fp8_blockwise

Dequantize FP8 block-scaled weights, falling back to a plain cast.

maybe_dequantize_fp8

Dequantize FP8 weights with a scalar or broadcastable scale tensor.

dequantize_fp8_e4m3fn_with_scale

Dequantize FP8 E4M3 weights with a companion scale tensor.

_quantize_fp8_2d_blocks

_quantize_fp8_per_row_tiles

_quantize_fp8_1d_scale

quantize_fp8_e4m3fn_like_scale

Quantize a 2-D weight to FP8 E4M3 using source_scale geometry and dtype.

dequantize_mxfp4

Dequantize GPT-OSS MXFP4 block/scales tensors.

dequantize_mxfp4_e2m1_packed

Dequantize MXFP4 E2M1 weights packed two values per byte.

is_mxfp4_e2m1_scale_geometry

Return whether source_scale describes packed MXFP4 E2M1 K tiles.

quantize_mxfp4_e2m1_like_scale

Quantize a 2-D weight to packed MXFP4 E2M1 using source scale geometry.

maybe_dequantize_hf_quantized_weight

Load and dequantize HF *.weight tensors that carry sibling *.scale tensors.

requantize_hf_weight_scale_pairs

Recreate quantized HF *.weight/*.scale pairs using source scale layout.

dequantize_int4

Dequantize Kimi INT4 packed weights to bfloat16.

quantize_to_int4

Quantize bfloat16/float16 weights to Kimi INT4 packed format.

Data#

API#

bridge.models.conversion.quantization_utils.FP8_BLOCK_SIZE#

128

bridge.models.conversion.quantization_utils.FP8_DTYPES#

()

bridge.models.conversion.quantization_utils.FP8_E4M3_MAX#

448.0

bridge.models.conversion.quantization_utils.FP4_E2M1_MAX#

6.0

bridge.models.conversion.quantization_utils.MXFP4_BLOCK_SIZE#

32

bridge.models.conversion.quantization_utils._FP4_E2M1_TABLE_VALUES#

None

bridge.models.conversion.quantization_utils.is_fp8_tensor(tensor: torch.Tensor) bool#

Return whether tensor uses one of PyTorch’s FP8 dtypes.

bridge.models.conversion.quantization_utils.is_float8_e8m0_dtype(dtype: torch.dtype) bool#

Return whether dtype is PyTorch’s E8M0 scale dtype.

bridge.models.conversion.quantization_utils.scale_from_amax(
amax: torch.Tensor,
max_quantized_value: float,
scale_dtype: torch.dtype,
) torch.Tensor#

Build positive quantization scales in the same scale family as scale_dtype.

bridge.models.conversion.quantization_utils.dequantize_fp8_blockwise(
weight: torch.Tensor,
scale_inv: torch.Tensor,
*,
block_size: int = FP8_BLOCK_SIZE,
dtype: torch.dtype = torch.bfloat16,
) torch.Tensor#

Dequantize FP8 weights with one scale per 2D block.

DeepSeek-V3 and MiniMax-M2 store linear weights as FP8 tensors with a separate *_scale_inv tensor. Each scale applies to one 128x128 weight block by default.

bridge.models.conversion.quantization_utils.maybe_dequantize_fp8_blockwise(
weight: torch.Tensor,
scale_inv: torch.Tensor | None = None,
*,
block_size: int = FP8_BLOCK_SIZE,
dtype: torch.dtype = torch.bfloat16,
) torch.Tensor#

Dequantize FP8 block-scaled weights, falling back to a plain cast.

bridge.models.conversion.quantization_utils.maybe_dequantize_fp8(
weight: torch.Tensor,
scale_inv: torch.Tensor | None = None,
*,
dtype: torch.dtype = torch.bfloat16,
) torch.Tensor#

Dequantize FP8 weights with a scalar or broadcastable scale tensor.

bridge.models.conversion.quantization_utils.dequantize_fp8_e4m3fn_with_scale(
weight: torch.Tensor,
scale: torch.Tensor,
*,
name: str = '',
block_size: int = FP8_BLOCK_SIZE,
dtype: torch.dtype = torch.bfloat16,
) torch.Tensor#

Dequantize FP8 E4M3 weights with a companion scale tensor.

Supports three common HF checkpoint scale layouts: 1-D per-row or row-block scales, 2-D per-row K tiles, and 2-D block scales.

bridge.models.conversion.quantization_utils._quantize_fp8_2d_blocks(
weight: torch.Tensor,
source_scale: torch.Tensor,
*,
name: str = '',
block_size: int = FP8_BLOCK_SIZE,
) tuple[torch.Tensor, torch.Tensor]#
bridge.models.conversion.quantization_utils._quantize_fp8_per_row_tiles(
weight: torch.Tensor,
source_scale: torch.Tensor,
*,
name: str = '',
) tuple[torch.Tensor, torch.Tensor]#
bridge.models.conversion.quantization_utils._quantize_fp8_1d_scale(
weight: torch.Tensor,
source_scale: torch.Tensor,
*,
name: str = '',
block_size: int = FP8_BLOCK_SIZE,
) tuple[torch.Tensor, torch.Tensor]#
bridge.models.conversion.quantization_utils.quantize_fp8_e4m3fn_like_scale(
weight: torch.Tensor,
source_scale: torch.Tensor,
*,
name: str = '',
block_size: int = FP8_BLOCK_SIZE,
) tuple[torch.Tensor, torch.Tensor]#

Quantize a 2-D weight to FP8 E4M3 using source_scale geometry and dtype.

bridge.models.conversion.quantization_utils.dequantize_mxfp4(
blocks: torch.Tensor,
scales: torch.Tensor,
*,
dtype: torch.dtype = torch.bfloat16,
rows_per_chunk: int = 32768 * 1024,
) torch.Tensor#

Dequantize GPT-OSS MXFP4 block/scales tensors.

bridge.models.conversion.quantization_utils.dequantize_mxfp4_e2m1_packed(
weight_packed: torch.Tensor,
scale: torch.Tensor,
*,
dtype: torch.dtype = torch.bfloat16,
) torch.Tensor#

Dequantize MXFP4 E2M1 weights packed two values per byte.

scale is expected to be one scale per row and per K tile. E8M0 scale tensors can be passed directly; .to(torch.float32) materializes their power-of-two values.

bridge.models.conversion.quantization_utils.is_mxfp4_e2m1_scale_geometry(
weight: torch.Tensor,
source_scale: torch.Tensor,
*,
block_size: int = MXFP4_BLOCK_SIZE,
) bool#

Return whether source_scale describes packed MXFP4 E2M1 K tiles.

bridge.models.conversion.quantization_utils.quantize_mxfp4_e2m1_like_scale(
weight: torch.Tensor,
source_scale: torch.Tensor,
*,
name: str = '',
block_size: int = MXFP4_BLOCK_SIZE,
) tuple[torch.Tensor, torch.Tensor]#

Quantize a 2-D weight to packed MXFP4 E2M1 using source scale geometry.

bridge.models.conversion.quantization_utils.maybe_dequantize_hf_quantized_weight(
hf_param: str | dict[str, str],
hf_state_dict: collections.abc.Mapping[str, torch.Tensor],
*,
dtype: torch.dtype = torch.bfloat16,
) torch.Tensor | dict[str, torch.Tensor]#

Load and dequantize HF *.weight tensors that carry sibling *.scale tensors.

bridge.models.conversion.quantization_utils.requantize_hf_weight_scale_pairs(
converted_weights_dict: collections.abc.Mapping[str, torch.Tensor],
hf_state_dict: collections.abc.Mapping[str, torch.Tensor],
*,
use_mxfp4: collections.abc.Callable[[str, torch.Tensor, torch.Tensor], bool] | None = None,
) dict[str, torch.Tensor]#

Recreate quantized HF *.weight/*.scale pairs using source scale layout.

use_mxfp4 lets model bridges opt specific parameters into packed MXFP4 output. Other scaled weights are emitted as FP8 E4M3.

bridge.models.conversion.quantization_utils.dequantize_int4(
weight_packed: torch.Tensor,
weight_scale: torch.Tensor,
weight_shape: torch.Tensor,
group_size: int = 32,
device: str | torch.device | None = None,
) torch.Tensor#

Dequantize Kimi INT4 packed weights to bfloat16.

The checkpoint stores eight offset-binary INT4 values in each int32 slot and carries per-group scales beside the packed tensor.

bridge.models.conversion.quantization_utils.quantize_to_int4(
weight: torch.Tensor,
group_size: int = 32,
scale_dtype: torch.dtype = torch.bfloat16,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Quantize bfloat16/float16 weights to Kimi INT4 packed format.