nemo_automodel.components.models.step3p5.state_dict_adapter

View as Markdown

State dict adapter for Step3p5 model.

Step3p5 uses grouped MoELinear weights with shape [n_exp, out, in], different from the standard per-expert format. This adapter handles conversion between:

HF Format (Step3p5): model.layers.{L}.moe.gate_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.up_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.down_proj.weight # [n_exp, dim, inter] model.layers.{L}.moe.gate.weight # [n_exp, dim] (router) model.layers.{L}.moe.router_bias # [n_exp] (post-sigmoid router correction bias, optional) model.layers.{L}.share_expert.*.weight # Shared expert

Native Format (Automodel): model.layers.{L}.moe.experts.gate_and_up_projs # [n_exp, dim, 2inter] model.layers.{L}.moe.experts.down_projs # [n_exp, inter, dim] model.layers.{L}.moe.gate.weight # [n_exp, dim] model.layers.{L}.moe.gate.e_score_correction_bias # [n_exp] model.layers.{L}.share_expert..weight

Note: Router gate weights and shared expert weights pass through with the same key names. Only the expert MLP weights (gate_proj, up_proj, down_proj) need transformation.

Module Contents

Classes

NameDescription
Step3p5StateDictAdapterConverts between HF Step3p5 checkpoints and Automodel grouped-experts native format.

Functions

NameDescription
_create_dtensor_from_local_or_referenceCreate a DTensor from a local tensor.
_swap_shard_placements_1_2Swap Shard dim 1 and dim 2 in DTensor placements after a transpose(1, 2).

Data

logger

API

class nemo_automodel.components.models.step3p5.state_dict_adapter.Step3p5StateDictAdapter(
config: typing.Any,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype = torch.float32
)

Bases: StateDictAdapter

Converts between HF Step3p5 checkpoints and Automodel grouped-experts native format.

Step3p5 HF uses grouped MoELinear with shape [n_experts, out_features, in_features]: model.layers.{L}.moe.gate_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.up_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.down_proj.weight # [n_exp, dim, inter]

_hf_prefix
str

Prefix for HuggingFace format keys.

nemo_automodel.components.models.step3p5.state_dict_adapter.Step3p5StateDictAdapter._convert_native_to_hf(
fqn: str,
tensor: torch.Tensor
) -> list[tuple[str, torch.Tensor]] | None

Convert native format expert tensors to HF Step3p5 format.

Native: gate_and_up_projs [n_exp, dim, 2*inter] -> HF: gate_proj, up_proj [n_exp, inter, dim] Native: down_projs [n_exp, inter, dim] -> HF: down_proj [n_exp, dim, inter]

Preserves DTensor structure when input is a DTensor.

nemo_automodel.components.models.step3p5.state_dict_adapter.Step3p5StateDictAdapter.convert_single_tensor_to_hf(
fqn: str,
tensor: typing.Any,
kwargs = {}
) -> list[tuple[str, typing.Any]]

Convert a single tensor from native format to HuggingFace format.

Parameters:

fqn
str

Fully qualified name of the tensor in native format

tensor
Any

The tensor to convert

**kwargs
Defaults to {}

Additional arguments for conversion

Returns: list[tuple[str, Any]]

List of (fqn, tensor) tuples in HuggingFace format

nemo_automodel.components.models.step3p5.state_dict_adapter.Step3p5StateDictAdapter.from_hf(
hf_state_dict: dict[str, typing.Any],
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None,
kwargs = {}
) -> dict[str, typing.Any]

Convert HF checkpoint to native format.

Handles Step3p5’s grouped MoELinear format:

  • [n_exp, inter, dim] gate_proj/up_proj -> [n_exp, dim, 2*inter] gate_and_up_projs
  • [n_exp, dim, inter] down_proj -> [n_exp, inter, dim] down_projs
nemo_automodel.components.models.step3p5.state_dict_adapter.Step3p5StateDictAdapter.to_hf(
state_dict: dict[str, typing.Any],
exclude_key_regex: typing.Optional[str] = None,
quantization: bool = False,
kwargs = {}
) -> dict[str, typing.Any]

Convert from native model state dict to HuggingFace format.

nemo_automodel.components.models.step3p5.state_dict_adapter._create_dtensor_from_local_or_reference(
local_tensor: torch.Tensor,
reference_dtensor: typing.Optional[torch.Tensor],
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None,
rank: typing.Optional[int] = None,
placements_override: typing.Optional[tuple] = None
) -> torch.Tensor

Create a DTensor from a local tensor.

Prefers using reference_dtensor’s mesh/placements if available (for preserving DTensor structure from DCP-loaded tensors). Falls back to creating a new DTensor using device_mesh if reference is not a DTensor.

Parameters:

local_tensor
torch.Tensor

Local portion of the tensor after transformation

reference_dtensor
Optional[torch.Tensor]

Optional DTensor to copy mesh/placements from

device_mesh
Optional[DeviceMesh]Defaults to None

Device mesh for EP (used if reference is not DTensor)

rank
Optional[int]Defaults to None

Current rank for device placement

placements_override
Optional[tuple]Defaults to None

If provided, use these placements instead of the reference DTensor’s placements. Useful after transposing the local tensor, where shard dimensions need to be swapped.

Returns: torch.Tensor

DTensor if mesh is available, otherwise local_tensor

nemo_automodel.components.models.step3p5.state_dict_adapter._swap_shard_placements_1_2(
placements: tuple
) -> tuple

Swap Shard dim 1 and dim 2 in DTensor placements after a transpose(1, 2).

When we transpose a 3-D tensor’s dims 1 and 2, any Shard placement on those dims must be swapped so that DTensor.from_local infers the correct global shape. Without this, the shard multiplier is applied to the wrong axis.

nemo_automodel.components.models.step3p5.state_dict_adapter.logger = logging.getLogger(__name__)