nemo_automodel._transformers.biencoder#
Generic Biencoder Model for embedding and retrieval tasks.
Module Contents#
Classes#
Biencoder model that encodes queries and passages separately using bidirectional backbones. |
Functions#
Pool hidden states using the specified pooling method. |
Data#
API#
- nemo_automodel._transformers.biencoder.logger#
‘get_logger(…)’
- nemo_automodel._transformers.biencoder.pool(
- last_hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- pool_type: str,
Pool hidden states using the specified pooling method.
- Parameters:
last_hidden_states – Hidden states from the model [batch_size, seq_len, hidden_size]
attention_mask – Attention mask [batch_size, seq_len]
pool_type – Type of pooling to apply
- Returns:
Pooled embeddings [batch_size, hidden_size]
- nemo_automodel._transformers.biencoder.SUPPORTED_BACKBONES#
None
- class nemo_automodel._transformers.biencoder.BiencoderModel(
- lm_q: transformers.PreTrainedModel,
- lm_p: transformers.PreTrainedModel,
- pooling: str = 'avg',
- l2_normalize: bool = True,
- share_encoder: bool = True,
Bases:
torch.nn.ModuleBiencoder model that encodes queries and passages separately using bidirectional backbones.
Initialization
- forward(
- input_dict: dict,
- encoder: str = 'query',
Forward pass — delegates to encode().
Going through forward() ensures FSDP2 unshard hooks fire via call.
- encode(
- input_dict: dict,
- encoder: str = 'query',
Encode inputs using the query or passage encoder.
- Parameters:
input_dict – Tokenized inputs (input_ids, attention_mask, etc.)
encoder – “query” or “passage”
- Returns:
Embeddings [batch_size, hidden_dim], or None if input_dict is empty.
- _encode(
- encoder: transformers.PreTrainedModel,
- input_dict: dict,
Encode input using the encoder.
- classmethod build(
- model_name_or_path: str,
- share_encoder: bool = True,
- pooling: str = 'avg',
- l2_normalize: bool = True,
- trust_remote_code: bool = False,
- **hf_kwargs,
Build biencoder model from a pretrained backbone.
- save_pretrained(save_directory: str, **kwargs)#
Save model to output directory.
If
checkpointeris in kwargs, delegates to Checkpointer.save_model for distributed/FSDP-safe saving. Otherwise falls back to HF-native save.