unpack_fp4#

nvmath.linalg.advanced.helpers.matmul.unpack_fp4(
fp4_tensor: torch.Tensor,
axis: Literal[-1, -2],
) torch.Tensor[source]#

This function is experimental and potentially subject to future changes.

Unpack an N-D torch tensor with dtype torch.float4_e2m1fn_x2 to float32.

Since each byte stores two FP4 values, the output tensor has one dimension doubled along axis.

  • axis=-1: The last dimension is the packed axis. Input shape (..., Q) with row-major layout produces output (..., 2*Q).

  • axis=-2: The second-to-last dimension is the packed axis. Input shape (..., P, Q) with column-major layout produces output (..., 2*P, Q).

Parameters:
  • fp4_tensor – FP4 tensor with dtype torch.float4_e2m1fn_x2.

  • axis – The axis along which the tensor was packed. Must be -1 (last dimension) or -2 (second-to-last dimension).

Returns:

A torch tensor with dtype float32 with the unpacked shape on the same device as the input.

See also

quantize_to_fp4() — quantize and pack float32 values to FP4.