expand_block_scale#
-
nvmath.
linalg. advanced. helpers. matmul. expand_block_scale( - scales_1d: torch.Tensor,
- operand_or_shape: torch.Tensor | tuple[int, ...],
- block_scaling_format: BlockScalingFormat,
- *,
- axis: Literal[-1, -2] | None = None,
- output_dtype: Literal['smallest'] | torch.dtype = 'smallest',
- device: Literal['cuda', 'cpu'] | None = None,
This function is experimental and potentially subject to future changes.
Expand NVFP4/MXFP8 block scales from 1D cuBLAS-compatible interleaved array to the full operand shape.
Matmul (cuBLAS) expects and returns the block scale factors in specific interleaved layout.
This function takes that 1D interleaved scale array (either provided as input or returned by cuBLASLt for NVFP4/MXFP8 output) and expands it to a full ND tensor with shape
operand_or_shapewhere each element gets its corresponding scale value. This can be useful, for example, to manually dequantize the result of a matmul, by elementwise multiplication of the expanded scales with the result.- Parameters:
scales_1d –
1D tensor of scale values with dtype:
for NVFP4:
torch.float8_e4m3fn, ortorch.uint8(interpreted astorch.float8_e4m3fn)for MXFP8:
torch.uint8, interpreted as exponent (UE8M0)
The scales are expected to be stored in cuBLAS-compatible interleaved layout (e.g. as returned by matmul’s
d_out_scale). The number of elements in the tensor must be equal to the number of elements in the operand tensor, divided by the number of elements in a block (for NVFP4: 16, for MXFP8: 32).operand_or_shape – Operand tensor or its logical (non-packed, non-blocked) shape. The scales are expanded to match this shape.
block_scaling_format – The block scaling format of the operand:
BlockScalingFormat.NVFP4orBlockScalingFormat.MXFP8. Internally, it is validated to be consistent with the operand dtype, and aValueErroris raised if not.axis –
The blocked dimension of the operand tensor. For example, for NVFP4/MXFP8 matmul, A is blocked in rows (
axis = -1), and B is blocked in columns (axis = -2). Depending onoperand_or_shape:if a shape is passed to
operand_or_shape, thenaxisis requiredif an operand is passed to
operand_or_shape, thenaxiscan be omitted and the blocked dimension is inferred from the operand’s layout.
output_dtype –
Output dtype. If provided, must be a torch’s dtype:
for NVFP4:
float8_e4m3fn,float16,float32, orfloat64for MXFP8:
uint8(exponentUE8M0),float16,float32, orfloat64
It must be wide enough to represent the result, or
ValueErroris raised. If ‘smallest’ (default), the smallest of accepted dtypes that can represent the result is automatically chosen (for MXFP8:uint8interpreted as exponent (UE8M0), for NVFP4:float8_e4m3fn).device – Device for the output tensor. When
None(default), the device is inferred fromscales_1d. When specified, must be"cuda"or"cpu".
- Returns:
Tensor with shape
operand_or_shape(and dtype as specified byoutput_dtype) containing expanded scales, on the target device. Each element contains the scale value that applies to the corresponding position in the FP4/FP8 matrix.
Note
For computing a single scale index rather than expanding all scales, use
get_block_scale_offset()instead.