Source code for nemo_retriever.retriever_graph_utils

# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Helpers for :class:`~nemo_retriever.retriever.Retriever` graph execution."""

from __future__ import annotations

from typing import Any

import pandas as pd

_RESERVED_RETRIEVE_KWARGS = frozenset({"query_texts", "refine_factor"})


[docs] def filter_retrieval_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: """Drop keys reserved for graph coordination (not forwarded to ``VDB.retrieval``).""" return {k: v for k, v in kwargs.items() if k not in _RESERVED_RETRIEVE_KWARGS}
[docs] def hits_lists_to_rerank_dataframe( query_texts: list[str], hits_per_query: list[list[dict[str, Any]]], ) -> pd.DataFrame: """One row per (query, hit) with payload to rebuild hits after reranking.""" rows: list[dict[str, Any]] = [] for q, hits in zip(query_texts, hits_per_query): for h in hits: rows.append({"query": q, "text": str(h.get("text", "")), "_hit": dict(h)}) return pd.DataFrame(rows)
[docs] def rerank_long_dataframe_to_hits( df: pd.DataFrame, *, query_texts: list[str], top_k: int, score_column: str = "rerank_score", ) -> list[list[dict[str, Any]]]: """Group long rerank output by query (preserving *query_texts* order), take top_k per query.""" if df.empty: return [[] for _ in query_texts] if score_column not in df.columns: raise ValueError(f"Rerank output missing score column {score_column!r}; columns={list(df.columns)}") work = df.copy() # Per-query score ordering (global sort in NemotronRerank is disabled for multi-query batches). work["_q_order"] = work["query"].map({q: i for i, q in enumerate(query_texts)}).fillna(len(query_texts)) work = work.sort_values(["_q_order", score_column], ascending=[True, False]).drop(columns=["_q_order"]) out: list[list[dict[str, Any]]] = [] for q in query_texts: sub = work[work["query"] == q] picked: list[dict[str, Any]] = [] for _, row in sub.head(int(top_k)).iterrows(): hit = dict(row["_hit"]) score = row[score_column] if isinstance(score, (int, float)): hit["_rerank_score"] = float(score) picked.append(hit) out.append(picked) return out