nemo_automodel.components.models.biencoder.state_dict_adapter#

Module Contents#

Classes#

BiencoderStateDictAdapter

Adapter for converting BiencoderModel state dict to single encoder format.

API#

class nemo_automodel.components.models.biencoder.state_dict_adapter.BiencoderStateDictAdapter#

Bases: nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapter

Adapter for converting BiencoderModel state dict to single encoder format.

This adapter extracts only the query encoder (lm_q) state dict and converts the “lm_q.” prefix to “model.” prefix, making it compatible with standard HuggingFace model format.

Initialization

Initialize the adapter.

to_hf(
state_dict: dict[str, Any],
**kwargs,
) dict[str, Any]#

Convert from biencoder state dict to HuggingFace format.

Filters to only lm_q keys and converts “lm_q.” prefix to “model.” prefix.

Parameters:

state_dict – The biencoder model state dict

Returns:

The converted HuggingFace format state dict with only query encoder

from_hf(
hf_state_dict: dict[str, Any],
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
**kwargs,
) dict[str, Any]#

Convert HuggingFace state dict to biencoder format.

Converts “model.” prefix to “lm_q.” prefix for loading into biencoder.

Parameters:
  • hf_state_dict – The HuggingFace format state dict

  • device_mesh – Optional device mesh (not used in this adapter)

Returns:

The converted biencoder format state dict

convert_single_tensor_to_hf(
fqn: str,
tensor: Any,
**kwargs,
) list[tuple[str, Any]]#

Convert a single tensor from biencoder format to HuggingFace format.

Parameters:
  • fqn – Fully qualified name of the tensor in biencoder format

  • tensor – The tensor to convert

  • **kwargs – Additional arguments (unused)

Returns:

List of (fqn, tensor) tuples in HuggingFace format. Returns empty list if tensor is not part of lm_q.