bridge.models.kimi_vl.utils#

Module Contents#

Functions#

dequantize_int4

Dequantize INT4 packed weights to bfloat16.

quantize_to_int4

Quantize bfloat16/float16 weights to INT4 packed format.

API#

bridge.models.kimi_vl.utils.dequantize_int4(
weight_packed: torch.Tensor,
weight_scale: torch.Tensor,
weight_shape: torch.Tensor,
group_size: int = 32,
device: str | torch.device | None = None,
) torch.Tensor#

Dequantize INT4 packed weights to bfloat16.

Extracts local tensors from DTensors before unpacking (bitwise ops don’t work on DTensor). Both weight_packed and weight_scale should have matching sharding so .to_local() gives corresponding slices automatically.

Parameters:
  • weight_packed – INT4 packed weights [out_features, in_features // 8], may be DTensor

  • weight_scale – Per-group scales [out_features, num_groups], should be DTensor with same sharding

  • weight_shape – Original shape [2], stores global dimensions

  • group_size – Elements per scale group (default 32)

  • device – Target device for computation

bridge.models.kimi_vl.utils.quantize_to_int4(
weight: torch.Tensor,
group_size: int = 32,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Quantize bfloat16/float16 weights to INT4 packed format.

Returns:

INT4 values packed into int32 (8 values per int32) weight_scale: Per-group scale factors (float16) weight_shape: Original tensor shape (int64)

Return type:

weight_packed