nemo_automodel.components.moe.mxfp8#
torchao MXFP8 grouped-GEMM plumbing for the experts="torch_mm_mxfp8" path.
torchao exposes a drop-in differentiable replacement for torch._grouped_mm that
dynamically quantizes both operands to MXFP8 (e4m3 data + e8m0 block scales,
block_size=32). It mirrors torch._grouped_mm’s contract exactly: 2D activations
(M*num_groups, K), 3D [E, K, N] stacked expert weights, int32 offs group
boundaries — so no transpose is needed for Automodel’s gate_and_up_projs
([E, dim, up]) or down_projs ([E, inter, dim]).
torchao is unpinned and (when present) comes from the base image rather than the uv
lock, so the API generation is resolved defensively at runtime across known versions
and normalized to a uniform mxfp8_grouped_mm(A, B, offs) callable. If torchao is
missing entirely, the runtime gate falls back to torch._grouped_mm.
Public entry: :func:select_grouped_mm returns the grouped-GEMM callable the expert
forward should use (mxfp8 with the contiguous-operand relayout when requested and
available, else plain torch._grouped_mm).
Module Contents#
Functions#
Resolve a torchao MXFP8 grouped-GEMM callable, normalizing across API generations. |
|
Return the MXFP8 grouped-GEMM callable iff it is usable on this device. |
|
Fallback grouped GEMM (plain |
|
Lay the [E,K,N] expert weight out so its (-2,-1) transpose is contiguous. |
|
Return the grouped-GEMM callable |
Data#
API#
- nemo_automodel.components.moe.mxfp8._MXFP8_GROUPED_MM#
None
- nemo_automodel.components.moe.mxfp8._MXFP8_RESOLVED#
False
- nemo_automodel.components.moe.mxfp8._MXFP8_FALLBACK_WARNED#
False
- nemo_automodel.components.moe.mxfp8._MXFP8_ACTIVE_ANNOUNCED#
False
- nemo_automodel.components.moe.mxfp8._resolve_mxfp8_grouped_mm()#
Resolve a torchao MXFP8 grouped-GEMM callable, normalizing across API generations.
Returns a callable
mxfp8_grouped_mm(A, B, offs)mirroringtorch._grouped_mm, orNoneif no supported torchao API is importable. The result is cached.
- nemo_automodel.components.moe.mxfp8._mxfp8_grouped_mm_or_none()#
Return the MXFP8 grouped-GEMM callable iff it is usable on this device.
Requires CUDA with compute capability >= 10 (GB200/sm_100+) AND a successful torchao import. Otherwise returns
None(callers fall back totorch._grouped_mm). Emits a one-time warning when MXFP8 was requested but is unavailable.
- nemo_automodel.components.moe.mxfp8._default_grouped_mm(A, B, offs)#
Fallback grouped GEMM (plain
torch._grouped_mm) used when MXFP8 is off.
- nemo_automodel.components.moe.mxfp8._mxfp8_weight_relayout(B)#
Lay the [E,K,N] expert weight out so its (-2,-1) transpose is contiguous.
torchao’s MXFP8 quantizer calls
to_mx(B.transpose(-2,-1))and strictly asserts the input is contiguous, so the weight must be stored as [E,N,K]-contiguous (viewed as [E,K,N]) — also the column-major B_t layout torchao’s grouped GEMM wants.
- nemo_automodel.components.moe.mxfp8.select_grouped_mm(use_mxfp8)#
Return the grouped-GEMM callable
grouped_mm(A, B, offs)for the expert GEMMs.When
use_mxfp8and the torchao MXFP8 kernel is usable on this device, returns a wrapper that makes both operands contiguous in the layout torchao requires (A contiguous; B relaid out so its transpose is contiguous — see _mxfp8_weight_relayout) and routes through it. Otherwise returns the plaintorch._grouped_mmfallback, leaving the bf16 path byte-identical. Shared by the no-bias helper and the inline bias paths so dispatch + relayout are defined once.