quantize_to_fp4#

nvmath.linalg.advanced.helpers.matmul.quantize_to_fp4(
x: torch.Tensor,
axis: Literal[-1, -2],
) torch.Tensor[source]#

This function is experimental and potentially subject to future changes.

Quantize a torch tensor to torch.float4_e2m1fn_x2 dtype.

The function supports 1D, 2D, and higher-dimensional input torch tensors with dtype float32. It quantizes each float32 value to the nearest representable FP4 value and packs two 4-bit codes per byte, halving the packed dimension. The packing direction is controlled by axis:

  • axis=-1: Packs consecutive elements along the last dimension. Input shape (..., Q) produces output (..., Q//2) with row-major layout (last stride = 1).

  • axis=-2: Packs consecutive elements along the second-to-last dimension. Input shape (..., P, Q) produces output (..., P//2, Q) with column-major layout (second-to-last stride = 1).

Parameters:
  • x – Torch tensor with dtype float32 (1D, 2D, or higher-dimensional).

  • axis – The axis along which to pack. Must be -1 (last dimension) or -2 (second-to-last dimension).

Returns:

Torch tensor with dtype torch.float4_e2m1fn_x2 on the same device as the input.

Important

The packed dimension must have even size.

Note

This helper quantizes a single tensor and is suitable for understanding how packing for torch.float4_e2m1fn_x2 works in practice and/or for experimenting with FP4 GEMMs outside of typical deep-learning workflows. It is not fully optimized for performance but should be adequate for most common use cases. For production whole-model quantization, consider tools such as torchao or bitsandbytes.

See also

unpack_fp4() — decode packed FP4 values back to float32.