unpack_fp4#
-
nvmath.
linalg. advanced. helpers. matmul. unpack_fp4( - fp4_tensor: torch.Tensor,
- axis: Literal[-1, -2],
This function is experimental and potentially subject to future changes.
Unpack an N-D torch tensor with dtype
torch.float4_e2m1fn_x2to float32.Since each byte stores two FP4 values, the output tensor has one dimension doubled along
axis.axis=-1: The last dimension is the packed axis. Input shape(..., Q)with row-major layout produces output(..., 2*Q).axis=-2: The second-to-last dimension is the packed axis. Input shape(..., P, Q)with column-major layout produces output(..., 2*P, Q).
- Parameters:
fp4_tensor – FP4 tensor with dtype
torch.float4_e2m1fn_x2.axis – The axis along which the tensor was packed. Must be
-1(last dimension) or-2(second-to-last dimension).
- Returns:
A torch tensor with dtype float32 with the unpacked shape on the same device as the input.
See also
quantize_to_fp4()— quantize and pack float32 values to FP4.