# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Graph operator for persisting post-embedding row images to storage."""
from __future__ import annotations
import base64
import binascii
import hashlib
import logging
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import fsspec
import pandas as pd
from nemo_retriever.graph.abstract_operator import AbstractOperator
from nemo_retriever.graph.cpu_operator import CPUOperator
logger = logging.getLogger(__name__)
_FORMAT_ALIASES = {"jpg": "jpeg"}
_SUPPORTED_FORMATS = {"png", "jpeg"}
def _normalize_image_format(image_format: str) -> str:
fmt = str(image_format or "png").strip().lower()
fmt = _FORMAT_ALIASES.get(fmt, fmt)
if fmt not in _SUPPORTED_FORMATS:
raise ValueError(f"Unsupported image_format: {image_format!r}. Supported formats: png, jpeg")
return fmt
def _sniff_image_format(raw: bytes) -> str | None:
if raw.startswith(b"\x89PNG\r\n\x1a\n"):
return "png"
if raw.startswith(b"\xff\xd8\xff"):
return "jpeg"
return None
def _decode_image_b64(value: Any) -> bytes | None:
if not isinstance(value, str) or not value.strip():
return None
payload = value.strip()
if payload.lower().startswith("data:") and "," in payload:
payload = payload.split(",", 1)[1]
try:
return base64.b64decode(payload)
except binascii.Error as exc:
logger.warning("Skipping store row with invalid _image_b64 payload: %s", exc)
return None
def _build_object_key(*, raw: bytes, extension: str) -> str:
image_hash = hashlib.sha1(raw).hexdigest()
return f"{image_hash}.{extension}"
def _join_storage_uri(storage_uri: str, object_key: str) -> str:
return f"{str(storage_uri).rstrip('/')}/{object_key.lstrip('/')}"
def _stored_uri(dest_uri: str) -> str:
parsed = urlparse(dest_uri)
if parsed.scheme:
return dest_uri
return Path(dest_uri).resolve().as_uri()
def _row_image_b64_with_source(row: pd.Series) -> tuple[Any, bool]:
value = row.get("_image_b64")
if isinstance(value, str) and value.strip():
return value, False
page_image = row.get("page_image")
if isinstance(page_image, dict):
return page_image.get("image_b64"), True
return None, False
def _row_image_represents_page(row: pd.Series, *, from_page_image: bool) -> bool:
if from_page_image:
return True
content_type = row.get("_content_type")
if not isinstance(content_type, str) or not content_type.strip():
return True
return content_type == "text"
def _store_row_images(
df: pd.DataFrame,
*,
storage_uri: str,
storage_options: dict[str, Any] | None = None,
image_format: str = "png",
strip_base64: bool = True,
) -> pd.DataFrame:
"""Return a copy of *df* with ``_stored_image_uri`` set for stored rows."""
if df.empty or ("_image_b64" not in df.columns and "page_image" not in df.columns):
return df
out = df.copy()
fallback_format = _normalize_image_format(image_format)
fsspec_options = dict(storage_options or {})
for idx, row in out.iterrows():
image_b64, from_page_image = _row_image_b64_with_source(row)
raw = _decode_image_b64(image_b64)
if raw is None:
continue
extension = _sniff_image_format(raw) or fallback_format
object_key = _build_object_key(raw=raw, extension=extension)
dest_uri = _join_storage_uri(storage_uri, object_key)
try:
with fsspec.open(dest_uri, mode="wb", **fsspec_options) as f:
f.write(raw)
except Exception as exc:
raise RuntimeError(f"Failed to store image for row {idx!r} to {dest_uri!r}: {exc}") from exc
stored_uri = _stored_uri(dest_uri)
out.at[idx, "_stored_image_uri"] = stored_uri
if strip_base64:
if "_image_b64" in out.columns:
out.at[idx, "_image_b64"] = None
page_image = row.get("page_image")
if isinstance(page_image, dict):
updated_page_image = dict(page_image)
updated_page_image["image_b64"] = None
if _row_image_represents_page(row, from_page_image=from_page_image):
updated_page_image["stored_image_uri"] = stored_uri
out.at[idx, "page_image"] = updated_page_image
return out
[docs]
class StoreOperator(AbstractOperator, CPUOperator):
"""Persist row-level image payloads to local or object storage.
The operator consumes ``_image_b64`` produced by content transforms and
writes ``_stored_image_uri`` for downstream vector DB upload. By default it
clears inline base64 after successful writes to avoid carrying page-sized
payloads into VDB upload.
"""
def __init__(self, *, params: Any = None) -> None:
super().__init__()
self._params = params
[docs]
def preprocess(self, data: Any, **kwargs: Any) -> Any:
return data
[docs]
def process(self, data: Any, **kwargs: Any) -> Any:
if not isinstance(data, pd.DataFrame):
return data
if hasattr(self._params, "model_dump"):
store_kwargs = self._params.model_dump(mode="python")
elif isinstance(self._params, dict):
store_kwargs = self._params
else:
store_kwargs = {}
return _store_row_images(
data,
storage_uri=store_kwargs.get("storage_uri", "stored_images"),
storage_options=store_kwargs.get("storage_options") or {},
image_format=store_kwargs.get("image_format", "png"),
strip_base64=bool(store_kwargs.get("strip_base64", True)),
)
[docs]
def postprocess(self, data: Any, **kwargs: Any) -> Any:
return data