Source code for nemo_retriever.model.local.nemotron_rerank_v2

# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Local wrapper for nvidia/llama-nemotron-rerank-1b-v2 cross-encoder reranker."""

from __future__ import annotations

from typing import List, Optional

from nemo_retriever.utils.hf_cache import configure_global_hf_cache_base
from nemo_retriever.utils.hf_model_registry import get_hf_revision
from ..model import BaseModel, ModelRunMode


_DEFAULT_MODEL = "nvidia/llama-nemotron-rerank-1b-v2"
_DEFAULT_MAX_LENGTH = 512
_DEFAULT_BATCH_SIZE = 32


def _prompt_template(query: str, passage: str) -> str:
    """Format a (query, passage) pair as the model expects."""
    return f"question:{query} \n \n passage:{passage}"


[docs] class NemotronRerankV2(BaseModel): """ Local cross-encoder reranker wrapping nvidia/llama-nemotron-rerank-1b-v2. The model scores (query, document) pairs and returns raw logits; higher values indicate greater relevance. It is fine-tuned from meta-llama/Llama-3.2-1B with bi-directional attention and supports 26 languages with sequences up to 8 192 tokens. Example:: reranker = NemotronRerankV2() scores = reranker.score("What is ML?", ["Machine learning is…", "Paris is…"]) # scores -> [20.6, -23.1] (higher = more relevant) """ def __init__( self, model_name: str = _DEFAULT_MODEL, device: Optional[str] = None, hf_cache_dir: Optional[str] = None, ) -> None: super().__init__() import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer configure_global_hf_cache_base() self._model_name = model_name self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") kwargs: dict = {"trust_remote_code": True} if hf_cache_dir: kwargs["cache_dir"] = hf_cache_dir revision = get_hf_revision(model_name, strict=False) if revision is not None: kwargs["revision"] = revision self._tokenizer = AutoTokenizer.from_pretrained( model_name, padding_side="left", **kwargs, ) if self._tokenizer.pad_token is None: self._tokenizer.pad_token = self._tokenizer.eos_token self._model = ( AutoModelForSequenceClassification.from_pretrained( model_name, torch_dtype=torch.bfloat16, **kwargs, ) .eval() .to(self._device) ) if self._model.config.pad_token_id is None: self._model.config.pad_token_id = self._tokenizer.eos_token_id # ------------------------------------------------------------------ # BaseModel abstract properties # ------------------------------------------------------------------ @property def model_name(self) -> str: return self._model_name @property def model_type(self) -> str: return "reranker" @property def model_runmode(self) -> ModelRunMode: return "local" @property def input(self): return "List[Tuple[str, str]]" @property def output(self): return "List[float]" @property def input_batch_size(self) -> int: return _DEFAULT_BATCH_SIZE # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def score( self, query: str, documents: List[str], *, max_length: int = _DEFAULT_MAX_LENGTH, batch_size: int = _DEFAULT_BATCH_SIZE, ) -> List[float]: """ Score relevance of *documents* to *query*. Parameters ---------- query: The search query. documents: Candidate passages/documents to score. max_length: Tokenizer truncation length (default 512; max supported 8 192). batch_size: Number of (query, doc) pairs to process per GPU forward pass. Returns ------- List[float] Raw logit scores aligned with *documents* (higher = more relevant). """ import torch if not documents: return [] texts = [_prompt_template(query, d) for d in documents] all_scores: List[float] = [] # Tokenize all texts in a single call to avoid repeated setup overhead. full_batch = self._tokenizer( texts, padding=True, truncation=True, return_tensors="pt", max_length=max_length, ) with torch.inference_mode(): for start in range(0, len(texts), batch_size): batch = {k: v[start : start + batch_size].to(self._device) for k, v in full_batch.items()} logits = self._model(**batch).logits all_scores.extend(logits.view(-1).cpu().tolist()) return all_scores
[docs] def score_pairs( self, pairs: List[tuple], *, max_length: int = _DEFAULT_MAX_LENGTH, batch_size: int = _DEFAULT_BATCH_SIZE, ) -> List[float]: """ Score a list of (query, document) pairs. Parameters ---------- pairs: Sequence of ``(query, document)`` tuples. max_length: Tokenizer truncation length. batch_size: GPU forward-pass batch size. Returns ------- List[float] Raw logit scores (higher = more relevant). """ import torch if not pairs: return [] texts = [_prompt_template(q, d) for q, d in pairs] all_scores: List[float] = [] # Tokenize all texts in a single call to avoid repeated setup overhead. full_batch = self._tokenizer( texts, padding=True, truncation=True, return_tensors="pt", max_length=max_length, ) with torch.inference_mode(): for start in range(0, len(texts), batch_size): batch = {k: v[start : start + batch_size].to(self._device) for k, v in full_batch.items()} logits = self._model(**batch).logits all_scores.extend(logits.view(-1).cpu().tolist()) return all_scores