nemo_automodel.components.moe.mxfp8#

torchao MXFP8 grouped-GEMM plumbing for the experts="torch_mm_mxfp8" path.

torchao exposes a drop-in differentiable replacement for torch._grouped_mm that dynamically quantizes both operands to MXFP8 (e4m3 data + e8m0 block scales, block_size=32). It mirrors torch._grouped_mm’s contract exactly: 2D activations (M*num_groups, K), 3D [E, K, N] stacked expert weights, int32 offs group boundaries — so no transpose is needed for Automodel’s gate_and_up_projs ([E, dim, up]) or down_projs ([E, inter, dim]).

torchao is unpinned and (when present) comes from the base image rather than the uv lock, so the API generation is resolved defensively at runtime across known versions and normalized to a uniform mxfp8_grouped_mm(A, B, offs) callable. If torchao is missing entirely, the runtime gate falls back to torch._grouped_mm.

Public entry: :func:select_grouped_mm returns the grouped-GEMM callable the expert forward should use (mxfp8 with the contiguous-operand relayout when requested and available, else plain torch._grouped_mm).

Module Contents#

Functions#

_resolve_mxfp8_grouped_mm

Resolve a torchao MXFP8 grouped-GEMM callable, normalizing across API generations.

_mxfp8_grouped_mm_or_none

Return the MXFP8 grouped-GEMM callable iff it is usable on this device.

_default_grouped_mm

Fallback grouped GEMM (plain torch._grouped_mm) used when MXFP8 is off.

_mxfp8_weight_relayout

Lay the [E,K,N] expert weight out so its (-2,-1) transpose is contiguous.

select_grouped_mm

Return the grouped-GEMM callable grouped_mm(A, B, offs) for the expert GEMMs.

Data#

API#

nemo_automodel.components.moe.mxfp8._MXFP8_GROUPED_MM#

None

nemo_automodel.components.moe.mxfp8._MXFP8_RESOLVED#

False

nemo_automodel.components.moe.mxfp8._MXFP8_FALLBACK_WARNED#

False

nemo_automodel.components.moe.mxfp8._MXFP8_ACTIVE_ANNOUNCED#

False

nemo_automodel.components.moe.mxfp8._resolve_mxfp8_grouped_mm()#

Resolve a torchao MXFP8 grouped-GEMM callable, normalizing across API generations.

Returns a callable mxfp8_grouped_mm(A, B, offs) mirroring torch._grouped_mm, or None if no supported torchao API is importable. The result is cached.

nemo_automodel.components.moe.mxfp8._mxfp8_grouped_mm_or_none()#

Return the MXFP8 grouped-GEMM callable iff it is usable on this device.

Requires CUDA with compute capability >= 10 (GB200/sm_100+) AND a successful torchao import. Otherwise returns None (callers fall back to torch._grouped_mm). Emits a one-time warning when MXFP8 was requested but is unavailable.

nemo_automodel.components.moe.mxfp8._default_grouped_mm(A, B, offs)#

Fallback grouped GEMM (plain torch._grouped_mm) used when MXFP8 is off.

nemo_automodel.components.moe.mxfp8._mxfp8_weight_relayout(B)#

Lay the [E,K,N] expert weight out so its (-2,-1) transpose is contiguous.

torchao’s MXFP8 quantizer calls to_mx(B.transpose(-2,-1)) and strictly asserts the input is contiguous, so the weight must be stored as [E,N,K]-contiguous (viewed as [E,K,N]) — also the column-major B_t layout torchao’s grouped GEMM wants.

nemo_automodel.components.moe.mxfp8.select_grouped_mm(use_mxfp8)#

Return the grouped-GEMM callable grouped_mm(A, B, offs) for the expert GEMMs.

When use_mxfp8 and the torchao MXFP8 kernel is usable on this device, returns a wrapper that makes both operands contiguous in the layout torchao requires (A contiguous; B relaid out so its transpose is contiguous — see _mxfp8_weight_relayout) and routes through it. Otherwise returns the plain torch._grouped_mm fallback, leaving the bf16 path byte-identical. Shared by the no-bias helper and the inline bias paths so dispatch + relayout are defined once.