Source code for nemo_retriever.graph.abstract_operator

# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from abc import ABC, abstractmethod
import inspect
from typing import Any, TYPE_CHECKING

if TYPE_CHECKING:
    from nemo_retriever.graph.pipeline_graph import Graph, Node


[docs] class AbstractOperator(ABC): """Base class for all pipeline operators.""" def __init__(self, **kwargs: Any) -> None: self._graph_init_kwargs = dict(kwargs) for key, value in kwargs.items(): setattr(self, key, value)
[docs] @abstractmethod def preprocess(self, data: Any, **kwargs: Any) -> Any: ...
[docs] @abstractmethod def process(self, data: Any, **kwargs: Any) -> Any: ...
[docs] @abstractmethod def postprocess(self, data: Any, **kwargs: Any) -> Any: ...
[docs] def run(self, data: Any, **kwargs: Any) -> Any: data = self.preprocess(data, **kwargs) data = self.process(data, **kwargs) data = self.postprocess(data, **kwargs) return data
def __call__(self, data: Any, **kwargs: Any) -> Any: """Make operators directly usable as Ray ``map_batches`` callables.""" return self.run(data, **kwargs)
[docs] def get_constructor_kwargs(self) -> dict[str, Any]: """Best-effort constructor kwargs for executor-side reconstruction.""" kwargs = dict(getattr(self, "_graph_init_kwargs", {})) signature = inspect.signature(type(self).__init__) for name, parameter in signature.parameters.items(): if name == "self" or parameter.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): continue if name in kwargs: continue if hasattr(self, name): kwargs[name] = getattr(self, name) continue private_name = f"_{name}" if hasattr(self, private_name): kwargs[name] = getattr(self, private_name) return kwargs
def __rshift__(self, other: "AbstractOperator | Node") -> "Graph": """``operator_a >> operator_b`` — auto-wrap both in Nodes and chain them. Returns a :class:`Graph` so the pipeline is immediately usable:: graph = op_a >> op_b >> op_c """ from nemo_retriever.graph.pipeline_graph import Node left = Node(self) # Delegate to Node.__rshift__ which returns a Graph return left >> other