nemo_automodel._transformers.biencoder#

Generic Biencoder Model for embedding and retrieval tasks.

Module Contents#

Classes#

BiencoderModel

Biencoder model that encodes queries and passages separately using bidirectional backbones.

Functions#

pool

Pool hidden states using the specified pooling method.

Data#

API#

nemo_automodel._transformers.biencoder.logger#

‘get_logger(…)’

nemo_automodel._transformers.biencoder.pool(
last_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
pool_type: str,
) torch.Tensor#

Pool hidden states using the specified pooling method.

Parameters:
  • last_hidden_states – Hidden states from the model [batch_size, seq_len, hidden_size]

  • attention_mask – Attention mask [batch_size, seq_len]

  • pool_type – Type of pooling to apply

Returns:

Pooled embeddings [batch_size, hidden_size]

nemo_automodel._transformers.biencoder.SUPPORTED_BACKBONES#

None

class nemo_automodel._transformers.biencoder.BiencoderModel(
lm_q: transformers.PreTrainedModel,
lm_p: transformers.PreTrainedModel,
pooling: str = 'avg',
l2_normalize: bool = True,
share_encoder: bool = True,
)#

Bases: torch.nn.Module

Biencoder model that encodes queries and passages separately using bidirectional backbones.

Initialization

forward(
input_dict: dict,
encoder: str = 'query',
) Optional[torch.Tensor]#

Forward pass — delegates to encode().

Going through forward() ensures FSDP2 unshard hooks fire via call.

encode(
input_dict: dict,
encoder: str = 'query',
) Optional[torch.Tensor]#

Encode inputs using the query or passage encoder.

Parameters:
  • input_dict – Tokenized inputs (input_ids, attention_mask, etc.)

  • encoder – “query” or “passage”

Returns:

Embeddings [batch_size, hidden_dim], or None if input_dict is empty.

_encode(
encoder: transformers.PreTrainedModel,
input_dict: dict,
) Optional[torch.Tensor]#

Encode input using the encoder.

classmethod build(
model_name_or_path: str,
share_encoder: bool = True,
pooling: str = 'avg',
l2_normalize: bool = True,
trust_remote_code: bool = False,
**hf_kwargs,
)#

Build biencoder model from a pretrained backbone.

save_pretrained(save_directory: str, **kwargs)#

Save model to output directory.

If checkpointer is in kwargs, delegates to Checkpointer.save_model for distributed/FSDP-safe saving. Otherwise falls back to HF-native save.