apply_mxfp8_scale#

nvmath.linalg.advanced.helpers.matmul.apply_mxfp8_scale(
x: torch.Tensor,
scales_1d: torch.Tensor,
output_dtype: Literal['smallest'] | torch.dtype = 'smallest',
) torch.Tensor[source]#

This function is experimental and potentially subject to future changes.

Apply MXFP8 block scale factors to a tensor x.

Parameters:
  • x – The tensor to which the scaling should be applied. Currently it must be a torch.Tensor.

  • scales_1d – The block scale factors (stored in cuBLAS-compatible interleaved layout) to apply. Its shape must be compatible with x, and currently it must also be a torch.Tensor.

  • output_dtype – Output dtype. If provided, must be a floating-point torch.dtype (float16, float32, or float64) and must be at least as wide as the smallest dtype that can represent the result, or ValueError is raised. If ‘smallest’ (default), the smallest dtype that can represent the result is automatically chosen.

Returns:

A tensor with values of x with scales applied, in the chosen or provided dtype.

Raises:

ValueError – When the result will over/underflow the requested dtype.

Behavior:

The operation is computed in float64. Then, the function determines the smallest dtype (float16, float32, or float64) that can represent the result without overflow or underflow. If output_dtype was passed, it must be at least as wide as that minimum otherwise ValueError is raised; if output_dtype='smallest', that minimum is used. The result is finally cast to the chosen dtype and returned.

Note

This function is not intended for production usage due to its relatively low performance and high memory consumption. Prefer result_type to request non-FP8 output.