nemo_automodel.components.models.qwen3_5_moe.state_dict_adapter#
State-dict adapter for Qwen3.5-MoE.
HF Qwen3.5-MoE stores expert weights as aggregated 3-D tensors:
model.language_model.layers.{L}.mlp.experts.gate_up_proj # [n_experts, 2*moe_inter, hidden]
model.language_model.layers.{L}.mlp.experts.down_proj # [n_experts, hidden, moe_inter]
NeMo uses a different naming convention and transposed layout (x @ weight):
model.language_model.layers.{L}.mlp.experts.gate_and_up_projs # [n_experts, hidden, 2*moe_inter]
model.language_model.layers.{L}.mlp.experts.down_projs # [n_experts, moe_inter, hidden]
Both expert tensors require .transpose(1, 2) when converting between formats.
Additionally, the shared expert uses singular in HF and plural in NeMo:
HF: .mlp.shared_expert.{gate,up,down}_proj.weight
NeMo: .mlp.shared_experts.{gate,up,down}_proj.weight
All other keys (attention, linear_attn/GatedDeltaNet, norms, embeddings, lm_head, vision encoder) pass through unchanged.
Module Contents#
Classes#
Converts between HF Qwen3.5-MoE checkpoints and the NeMo native format. |
API#
- class nemo_automodel.components.models.qwen3_5_moe.state_dict_adapter.Qwen3_5MoeStateDictAdapter(
- config: Any,
- moe_config: nemo_automodel.components.moe.layers.MoEConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- dtype: torch.dtype = torch.float32,
Bases:
nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapterConverts between HF Qwen3.5-MoE checkpoints and the NeMo native format.
Handles:
Aggregated expert weight renaming (gate_up_proj ↔ gate_and_up_projs)
Shared expert key mapping (shared_expert ↔ shared_experts)
Expert-parallel sharding when a device mesh is provided
Initialization
- _apply_key_mapping(
- state_dict: dict[str, Any],
- mapping: dict[str, str],
Apply key substring mappings to state dict keys.
- to_hf(
- state_dict: dict[str, Any],
- exclude_key_regex: Optional[str] = None,
- quantization: bool = False,
- **kwargs,
- from_hf(
- hf_state_dict: dict[str, Any],
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- **kwargs,
- convert_single_tensor_to_hf(
- fqn: str,
- tensor: Any,
- **kwargs,
Convert a single native tensor back to HF format.