apply_mxfp8_scale#
-
nvmath.
linalg. advanced. helpers. matmul. apply_mxfp8_scale( - x: torch.Tensor,
- scales_1d: torch.Tensor,
- output_dtype: Literal['smallest'] | torch.dtype = 'smallest',
This function is experimental and potentially subject to future changes.
Apply MXFP8 block scale factors to a tensor
x.- Parameters:
x – The tensor to which the scaling should be applied. Currently it must be a
torch.Tensor.scales_1d – The block scale factors (stored in cuBLAS-compatible interleaved layout) to apply. Its shape must be compatible with
x, and currently it must also be atorch.Tensor.output_dtype – Output dtype. If provided, must be a floating-point
torch.dtype(float16, float32, or float64) and must be at least as wide as the smallest dtype that can represent the result, orValueErroris raised. If ‘smallest’ (default), the smallest dtype that can represent the result is automatically chosen.
- Returns:
A tensor with values of
xwith scales applied, in the chosen or provided dtype.- Raises:
ValueError – When the result will over/underflow the requested dtype.
- Behavior:
The operation is computed in float64. Then, the function determines the smallest dtype (float16, float32, or float64) that can represent the result without overflow or underflow. If
output_dtypewas passed, it must be at least as wide as that minimum otherwiseValueErroris raised; ifoutput_dtype='smallest', that minimum is used. The result is finally cast to the chosen dtype and returned.
Note
This function is not intended for production usage due to its relatively low performance and high memory consumption. Prefer
result_typeto request non-FP8 output.