bridge.models.conversion.quantization_utils#
Module Contents#
Functions#
Return whether tensor uses one of PyTorch’s FP8 dtypes. |
|
Dequantize FP8 weights with one scale per 2D block. |
|
Dequantize FP8 block-scaled weights, falling back to a plain cast. |
|
Dequantize FP8 weights with a scalar or broadcastable scale tensor. |
|
Dequantize GPT-OSS MXFP4 block/scales tensors. |
|
Dequantize Kimi INT4 packed weights to bfloat16. |
|
Quantize bfloat16/float16 weights to Kimi INT4 packed format. |
Data#
API#
- bridge.models.conversion.quantization_utils.FP8_BLOCK_SIZE#
128
- bridge.models.conversion.quantization_utils.FP8_DTYPES#
()
- bridge.models.conversion.quantization_utils.is_fp8_tensor(tensor: torch.Tensor) bool#
Return whether tensor uses one of PyTorch’s FP8 dtypes.
- bridge.models.conversion.quantization_utils.dequantize_fp8_blockwise(
- weight: torch.Tensor,
- scale_inv: torch.Tensor,
- *,
- block_size: int = FP8_BLOCK_SIZE,
- dtype: torch.dtype = torch.bfloat16,
Dequantize FP8 weights with one scale per 2D block.
DeepSeek-V3 and MiniMax-M2 store linear weights as FP8 tensors with a separate
*_scale_invtensor. Each scale applies to one 128x128 weight block by default.
- bridge.models.conversion.quantization_utils.maybe_dequantize_fp8_blockwise(
- weight: torch.Tensor,
- scale_inv: torch.Tensor | None = None,
- *,
- block_size: int = FP8_BLOCK_SIZE,
- dtype: torch.dtype = torch.bfloat16,
Dequantize FP8 block-scaled weights, falling back to a plain cast.
- bridge.models.conversion.quantization_utils.maybe_dequantize_fp8(
- weight: torch.Tensor,
- scale_inv: torch.Tensor | None = None,
- *,
- dtype: torch.dtype = torch.bfloat16,
Dequantize FP8 weights with a scalar or broadcastable scale tensor.
- bridge.models.conversion.quantization_utils.dequantize_mxfp4(
- blocks: torch.Tensor,
- scales: torch.Tensor,
- *,
- dtype: torch.dtype = torch.bfloat16,
- rows_per_chunk: int = 32768 * 1024,
Dequantize GPT-OSS MXFP4 block/scales tensors.
- bridge.models.conversion.quantization_utils.dequantize_int4(
- weight_packed: torch.Tensor,
- weight_scale: torch.Tensor,
- weight_shape: torch.Tensor,
- group_size: int = 32,
- device: str | torch.device | None = None,
Dequantize Kimi INT4 packed weights to bfloat16.
The checkpoint stores eight offset-binary INT4 values in each int32 slot and carries per-group scales beside the packed tensor.
- bridge.models.conversion.quantization_utils.quantize_to_int4(
- weight: torch.Tensor,
- group_size: int = 32,
- scale_dtype: torch.dtype = torch.bfloat16,
Quantize bfloat16/float16 weights to Kimi INT4 packed format.