bridge.models.kimi_vl.utils#
Module Contents#
Functions#
Dequantize INT4 packed weights to bfloat16. |
|
Quantize bfloat16/float16 weights to INT4 packed format. |
API#
- bridge.models.kimi_vl.utils.dequantize_int4(
- weight_packed: torch.Tensor,
- weight_scale: torch.Tensor,
- weight_shape: torch.Tensor,
- group_size: int = 32,
- device: str | torch.device | None = None,
Dequantize INT4 packed weights to bfloat16.
Extracts local tensors from DTensors before unpacking (bitwise ops donβt work on DTensor). Both weight_packed and weight_scale should have matching sharding so .to_local() gives corresponding slices automatically.
- Parameters:
weight_packed β INT4 packed weights [out_features, in_features // 8], may be DTensor
weight_scale β Per-group scales [out_features, num_groups], should be DTensor with same sharding
weight_shape β Original shape [2], stores global dimensions
group_size β Elements per scale group (default 32)
device β Target device for computation
- bridge.models.kimi_vl.utils.quantize_to_int4(
- weight: torch.Tensor,
- group_size: int = 32,
Quantize bfloat16/float16 weights to INT4 packed format.
- Returns:
INT4 values packed into int32 (8 values per int32) weight_scale: Per-group scale factors (float16) weight_shape: Original tensor shape (int64)
- Return type:
weight_packed