nemo_automodel.components.models.gpt_oss.state_dict_adapter#

Module Contents#

Classes#

Data#

API#

nemo_automodel.components.models.gpt_oss.state_dict_adapter.FP4_VALUES#

None

class nemo_automodel.components.models.gpt_oss.state_dict_adapter.GPTOSSStateDictAdapter(
config: transformers.GptOssConfig,
moe_config: nemo_automodel.components.moe.layers.MoEConfig,
backend: nemo_automodel.components.moe.utils.BackendConfig,
dtype: torch.dtype = torch.bfloat16,
)#

Bases: nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapter

_apply_key_mapping(
state_dict: dict[str, Any],
mapping: dict[str, str],
) dict[str, Any]#
_add_quantization_block_scale_tensors(
state_dict: dict[str, Any],
) dict[str, Any]#
_dequantize_block_scale_tensors(
state_dict: dict[str, Any],
) dict[str, Any]#
_convert_moe_packed_tensors(
blocks,
scales,
dtype: torch.dtype = torch.bfloat16,
rows_per_chunk: int = 32768 * 1024,
) torch.Tensor#

Convert the mxfp4 weights to bfloat16.

Source: https://github.com/huggingface/transformers/blob/869735d37d0f929311ac6611728c482a4414ba8c/src/transformers/integrations/mxfp4.py#L77

to_hf(
state_dict: dict[str, Any],
exclude_key_regex: Optional[str] = None,
quantization: bool = False,
**kwargs,
) dict[str, Any]#

Convert from native model state dict to HuggingFace format.

from_hf(
hf_state_dict: dict[str, Any],
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
**kwargs,
) dict[str, Any]#

Convert HF checkpoint to native format.

  • Apply key mappings from HF to internal format

  • Add quantization block and scale tensors