nemo_automodel.components.models.step3p5.state_dict_adapter
nemo_automodel.components.models.step3p5.state_dict_adapter
State dict adapter for Step3p5 model.
Step3p5 uses grouped MoELinear weights with shape [n_exp, out, in], different from the standard per-expert format. This adapter handles conversion between:
HF Format (Step3p5): model.layers.{L}.moe.gate_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.up_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.down_proj.weight # [n_exp, dim, inter] model.layers.{L}.moe.gate.weight # [n_exp, dim] (router) model.layers.{L}.moe.router_bias # [n_exp] (post-sigmoid router correction bias, optional) model.layers.{L}.share_expert.*.weight # Shared expert
Native Format (Automodel): model.layers.{L}.moe.experts.gate_and_up_projs # [n_exp, dim, 2inter] model.layers.{L}.moe.experts.down_projs # [n_exp, inter, dim] model.layers.{L}.moe.gate.weight # [n_exp, dim] model.layers.{L}.moe.gate.e_score_correction_bias # [n_exp] model.layers.{L}.share_expert..weight
Note: Router gate weights and shared expert weights pass through with the same key names. Only the expert MLP weights (gate_proj, up_proj, down_proj) need transformation.
Module Contents
Classes
Functions
Data
API
Bases: StateDictAdapter
Converts between HF Step3p5 checkpoints and Automodel grouped-experts native format.
Step3p5 HF uses grouped MoELinear with shape [n_experts, out_features, in_features]: model.layers.{L}.moe.gate_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.up_proj.weight # [n_exp, inter, dim] model.layers.{L}.moe.down_proj.weight # [n_exp, dim, inter]
Prefix for HuggingFace format keys.
Convert native format expert tensors to HF Step3p5 format.
Native: gate_and_up_projs [n_exp, dim, 2*inter] -> HF: gate_proj, up_proj [n_exp, inter, dim] Native: down_projs [n_exp, inter, dim] -> HF: down_proj [n_exp, dim, inter]
Preserves DTensor structure when input is a DTensor.
Convert a single tensor from native format to HuggingFace format.
Parameters:
Fully qualified name of the tensor in native format
The tensor to convert
Additional arguments for conversion
Returns: list[tuple[str, Any]]
List of (fqn, tensor) tuples in HuggingFace format
Convert HF checkpoint to native format.
Handles Step3p5’s grouped MoELinear format:
- [n_exp, inter, dim] gate_proj/up_proj -> [n_exp, dim, 2*inter] gate_and_up_projs
- [n_exp, dim, inter] down_proj -> [n_exp, inter, dim] down_projs
Convert from native model state dict to HuggingFace format.
Create a DTensor from a local tensor.
Prefers using reference_dtensor’s mesh/placements if available (for preserving DTensor structure from DCP-loaded tensors). Falls back to creating a new DTensor using device_mesh if reference is not a DTensor.
Parameters:
Local portion of the tensor after transformation
Optional DTensor to copy mesh/placements from
Device mesh for EP (used if reference is not DTensor)
Current rank for device placement
If provided, use these placements instead of the reference DTensor’s placements. Useful after transposing the local tensor, where shard dimensions need to be swapped.
Returns: torch.Tensor
DTensor if mesh is available, otherwise local_tensor
Swap Shard dim 1 and dim 2 in DTensor placements after a transpose(1, 2).
When we transpose a 3-D tensor’s dims 1 and 2, any Shard placement on those
dims must be swapped so that DTensor.from_local infers the correct global
shape. Without this, the shard multiplier is applied to the wrong axis.