nemo_automodel.components.moe.mxfp8
nemo_automodel.components.moe.mxfp8
torchao MXFP8 grouped-GEMM plumbing for the experts="torch_mm_mxfp8" path.
torchao exposes a drop-in differentiable replacement for torch._grouped_mm that
dynamically quantizes both operands to MXFP8 (e4m3 data + e8m0 block scales,
block_size=32). It mirrors torch._grouped_mm’s contract exactly: 2D activations
(M*num_groups, K), 3D [E, K, N] stacked expert weights, int32 offs group
boundaries — so no transpose is needed for Automodel’s gate_and_up_projs
([E, dim, up]) or down_projs ([E, inter, dim]).
torchao is unpinned and (when present) comes from the base image rather than the uv
lock, so the API generation is resolved defensively at runtime across known versions
and normalized to a uniform mxfp8_grouped_mm(A, B, offs) callable. If torchao is
missing entirely, the runtime gate falls back to torch._grouped_mm.
Public entry: :func:select_grouped_mm returns the grouped-GEMM callable the expert
forward should use (mxfp8 with the contiguous-operand relayout when requested and
available, else plain torch._grouped_mm).
Module Contents
Functions
Data
API
Fallback grouped GEMM (plain torch._grouped_mm) used when MXFP8 is off.
Return the MXFP8 grouped-GEMM callable iff it is usable on this device.
Requires CUDA with compute capability >= 10 (GB200/sm_100+) AND a successful
torchao import. Otherwise returns None (callers fall back to
torch._grouped_mm). Emits a one-time warning when MXFP8 was requested but is
unavailable.
Lay the [E,K,N] expert weight out so its (-2,-1) transpose is contiguous.
torchao’s MXFP8 quantizer calls to_mx(B.transpose(-2,-1)) and strictly asserts
the input is contiguous, so the weight must be stored as [E,N,K]-contiguous (viewed
as [E,K,N]) — also the column-major B_t layout torchao’s grouped GEMM wants.
Resolve a torchao MXFP8 grouped-GEMM callable, normalizing across API generations.
Returns a callable mxfp8_grouped_mm(A, B, offs) mirroring torch._grouped_mm,
or None if no supported torchao API is importable. The result is cached.
Return the grouped-GEMM callable grouped_mm(A, B, offs) for the expert GEMMs.
When use_mxfp8 and the torchao MXFP8 kernel is usable on this device, returns a
wrapper that makes both operands contiguous in the layout torchao requires (A
contiguous; B relaid out so its transpose is contiguous — see _mxfp8_weight_relayout)
and routes through it. Otherwise returns the plain torch._grouped_mm fallback,
leaving the bf16 path byte-identical. Shared by the no-bias helper and the inline
bias paths so dispatch + relayout are defined once.