nemo_automodel.components.models.common.bidirectional#

Bidirectional model state dict adapter utilities.

This module provides the BiencoderStateDictAdapter for converting between biencoder and HuggingFace state dict formats.

Module Contents#

Classes#

BiencoderStateDictAdapter

Adapter for converting BiencoderModel state dict to/from single-encoder HF format.

Data#

API#

class nemo_automodel.components.models.common.bidirectional.BiencoderStateDictAdapter#

Bases: nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapter

Adapter for converting BiencoderModel state dict to/from single-encoder HF format.

Extracts only the query encoder (lm_q) on save, mapping lm_q. to model.. On load, fans model. keys back out to both lm_q. and lm_p.. PEFT-prefixed keys (base_model.model.) are handled transparently.

Initialization

_PEFT_PREFIX#

‘base_model.model.’

static _swap_key(
key: str,
src: str,
dst: str,
peft_prefix: str,
) Optional[str]#

Return key with src prefix replaced by dst, handling an optional PEFT wrapper.

Returns None when key doesn’t match src (bare or PEFT-wrapped).

to_hf(
state_dict: dict[str, Any],
**kwargs,
) dict[str, Any]#

Convert biencoder state dict to HF format (lm_q -> model).

from_hf(
hf_state_dict: dict[str, Any],
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
**kwargs,
) dict[str, Any]#

Convert HF state dict to biencoder format (model -> lm_q + lm_p).

convert_single_tensor_to_hf(
fqn: str,
tensor: Any,
**kwargs,
) list[tuple[str, Any]]#

Convert a single tensor from biencoder to HF format. Skips non-lm_q tensors.

nemo_automodel.components.models.common.bidirectional.__all__#

[‘BiencoderStateDictAdapter’]