nemo_automodel.components.models.step3p5.state_dict_adapter#

State dict adapter for Step3p5 model.

Step3p5 uses grouped MoELinear weights with shape [n_exp, out, in], different from the standard per-expert format. This adapter handles conversion between:

HF Format (Step3p5): model.layers.{L}.moe.gate_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.up_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.down_proj.weight # [n_exp, dim, inter] model.layers.{L}.moe.gate.weight # [n_exp, dim] (router) model.layers.{L}.moe.router_bias # [n_exp] (router bias, optional) model.layers.{L}.share_expert.*.weight # Shared expert

Native Format (Automodel): model.layers.{L}.moe.experts.gate_and_up_projs # [n_exp, dim, 2inter] model.layers.{L}.moe.experts.down_projs # [n_exp, inter, dim] model.layers.{L}.moe.gate.weight # [n_exp, dim] model.layers.{L}.moe.gate.bias # [n_exp] model.layers.{L}.share_expert..weight

Note: Router gate weights and shared expert weights pass through with the same key names. Only the expert MLP weights (gate_proj, up_proj, down_proj) need transformation.

Module Contents#

Classes#

Step3p5StateDictAdapter

Converts between HF Step3p5 checkpoints and Automodel grouped-experts native format.

Functions#

_create_dtensor_from_local_or_reference

Create a DTensor from a local tensor.

Data#

API#

nemo_automodel.components.models.step3p5.state_dict_adapter.logger#

‘getLogger(…)’

nemo_automodel.components.models.step3p5.state_dict_adapter._create_dtensor_from_local_or_reference(
local_tensor: torch.Tensor,
reference_dtensor: Optional[torch.Tensor],
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
rank: Optional[int] = None,
) torch.Tensor#

Create a DTensor from a local tensor.

Prefers using reference_dtensor’s mesh/placements if available (for preserving DTensor structure from DCP-loaded tensors). Falls back to creating a new DTensor using device_mesh if reference is not a DTensor.

Parameters:
  • local_tensor – Local portion of the tensor after transformation

  • reference_dtensor – Optional DTensor to copy mesh/placements from

  • device_mesh – Device mesh for EP (used if reference is not DTensor)

  • rank – Current rank for device placement

Returns:

DTensor if mesh is available, otherwise local_tensor

class nemo_automodel.components.models.step3p5.state_dict_adapter.Step3p5StateDictAdapter(
config: Any,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype = torch.float32,
)#

Bases: nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapter

Converts between HF Step3p5 checkpoints and Automodel grouped-experts native format.

Step3p5 HF uses grouped MoELinear with shape [n_experts, out_features, in_features]: model.layers.{L}.moe.gate_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.up_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.down_proj.weight # [n_exp, dim, inter]

Automodel native format uses: model.layers.{L}.moe.experts.gate_and_up_projs # [n_exp, dim, 2*inter] model.layers.{L}.moe.experts.down_projs # [n_exp, inter, dim]

Initialization

property _hf_prefix: str#

Prefix for HuggingFace format keys.

to_hf(
state_dict: dict[str, Any],
exclude_key_regex: Optional[str] = None,
quantization: bool = False,
**kwargs,
) dict[str, Any]#

Convert from native model state dict to HuggingFace format.

from_hf(
hf_state_dict: dict[str, Any],
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
**kwargs,
) dict[str, Any]#

Convert HF checkpoint to native format.

Handles Step3p5’s grouped MoELinear format:

  • [n_exp, inter, dim] gate_proj/up_proj -> [n_exp, dim, 2*inter] gate_and_up_projs

  • [n_exp, dim, inter] down_proj -> [n_exp, inter, dim] down_projs

convert_single_tensor_to_hf(
fqn: str,
tensor: Any,
**kwargs,
) list[tuple[str, Any]]#

Convert a single tensor from native format to HuggingFace format.

Parameters:
  • fqn – Fully qualified name of the tensor in native format

  • tensor – The tensor to convert

  • **kwargs – Additional arguments for conversion

Returns:

List of (fqn, tensor) tuples in HuggingFace format

_convert_native_to_hf(
fqn: str,
tensor: torch.Tensor,
) list[tuple[str, torch.Tensor]] | None#

Convert native format expert tensors to HF Step3p5 format.

Native: gate_and_up_projs [n_exp, dim, 2*inter] -> HF: gate_proj, up_proj [n_exp, inter, dim] Native: down_projs [n_exp, inter, dim] -> HF: down_proj [n_exp, dim, inter]

Preserves DTensor structure when input is a DTensor.