nemo_automodel.components.moe.mxfp8

View as Markdown

torchao MXFP8 grouped-GEMM plumbing for the experts="torch_mm_mxfp8" path.

torchao exposes a drop-in differentiable replacement for torch._grouped_mm that dynamically quantizes both operands to MXFP8 (e4m3 data + e8m0 block scales, block_size=32). It mirrors torch._grouped_mm’s contract exactly: 2D activations (M*num_groups, K), 3D [E, K, N] stacked expert weights, int32 offs group boundaries — so no transpose is needed for Automodel’s gate_and_up_projs ([E, dim, up]) or down_projs ([E, inter, dim]).

torchao is unpinned and (when present) comes from the base image rather than the uv lock, so the API generation is resolved defensively at runtime across known versions and normalized to a uniform mxfp8_grouped_mm(A, B, offs) callable. If torchao is missing entirely, the runtime gate falls back to torch._grouped_mm.

Public entry: :func:select_grouped_mm returns the grouped-GEMM callable the expert forward should use (mxfp8 with the contiguous-operand relayout when requested and available, else plain torch._grouped_mm).

Module Contents

Functions

NameDescription
_default_grouped_mmFallback grouped GEMM (plain torch._grouped_mm) used when MXFP8 is off.
_mxfp8_grouped_mm_or_noneReturn the MXFP8 grouped-GEMM callable iff it is usable on this device.
_mxfp8_weight_relayoutLay the [E,K,N] expert weight out so its (-2,-1) transpose is contiguous.
_resolve_mxfp8_grouped_mmResolve a torchao MXFP8 grouped-GEMM callable, normalizing across API generations.
select_grouped_mmReturn the grouped-GEMM callable grouped_mm(A, B, offs) for the expert GEMMs.

Data

_MXFP8_ACTIVE_ANNOUNCED

_MXFP8_FALLBACK_WARNED

_MXFP8_GROUPED_MM

_MXFP8_RESOLVED

API

nemo_automodel.components.moe.mxfp8._default_grouped_mm(
A,
B,
offs
)

Fallback grouped GEMM (plain torch._grouped_mm) used when MXFP8 is off.

nemo_automodel.components.moe.mxfp8._mxfp8_grouped_mm_or_none()

Return the MXFP8 grouped-GEMM callable iff it is usable on this device.

Requires CUDA with compute capability >= 10 (GB200/sm_100+) AND a successful torchao import. Otherwise returns None (callers fall back to torch._grouped_mm). Emits a one-time warning when MXFP8 was requested but is unavailable.

nemo_automodel.components.moe.mxfp8._mxfp8_weight_relayout(
B
)

Lay the [E,K,N] expert weight out so its (-2,-1) transpose is contiguous.

torchao’s MXFP8 quantizer calls to_mx(B.transpose(-2,-1)) and strictly asserts the input is contiguous, so the weight must be stored as [E,N,K]-contiguous (viewed as [E,K,N]) — also the column-major B_t layout torchao’s grouped GEMM wants.

nemo_automodel.components.moe.mxfp8._resolve_mxfp8_grouped_mm()

Resolve a torchao MXFP8 grouped-GEMM callable, normalizing across API generations.

Returns a callable mxfp8_grouped_mm(A, B, offs) mirroring torch._grouped_mm, or None if no supported torchao API is importable. The result is cached.

nemo_automodel.components.moe.mxfp8.select_grouped_mm(
use_mxfp8
)

Return the grouped-GEMM callable grouped_mm(A, B, offs) for the expert GEMMs.

When use_mxfp8 and the torchao MXFP8 kernel is usable on this device, returns a wrapper that makes both operands contiguous in the layout torchao requires (A contiguous; B relaid out so its transpose is contiguous — see _mxfp8_weight_relayout) and routes through it. Otherwise returns the plain torch._grouped_mm fallback, leaving the bf16 path byte-identical. Shared by the no-bias helper and the inline bias paths so dispatch + relayout are defined once.

nemo_automodel.components.moe.mxfp8._MXFP8_ACTIVE_ANNOUNCED = False
nemo_automodel.components.moe.mxfp8._MXFP8_FALLBACK_WARNED = False
nemo_automodel.components.moe.mxfp8._MXFP8_GROUPED_MM = None
nemo_automodel.components.moe.mxfp8._MXFP8_RESOLVED = False