Source code for nemo_retriever.graph.abstract_operator
# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from abc import ABC, abstractmethod
import inspect
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from nemo_retriever.graph.pipeline_graph import Graph, Node
[docs]
class AbstractOperator(ABC):
"""Base class for all pipeline operators."""
def __init__(self, **kwargs: Any) -> None:
self._graph_init_kwargs = dict(kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
[docs]
@abstractmethod
def preprocess(self, data: Any, **kwargs: Any) -> Any: ...
[docs]
@abstractmethod
def process(self, data: Any, **kwargs: Any) -> Any: ...
[docs]
@abstractmethod
def postprocess(self, data: Any, **kwargs: Any) -> Any: ...
[docs]
def run(self, data: Any, **kwargs: Any) -> Any:
data = self.preprocess(data, **kwargs)
data = self.process(data, **kwargs)
data = self.postprocess(data, **kwargs)
return data
def __call__(self, data: Any, **kwargs: Any) -> Any:
"""Make operators directly usable as Ray ``map_batches`` callables."""
return self.run(data, **kwargs)
[docs]
def get_constructor_kwargs(self) -> dict[str, Any]:
"""Best-effort constructor kwargs for executor-side reconstruction."""
kwargs = dict(getattr(self, "_graph_init_kwargs", {}))
signature = inspect.signature(type(self).__init__)
for name, parameter in signature.parameters.items():
if name == "self" or parameter.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
continue
if name in kwargs:
continue
if hasattr(self, name):
kwargs[name] = getattr(self, name)
continue
private_name = f"_{name}"
if hasattr(self, private_name):
kwargs[name] = getattr(self, private_name)
return kwargs
def __rshift__(self, other: "AbstractOperator | Node") -> "Graph":
"""``operator_a >> operator_b`` — auto-wrap both in Nodes and chain them.
Returns a :class:`Graph` so the pipeline is immediately usable::
graph = op_a >> op_b >> op_c
"""
from nemo_retriever.graph.pipeline_graph import Node
left = Node(self)
# Delegate to Node.__rshift__ which returns a Graph
return left >> other