bridge.models.conversion.quantization_utils#
Module Contents#
Functions#
Return whether tensor uses one of PyTorch’s FP8 dtypes. |
|
Return whether dtype is PyTorch’s E8M0 scale dtype. |
|
Build positive quantization scales in the same scale family as |
|
Dequantize FP8 weights with one scale per 2D block. |
|
Dequantize FP8 block-scaled weights, falling back to a plain cast. |
|
Dequantize FP8 weights with a scalar or broadcastable scale tensor. |
|
Dequantize FP8 E4M3 weights with a companion scale tensor. |
|
Quantize a 2-D weight to FP8 E4M3 using |
|
Dequantize GPT-OSS MXFP4 block/scales tensors. |
|
Dequantize MXFP4 E2M1 weights packed two values per byte. |
|
Return whether |
|
Quantize a 2-D weight to packed MXFP4 E2M1 using source scale geometry. |
|
Load and dequantize HF |
|
Recreate quantized HF |
|
Dequantize Kimi INT4 packed weights to bfloat16. |
|
Quantize bfloat16/float16 weights to Kimi INT4 packed format. |
Data#
API#
- bridge.models.conversion.quantization_utils.FP8_BLOCK_SIZE#
128
- bridge.models.conversion.quantization_utils.FP8_DTYPES#
()
- bridge.models.conversion.quantization_utils.FP8_E4M3_MAX#
448.0
- bridge.models.conversion.quantization_utils.FP4_E2M1_MAX#
6.0
- bridge.models.conversion.quantization_utils.MXFP4_BLOCK_SIZE#
32
- bridge.models.conversion.quantization_utils._FP4_E2M1_TABLE_VALUES#
None
- bridge.models.conversion.quantization_utils.is_fp8_tensor(tensor: torch.Tensor) bool#
Return whether tensor uses one of PyTorch’s FP8 dtypes.
- bridge.models.conversion.quantization_utils.is_float8_e8m0_dtype(dtype: torch.dtype) bool#
Return whether dtype is PyTorch’s E8M0 scale dtype.
- bridge.models.conversion.quantization_utils.scale_from_amax(
- amax: torch.Tensor,
- max_quantized_value: float,
- scale_dtype: torch.dtype,
Build positive quantization scales in the same scale family as
scale_dtype.
- bridge.models.conversion.quantization_utils.dequantize_fp8_blockwise(
- weight: torch.Tensor,
- scale_inv: torch.Tensor,
- *,
- block_size: int = FP8_BLOCK_SIZE,
- dtype: torch.dtype = torch.bfloat16,
Dequantize FP8 weights with one scale per 2D block.
DeepSeek-V3 and MiniMax-M2 store linear weights as FP8 tensors with a separate
*_scale_invtensor. Each scale applies to one 128x128 weight block by default.
- bridge.models.conversion.quantization_utils.maybe_dequantize_fp8_blockwise(
- weight: torch.Tensor,
- scale_inv: torch.Tensor | None = None,
- *,
- block_size: int = FP8_BLOCK_SIZE,
- dtype: torch.dtype = torch.bfloat16,
Dequantize FP8 block-scaled weights, falling back to a plain cast.
- bridge.models.conversion.quantization_utils.maybe_dequantize_fp8(
- weight: torch.Tensor,
- scale_inv: torch.Tensor | None = None,
- *,
- dtype: torch.dtype = torch.bfloat16,
Dequantize FP8 weights with a scalar or broadcastable scale tensor.
- bridge.models.conversion.quantization_utils.dequantize_fp8_e4m3fn_with_scale(
- weight: torch.Tensor,
- scale: torch.Tensor,
- *,
- name: str = '',
- block_size: int = FP8_BLOCK_SIZE,
- dtype: torch.dtype = torch.bfloat16,
Dequantize FP8 E4M3 weights with a companion scale tensor.
Supports three common HF checkpoint scale layouts: 1-D per-row or row-block scales, 2-D per-row K tiles, and 2-D block scales.
- bridge.models.conversion.quantization_utils._quantize_fp8_2d_blocks(
- weight: torch.Tensor,
- source_scale: torch.Tensor,
- *,
- name: str = '',
- block_size: int = FP8_BLOCK_SIZE,
- bridge.models.conversion.quantization_utils._quantize_fp8_per_row_tiles(
- weight: torch.Tensor,
- source_scale: torch.Tensor,
- *,
- name: str = '',
- bridge.models.conversion.quantization_utils._quantize_fp8_1d_scale(
- weight: torch.Tensor,
- source_scale: torch.Tensor,
- *,
- name: str = '',
- block_size: int = FP8_BLOCK_SIZE,
- bridge.models.conversion.quantization_utils.quantize_fp8_e4m3fn_like_scale(
- weight: torch.Tensor,
- source_scale: torch.Tensor,
- *,
- name: str = '',
- block_size: int = FP8_BLOCK_SIZE,
Quantize a 2-D weight to FP8 E4M3 using
source_scalegeometry and dtype.
- bridge.models.conversion.quantization_utils.dequantize_mxfp4(
- blocks: torch.Tensor,
- scales: torch.Tensor,
- *,
- dtype: torch.dtype = torch.bfloat16,
- rows_per_chunk: int = 32768 * 1024,
Dequantize GPT-OSS MXFP4 block/scales tensors.
- bridge.models.conversion.quantization_utils.dequantize_mxfp4_e2m1_packed(
- weight_packed: torch.Tensor,
- scale: torch.Tensor,
- *,
- dtype: torch.dtype = torch.bfloat16,
Dequantize MXFP4 E2M1 weights packed two values per byte.
scaleis expected to be one scale per row and per K tile. E8M0 scale tensors can be passed directly;.to(torch.float32)materializes their power-of-two values.
- bridge.models.conversion.quantization_utils.is_mxfp4_e2m1_scale_geometry(
- weight: torch.Tensor,
- source_scale: torch.Tensor,
- *,
- block_size: int = MXFP4_BLOCK_SIZE,
Return whether
source_scaledescribes packed MXFP4 E2M1 K tiles.
- bridge.models.conversion.quantization_utils.quantize_mxfp4_e2m1_like_scale(
- weight: torch.Tensor,
- source_scale: torch.Tensor,
- *,
- name: str = '',
- block_size: int = MXFP4_BLOCK_SIZE,
Quantize a 2-D weight to packed MXFP4 E2M1 using source scale geometry.
- bridge.models.conversion.quantization_utils.maybe_dequantize_hf_quantized_weight(
- hf_param: str | dict[str, str],
- hf_state_dict: collections.abc.Mapping[str, torch.Tensor],
- *,
- dtype: torch.dtype = torch.bfloat16,
Load and dequantize HF
*.weighttensors that carry sibling*.scaletensors.
- bridge.models.conversion.quantization_utils.requantize_hf_weight_scale_pairs(
- converted_weights_dict: collections.abc.Mapping[str, torch.Tensor],
- hf_state_dict: collections.abc.Mapping[str, torch.Tensor],
- *,
- use_mxfp4: collections.abc.Callable[[str, torch.Tensor, torch.Tensor], bool] | None = None,
Recreate quantized HF
*.weight/*.scalepairs using source scale layout.use_mxfp4lets model bridges opt specific parameters into packed MXFP4 output. Other scaled weights are emitted as FP8 E4M3.
- bridge.models.conversion.quantization_utils.dequantize_int4(
- weight_packed: torch.Tensor,
- weight_scale: torch.Tensor,
- weight_shape: torch.Tensor,
- group_size: int = 32,
- device: str | torch.device | None = None,
Dequantize Kimi INT4 packed weights to bfloat16.
The checkpoint stores eight offset-binary INT4 values in each int32 slot and carries per-group scales beside the packed tensor.
- bridge.models.conversion.quantization_utils.quantize_to_int4(
- weight: torch.Tensor,
- group_size: int = 32,
- scale_dtype: torch.dtype = torch.bfloat16,
Quantize bfloat16/float16 weights to Kimi INT4 packed format.