nemo_automodel.components.models.biencoder.state_dict_adapter#
Module Contents#
Classes#
Adapter for converting BiencoderModel state dict to single encoder format. |
API#
- class nemo_automodel.components.models.biencoder.state_dict_adapter.BiencoderStateDictAdapter#
Bases:
nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapterAdapter for converting BiencoderModel state dict to single encoder format.
This adapter extracts only the query encoder (lm_q) state dict and converts the “lm_q.” prefix to “model.” prefix, making it compatible with standard HuggingFace model format.
Initialization
Initialize the adapter.
- to_hf(
- state_dict: dict[str, Any],
- **kwargs,
Convert from biencoder state dict to HuggingFace format.
Filters to only lm_q keys and converts “lm_q.” prefix to “model.” prefix.
- Parameters:
state_dict – The biencoder model state dict
- Returns:
The converted HuggingFace format state dict with only query encoder
- from_hf(
- hf_state_dict: dict[str, Any],
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- **kwargs,
Convert HuggingFace state dict to biencoder format.
Converts “model.” prefix to “lm_q.” prefix for loading into biencoder.
- Parameters:
hf_state_dict – The HuggingFace format state dict
device_mesh – Optional device mesh (not used in this adapter)
- Returns:
The converted biencoder format state dict
- convert_single_tensor_to_hf(
- fqn: str,
- tensor: Any,
- **kwargs,
Convert a single tensor from biencoder format to HuggingFace format.
- Parameters:
fqn – Fully qualified name of the tensor in biencoder format
tensor – The tensor to convert
**kwargs – Additional arguments (unused)
- Returns:
List of (fqn, tensor) tuples in HuggingFace format. Returns empty list if tensor is not part of lm_q.