bridge.models.conversion.quantization_utils#

Module Contents#

Functions#

is_fp8_tensor

Return whether tensor uses one of PyTorch’s FP8 dtypes.

dequantize_fp8_blockwise

Dequantize FP8 weights with one scale per 2D block.

maybe_dequantize_fp8_blockwise

Dequantize FP8 block-scaled weights, falling back to a plain cast.

maybe_dequantize_fp8

Dequantize FP8 weights with a scalar or broadcastable scale tensor.

dequantize_mxfp4

Dequantize GPT-OSS MXFP4 block/scales tensors.

dequantize_int4

Dequantize Kimi INT4 packed weights to bfloat16.

quantize_to_int4

Quantize bfloat16/float16 weights to Kimi INT4 packed format.

Data#

API#

bridge.models.conversion.quantization_utils.FP8_BLOCK_SIZE#

128

bridge.models.conversion.quantization_utils.FP8_DTYPES#

()

bridge.models.conversion.quantization_utils.is_fp8_tensor(tensor: torch.Tensor) bool#

Return whether tensor uses one of PyTorch’s FP8 dtypes.

bridge.models.conversion.quantization_utils.dequantize_fp8_blockwise(
weight: torch.Tensor,
scale_inv: torch.Tensor,
*,
block_size: int = FP8_BLOCK_SIZE,
dtype: torch.dtype = torch.bfloat16,
) torch.Tensor#

Dequantize FP8 weights with one scale per 2D block.

DeepSeek-V3 and MiniMax-M2 store linear weights as FP8 tensors with a separate *_scale_inv tensor. Each scale applies to one 128x128 weight block by default.

bridge.models.conversion.quantization_utils.maybe_dequantize_fp8_blockwise(
weight: torch.Tensor,
scale_inv: torch.Tensor | None = None,
*,
block_size: int = FP8_BLOCK_SIZE,
dtype: torch.dtype = torch.bfloat16,
) torch.Tensor#

Dequantize FP8 block-scaled weights, falling back to a plain cast.

bridge.models.conversion.quantization_utils.maybe_dequantize_fp8(
weight: torch.Tensor,
scale_inv: torch.Tensor | None = None,
*,
dtype: torch.dtype = torch.bfloat16,
) torch.Tensor#

Dequantize FP8 weights with a scalar or broadcastable scale tensor.

bridge.models.conversion.quantization_utils.dequantize_mxfp4(
blocks: torch.Tensor,
scales: torch.Tensor,
*,
dtype: torch.dtype = torch.bfloat16,
rows_per_chunk: int = 32768 * 1024,
) torch.Tensor#

Dequantize GPT-OSS MXFP4 block/scales tensors.

bridge.models.conversion.quantization_utils.dequantize_int4(
weight_packed: torch.Tensor,
weight_scale: torch.Tensor,
weight_shape: torch.Tensor,
group_size: int = 32,
device: str | torch.device | None = None,
) torch.Tensor#

Dequantize Kimi INT4 packed weights to bfloat16.

The checkpoint stores eight offset-binary INT4 values in each int32 slot and carries per-group scales beside the packed tensor.

bridge.models.conversion.quantization_utils.quantize_to_int4(
weight: torch.Tensor,
group_size: int = 32,
scale_dtype: torch.dtype = torch.bfloat16,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Quantize bfloat16/float16 weights to Kimi INT4 packed format.