# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Record adapters for the graph-pipeline VDB upload/retrieval path."""
from __future__ import annotations
import json
from collections.abc import Mapping
from pathlib import Path
from typing import Any, TypedDict
[docs]
class RetrievalHit(TypedDict, total=False):
"""Shape of a single hit returned by ``Retriever.query`` / ``Retriever.queries``.
``metadata`` is a native ``dict`` at this boundary — never a JSON string. The
LanceDB storage layer JSON-encodes on write and decodes on read; do not let
a re-encoded string leak back out here. See ``_normalize_hit`` for the
contract enforcement point.
``total=False`` because optional fields (``stored_image_uri``,
``content_type``, ``bbox_xyxy_norm``, scores) are only set when present.
"""
text: str
metadata: dict[str, Any]
source: str
source_id: str
path: str
page_number: int | None
pdf_basename: str
pdf_page: str
stored_image_uri: str
content_type: str
bbox_xyxy_norm: list[float]
_distance: float
_score: float
def _embedding_from_graph_row(row: dict[str, Any], metadata: dict[str, Any]) -> Any:
if metadata.get("embedding") is not None:
return metadata["embedding"]
payload = row.get("text_embeddings_1b_v2")
return payload.get("embedding") if isinstance(payload, dict) else None
def _first_str(*values: Any) -> str:
for value in values:
if isinstance(value, str) and value.strip():
return value.strip()
return ""
def _optional_int(value: Any) -> int | None:
if value is not None:
try:
return int(value)
except (TypeError, ValueError):
pass
return None
def _dict_or_empty(value: Any) -> dict[str, Any]:
return dict(value) if isinstance(value, dict) else {}
def _client_record_from_graph_row(row: dict[str, Any]) -> dict[str, Any] | None:
metadata = _dict_or_empty(row.get("metadata"))
embedding = _embedding_from_graph_row(row, metadata)
text = row.get("text") or row.get("content") or metadata.get("content")
if embedding is None or not text:
return None
content_metadata = _dict_or_empty(metadata.get("content_metadata"))
page_number = _optional_int(content_metadata.get("page_number"))
if page_number is None:
page_number = _optional_int(row.get("page_number"))
if page_number is not None:
content_metadata.setdefault("page_number", page_number)
content_type = row.get("_content_type") or row.get("content_type")
if content_type:
content_metadata.setdefault("type", content_type)
stored_image_uri = row.get("_stored_image_uri") or row.get("stored_image_uri")
if stored_image_uri:
content_metadata.setdefault("stored_image_uri", stored_image_uri)
bbox = row.get("_bbox_xyxy_norm") or row.get("bbox_xyxy_norm")
if bbox:
content_metadata.setdefault("bbox_xyxy_norm", bbox)
for key in ("segment_start_seconds", "segment_end_seconds", "frame_timestamp_seconds"):
if key in metadata:
content_metadata.setdefault(key, metadata[key])
source_path = _first_str(
metadata.get("source_path"),
row.get("path"),
row.get("source_id"),
row.get("source"),
metadata.get("source_id"),
)
source_name = Path(source_path).name if source_path else str(row.get("filename") or row.get("source_id") or "")
source_metadata = _dict_or_empty(metadata.get("source_metadata"))
if source_path:
source_metadata.setdefault("source_id", source_path)
if source_name:
source_metadata.setdefault("source_name", source_name)
record_metadata = dict(metadata)
record_metadata["embedding"] = embedding
record_metadata["content"] = str(text)
record_metadata["content_metadata"] = content_metadata
record_metadata["source_metadata"] = source_metadata
return {"document_type": str(row.get("document_type") or "text"), "metadata": record_metadata}
[docs]
def to_client_vdb_records(rows: list[dict[str, Any]]) -> list[list[dict[str, Any]]]:
"""Convert graph-pipeline rows into the nested record shape expected by client VDBs.
When no row survives conversion (empty input or all rows lack text/embedding),
returns ``[]`` — a falsy value so ``if not records`` skips :meth:`~nemo_retriever.vdb.adt_vdb.VDB.run`.
When at least one row converts, returns ``[batch]`` with a single non-empty inner list
(never ``[[]]``, which would be truthy and could trip backends on an empty insert).
"""
if hasattr(rows, "to_dict"):
rows = rows.to_dict("records")
# Walrus: bind conversion once per row — a plain ``if f(row)`` + ``f(row)`` list comp
# would call _client_record_from_graph_row twice per row on large datasets.
# isinstance(row, dict): plain lists are not normalized like DataFrame rows; skip None/Series/etc.
inner = [
record
for row in rows or []
if isinstance(row, dict) and (record := _client_record_from_graph_row(row)) is not None
]
# Preserve legacy contract: no uploadable rows → [], not [[]].
return [inner] if inner else []
def _mapping(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
return value
if not isinstance(value, str) or not value.strip():
return {}
try:
parsed = json.loads(value)
except json.JSONDecodeError:
return {}
return parsed if isinstance(parsed, dict) else {}
def _normalize_hit(hit: dict[str, Any]) -> RetrievalHit:
"""Adapt LanceDB client hit shapes to Retriever hits."""
entity = hit.get("entity") if isinstance(hit.get("entity"), dict) else hit
source = _mapping(entity.get("source") or hit.get("source") or entity.get("source_metadata"))
if not source and isinstance(entity.get("source"), str):
source = {"source_id": entity["source"]}
content_metadata = _mapping(entity.get("content_metadata") or hit.get("content_metadata") or entity.get("metadata"))
source_id = _first_str(
source.get("source_id"),
source.get("source_name"),
entity.get("source_id"),
hit.get("source_id"),
hit.get("path"),
)
page_number = content_metadata.get("page_number") if isinstance(content_metadata, dict) else None
if page_number is None:
page_number = entity.get("page_number", hit.get("page_number"))
page_number = _optional_int(page_number)
path = Path(source_id) if source_id else None
pdf_basename = path.stem if path is not None else ""
normalized: RetrievalHit = {
"text": _first_str(entity.get("text"), entity.get("content"), hit.get("text")),
# Keep `metadata` as a native dict on the API boundary. The LanceDB
# storage layer JSON-encodes it on write (see `_json_str` in
# `vdb/lancedb.py`); we already parse it back on read in
# `LanceDB.retrieval`. Re-encoding it here forced every downstream
# consumer (`Retriever.query()` callers, the CLI, the SKILL.md jq
# recipe) to do its own `fromjson`/`json.loads` — and most didn't,
# producing silent `metadata.type == "?"` lookups.
"metadata": content_metadata,
"source": source_id,
"source_id": source_id,
"path": source_id,
"page_number": page_number,
"pdf_basename": pdf_basename,
"pdf_page": f"{pdf_basename}_{page_number}" if pdf_basename and page_number is not None else "",
}
for key in ("stored_image_uri", "content_type", "bbox_xyxy_norm", "_distance", "_score"):
if key in hit:
normalized[key] = hit[key]
elif key in entity:
normalized[key] = entity[key]
return normalized
def _hit_to_dict(hit: Any) -> dict[str, Any] | None:
if isinstance(hit, dict):
return hit
if isinstance(hit, Mapping):
return dict(hit)
if hasattr(hit, "to_dict"):
try:
converted = hit.to_dict()
except Exception:
return None
return converted if isinstance(converted, dict) else None
return None
[docs]
def normalize_retrieval_results(results: Any) -> list[list[RetrievalHit]]:
if results is None:
return []
if isinstance(results, dict):
results = [[results]]
normalized: list[list[RetrievalHit]] = []
for hits in results:
if isinstance(hits, dict):
hits = [hits]
normalized_hits: list[RetrievalHit] = []
for hit in hits:
hit_dict = _hit_to_dict(hit)
if hit_dict is not None:
normalized_hits.append(_normalize_hit(hit_dict))
normalized.append(normalized_hits)
return normalized