nemo_export.model_adapters.reranker.reranker_adapter#

Module Contents#

Classes#

SequenceClassificationModelAdapterWithoutTypeIds

Adapter for sequence classification models that don’t use token type IDs.

SequenceClassificationModelAdapterWithTypeIds

Adapter for sequence classification models that use token type IDs.

Functions#

get_llama_reranker_hf_model

Load and adapt a HuggingFace reranker model for export.

API#

class nemo_export.model_adapters.reranker.reranker_adapter.SequenceClassificationModelAdapterWithoutTypeIds(
model: transformers.AutoModelForSequenceClassification,
)#

Bases: torch.nn.Module

Adapter for sequence classification models that don’t use token type IDs.

This adapter wraps a HuggingFace AutoModelForSequenceClassification model and provides a simplified forward method that only takes input_ids and attention_mask as inputs, excluding token_type_ids.

Parameters:

model – A HuggingFace AutoModelForSequenceClassification model to wrap.

.. attribute:: config

The configuration object from the wrapped model.

Initialization

forward(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) torch.Tensor#

Forward pass through the sequence classification model.

Parameters:
  • input_ids – Token IDs for the input sequences. Shape: (batch_size, sequence_length)

  • attention_mask – Attention mask indicating which tokens should be attended to. Shape: (batch_size, sequence_length)

Returns:

(batch_size, num_labels)

Return type:

Logits from the classification head. Shape

class nemo_export.model_adapters.reranker.reranker_adapter.SequenceClassificationModelAdapterWithTypeIds(
model: transformers.AutoModelForSequenceClassification,
)#

Bases: torch.nn.Module

Adapter for sequence classification models that use token type IDs.

This adapter wraps a HuggingFace AutoModelForSequenceClassification model and provides a forward method that includes token_type_ids for models that require this input (e.g., BERT-based models).

Parameters:

model – A HuggingFace AutoModelForSequenceClassification model to wrap.

.. attribute:: config

The configuration object from the wrapped model.

Initialization

forward(
input_ids: torch.Tensor,
token_type_ids: torch.Tensor,
attention_mask: torch.Tensor,
) torch.Tensor#

Forward pass through the sequence classification model.

Parameters:
  • input_ids – Token IDs for the input sequences. Shape: (batch_size, sequence_length)

  • token_type_ids – Token type IDs to distinguish between different parts of the input. Shape: (batch_size, sequence_length)

  • attention_mask – Attention mask indicating which tokens should be attended to. Shape: (batch_size, sequence_length)

Returns:

(batch_size, num_labels)

Return type:

Logits from the classification head. Shape

nemo_export.model_adapters.reranker.reranker_adapter.get_llama_reranker_hf_model(
model_name_or_path: str | os.PathLike[str],
trust_remote_code: bool = False,
attn_implementation: str | None = None,
)#

Load and adapt a HuggingFace reranker model for export.

This function loads a sequence classification model from HuggingFace and wraps it with an appropriate adapter based on whether the model uses token_type_ids. It also handles specific configuration adjustments for certain model types.

Parameters:
  • model_name_or_path – Path to the model directory or HuggingFace model identifier.

  • trust_remote_code – Whether to trust and execute remote code from the model repository. Defaults to False.

  • attn_implementation – Specific attention implementation to use. If provided, the model’s attention implementation will be set to this value. Defaults to None.

Returns:

A tuple containing: - model: The wrapped sequence classification model (either with or without token type IDs). - tokenizer: The corresponding tokenizer for the model.

Return type:

tuple

.. note::

The function automatically determines whether to use the adapter with or without token_type_ids based on the tokenizer’s model_input_names attribute.

For models with attn_implementation specified, the config is reset after initialization to handle cases where the config is mutated during init.