quantize_to_fp4#
-
nvmath.
linalg. advanced. helpers. matmul. quantize_to_fp4( - x: torch.Tensor,
- axis: Literal[-1, -2],
This function is experimental and potentially subject to future changes.
Quantize a torch tensor to
torch.float4_e2m1fn_x2dtype.The function supports 1D, 2D, and higher-dimensional input torch tensors with dtype float32. It quantizes each float32 value to the nearest representable FP4 value and packs two 4-bit codes per byte, halving the packed dimension. The packing direction is controlled by
axis:axis=-1: Packs consecutive elements along the last dimension. Input shape(..., Q)produces output(..., Q//2)with row-major layout (last stride = 1).axis=-2: Packs consecutive elements along the second-to-last dimension. Input shape(..., P, Q)produces output(..., P//2, Q)with column-major layout (second-to-last stride = 1).
- Parameters:
x – Torch tensor with dtype float32 (1D, 2D, or higher-dimensional).
axis – The axis along which to pack. Must be
-1(last dimension) or-2(second-to-last dimension).
- Returns:
Torch tensor with dtype
torch.float4_e2m1fn_x2on the same device as the input.
Important
The packed dimension must have even size.
Note
This helper quantizes a single tensor and is suitable for understanding how packing for
torch.float4_e2m1fn_x2works in practice and/or for experimenting with FP4 GEMMs outside of typical deep-learning workflows. It is not fully optimized for performance but should be adequate for most common use cases. For production whole-model quantization, consider tools such as torchao or bitsandbytes.See also
unpack_fp4()— decode packed FP4 values back to float32.