expand_block_scale#

nvmath.linalg.advanced.helpers.matmul.expand_block_scale(
scales_1d: torch.Tensor,
operand_or_shape: torch.Tensor | tuple[int, ...],
block_scaling_format: BlockScalingFormat,
*,
axis: Literal[-1, -2] | None = None,
output_dtype: Literal['smallest'] | torch.dtype = 'smallest',
device: Literal['cuda', 'cpu'] | None = None,
) torch.Tensor[source]#

This function is experimental and potentially subject to future changes.

Expand NVFP4/MXFP8 block scales from 1D cuBLAS-compatible interleaved array to the full operand shape.

Matmul (cuBLAS) expects and returns the block scale factors in specific interleaved layout.

This function takes that 1D interleaved scale array (either provided as input or returned by cuBLASLt for NVFP4/MXFP8 output) and expands it to a full ND tensor with shape operand_or_shape where each element gets its corresponding scale value. This can be useful, for example, to manually dequantize the result of a matmul, by elementwise multiplication of the expanded scales with the result.

Parameters:
  • scales_1d

    1D tensor of scale values with dtype:

    • for NVFP4: torch.float8_e4m3fn, or torch.uint8 (interpreted as torch.float8_e4m3fn)

    • for MXFP8: torch.uint8, interpreted as exponent (UE8M0)

    The scales are expected to be stored in cuBLAS-compatible interleaved layout (e.g. as returned by matmul’s d_out_scale). The number of elements in the tensor must be equal to the number of elements in the operand tensor, divided by the number of elements in a block (for NVFP4: 16, for MXFP8: 32).

  • operand_or_shape – Operand tensor or its logical (non-packed, non-blocked) shape. The scales are expanded to match this shape.

  • block_scaling_format – The block scaling format of the operand: BlockScalingFormat.NVFP4 or BlockScalingFormat.MXFP8. Internally, it is validated to be consistent with the operand dtype, and a ValueError is raised if not.

  • axis

    The blocked dimension of the operand tensor. For example, for NVFP4/MXFP8 matmul, A is blocked in rows (axis = -1), and B is blocked in columns (axis = -2). Depending on operand_or_shape:

    • if a shape is passed to operand_or_shape, then axis is required

    • if an operand is passed to operand_or_shape, then axis can be omitted and the blocked dimension is inferred from the operand’s layout.

  • output_dtype

    Output dtype. If provided, must be a torch’s dtype:

    • for NVFP4: float8_e4m3fn, float16, float32, or float64

    • for MXFP8: uint8 (exponent UE8M0), float16, float32, or float64

    It must be wide enough to represent the result, or ValueError is raised. If ‘smallest’ (default), the smallest of accepted dtypes that can represent the result is automatically chosen (for MXFP8: uint8 interpreted as exponent (UE8M0), for NVFP4: float8_e4m3fn).

  • device – Device for the output tensor. When None (default), the device is inferred from scales_1d. When specified, must be "cuda" or "cpu".

Returns:

Tensor with shape operand_or_shape (and dtype as specified by output_dtype) containing expanded scales, on the target device. Each element contains the scale value that applies to the corresponding position in the FP4/FP8 matrix.

Note

For computing a single scale index rather than expanding all scales, use get_block_scale_offset() instead.