> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.moe.mxfp8

torchao MXFP8 grouped-GEMM plumbing for the `experts="torch_mm_mxfp8"` path.

torchao exposes a drop-in differentiable replacement for `torch._grouped_mm` that
dynamically quantizes both operands to MXFP8 (e4m3 data + e8m0 block scales,
block\_size=32). It mirrors `torch._grouped_mm`'s contract exactly: 2D activations
(M\*num\_groups, K), 3D \[E, K, N] stacked expert weights, int32 `offs` group
boundaries — so no transpose is needed for Automodel's gate\_and\_up\_projs
(\[E, dim, up]) or down\_projs (\[E, inter, dim]).

torchao is unpinned and (when present) comes from the base image rather than the uv
lock, so the API generation is resolved defensively at runtime across known versions
and normalized to a uniform `mxfp8_grouped_mm(A, B, offs)` callable. If torchao is
missing entirely, the runtime gate falls back to `torch._grouped_mm`.

Public entry: :func:`select_grouped_mm` returns the grouped-GEMM callable the expert
forward should use (mxfp8 with the contiguous-operand relayout when requested and
available, else plain `torch._grouped_mm`).

## Module Contents

### Functions

| Name                                                                                          | Description                                                                        |
| --------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- |
| [`_default_grouped_mm`](#nemo_automodel-components-moe-mxfp8-_default_grouped_mm)             | Fallback grouped GEMM (plain `torch._grouped_mm`) used when MXFP8 is off.          |
| [`_mxfp8_grouped_mm_or_none`](#nemo_automodel-components-moe-mxfp8-_mxfp8_grouped_mm_or_none) | Return the MXFP8 grouped-GEMM callable iff it is usable on this device.            |
| [`_mxfp8_weight_relayout`](#nemo_automodel-components-moe-mxfp8-_mxfp8_weight_relayout)       | Lay the \[E,K,N] expert weight out so its (-2,-1) transpose is contiguous.         |
| [`_resolve_mxfp8_grouped_mm`](#nemo_automodel-components-moe-mxfp8-_resolve_mxfp8_grouped_mm) | Resolve a torchao MXFP8 grouped-GEMM callable, normalizing across API generations. |
| [`select_grouped_mm`](#nemo_automodel-components-moe-mxfp8-select_grouped_mm)                 | Return the grouped-GEMM callable `grouped_mm(A, B, offs)` for the expert GEMMs.    |

### Data

[`_MXFP8_ACTIVE_ANNOUNCED`](#nemo_automodel-components-moe-mxfp8-_MXFP8_ACTIVE_ANNOUNCED)

[`_MXFP8_FALLBACK_WARNED`](#nemo_automodel-components-moe-mxfp8-_MXFP8_FALLBACK_WARNED)

[`_MXFP8_GROUPED_MM`](#nemo_automodel-components-moe-mxfp8-_MXFP8_GROUPED_MM)

[`_MXFP8_RESOLVED`](#nemo_automodel-components-moe-mxfp8-_MXFP8_RESOLVED)

### API

```python
nemo_automodel.components.moe.mxfp8._default_grouped_mm(
    A,
    B,
    offs
)
```

Fallback grouped GEMM (plain `torch._grouped_mm`) used when MXFP8 is off.

```python
nemo_automodel.components.moe.mxfp8._mxfp8_grouped_mm_or_none()
```

Return the MXFP8 grouped-GEMM callable iff it is usable on this device.

Requires CUDA with compute capability >= 10 (GB200/sm\_100+) AND a successful
torchao import. Otherwise returns `None` (callers fall back to
`torch._grouped_mm`). Emits a one-time warning when MXFP8 was requested but is
unavailable.

```python
nemo_automodel.components.moe.mxfp8._mxfp8_weight_relayout(
    B
)
```

Lay the \[E,K,N] expert weight out so its (-2,-1) transpose is contiguous.

torchao's MXFP8 quantizer calls `to_mx(B.transpose(-2,-1))` and strictly asserts
the input is contiguous, so the weight must be stored as \[E,N,K]-contiguous (viewed
as \[E,K,N]) — also the column-major B\_t layout torchao's grouped GEMM wants.

```python
nemo_automodel.components.moe.mxfp8._resolve_mxfp8_grouped_mm()
```

Resolve a torchao MXFP8 grouped-GEMM callable, normalizing across API generations.

Returns a callable `mxfp8_grouped_mm(A, B, offs)` mirroring `torch._grouped_mm`,
or `None` if no supported torchao API is importable. The result is cached.

```python
nemo_automodel.components.moe.mxfp8.select_grouped_mm(
    use_mxfp8
)
```

Return the grouped-GEMM callable `grouped_mm(A, B, offs)` for the expert GEMMs.

When `use_mxfp8` and the torchao MXFP8 kernel is usable on this device, returns a
wrapper that makes both operands contiguous in the layout torchao requires (A
contiguous; B relaid out so its transpose is contiguous — see \_mxfp8\_weight\_relayout)
and routes through it. Otherwise returns the plain `torch._grouped_mm` fallback,
leaving the bf16 path byte-identical. Shared by the no-bias helper and the inline
bias paths so dispatch + relayout are defined once.

```python
nemo_automodel.components.moe.mxfp8._MXFP8_ACTIVE_ANNOUNCED = False
```

```python
nemo_automodel.components.moe.mxfp8._MXFP8_FALLBACK_WARNED = False
```

```python
nemo_automodel.components.moe.mxfp8._MXFP8_GROUPED_MM = None
```

```python
nemo_automodel.components.moe.mxfp8._MXFP8_RESOLVED = False
```