nemo_automodel.components.models.step3p5.state_dict_adapter#
State dict adapter for Step3p5 model.
Step3p5 uses grouped MoELinear weights with shape [n_exp, out, in], different from the standard per-expert format. This adapter handles conversion between:
HF Format (Step3p5): model.layers.{L}.moe.gate_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.up_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.down_proj.weight # [n_exp, dim, inter] model.layers.{L}.moe.gate.weight # [n_exp, dim] (router) model.layers.{L}.moe.router_bias # [n_exp] (router bias, optional) model.layers.{L}.share_expert.*.weight # Shared expert
Native Format (Automodel): model.layers.{L}.moe.experts.gate_and_up_projs # [n_exp, dim, 2inter] model.layers.{L}.moe.experts.down_projs # [n_exp, inter, dim] model.layers.{L}.moe.gate.weight # [n_exp, dim] model.layers.{L}.moe.gate.bias # [n_exp] model.layers.{L}.share_expert..weight
Note: Router gate weights and shared expert weights pass through with the same key names. Only the expert MLP weights (gate_proj, up_proj, down_proj) need transformation.
Module Contents#
Classes#
Converts between HF Step3p5 checkpoints and Automodel grouped-experts native format. |
Functions#
Create a DTensor from a local tensor. |
Data#
API#
- nemo_automodel.components.models.step3p5.state_dict_adapter.logger#
‘getLogger(…)’
- nemo_automodel.components.models.step3p5.state_dict_adapter._create_dtensor_from_local_or_reference(
- local_tensor: torch.Tensor,
- reference_dtensor: Optional[torch.Tensor],
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- rank: Optional[int] = None,
Create a DTensor from a local tensor.
Prefers using reference_dtensor’s mesh/placements if available (for preserving DTensor structure from DCP-loaded tensors). Falls back to creating a new DTensor using device_mesh if reference is not a DTensor.
- Parameters:
local_tensor – Local portion of the tensor after transformation
reference_dtensor – Optional DTensor to copy mesh/placements from
device_mesh – Device mesh for EP (used if reference is not DTensor)
rank – Current rank for device placement
- Returns:
DTensor if mesh is available, otherwise local_tensor
- class nemo_automodel.components.models.step3p5.state_dict_adapter.Step3p5StateDictAdapter(
- config: Any,
- moe_config: nemo_automodel.components.moe.config.MoEConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- dtype: torch.dtype = torch.float32,
Bases:
nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapterConverts between HF Step3p5 checkpoints and Automodel grouped-experts native format.
Step3p5 HF uses grouped MoELinear with shape [n_experts, out_features, in_features]: model.layers.{L}.moe.gate_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.up_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.down_proj.weight # [n_exp, dim, inter]
Automodel native format uses: model.layers.{L}.moe.experts.gate_and_up_projs # [n_exp, dim, 2*inter] model.layers.{L}.moe.experts.down_projs # [n_exp, inter, dim]
Initialization
- property _hf_prefix: str#
Prefix for HuggingFace format keys.
- to_hf(
- state_dict: dict[str, Any],
- exclude_key_regex: Optional[str] = None,
- quantization: bool = False,
- **kwargs,
Convert from native model state dict to HuggingFace format.
- from_hf(
- hf_state_dict: dict[str, Any],
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- **kwargs,
Convert HF checkpoint to native format.
Handles Step3p5’s grouped MoELinear format:
[n_exp, inter, dim] gate_proj/up_proj -> [n_exp, dim, 2*inter] gate_and_up_projs
[n_exp, dim, inter] down_proj -> [n_exp, inter, dim] down_projs
- convert_single_tensor_to_hf(
- fqn: str,
- tensor: Any,
- **kwargs,
Convert a single tensor from native format to HuggingFace format.
- Parameters:
fqn – Fully qualified name of the tensor in native format
tensor – The tensor to convert
**kwargs – Additional arguments for conversion
- Returns:
List of (fqn, tensor) tuples in HuggingFace format
- _convert_native_to_hf(
- fqn: str,
- tensor: torch.Tensor,
Convert native format expert tensors to HF Step3p5 format.
Native: gate_and_up_projs [n_exp, dim, 2*inter] -> HF: gate_proj, up_proj [n_exp, inter, dim] Native: down_projs [n_exp, inter, dim] -> HF: down_proj [n_exp, dim, inter]
Preserves DTensor structure when input is a DTensor.