Source code for nemo_retriever.graph.operator_resolution

# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from nemo_retriever.graph.abstract_operator import AbstractOperator
from nemo_retriever.graph.operator_archetype import ArchetypeOperator
from nemo_retriever.graph.pipeline_graph import Graph, Node
from nemo_retriever.utils import ray_resource_hueristics as _rrh
from nemo_retriever.utils.ray_resource_hueristics import ClusterResources, Resources


[docs] def resolve_operator_class( operator_class: type[AbstractOperator], resources: ClusterResources | Resources, operator_kwargs: dict | None = None, ) -> type[AbstractOperator]: if issubclass(operator_class, ArchetypeOperator): return operator_class.resolve_operator_class(resources, operator_kwargs=operator_kwargs) return operator_class
[docs] def resolve_operator_kwargs( operator_class: type[AbstractOperator], resolved_class: type[AbstractOperator], operator_kwargs: dict | None = None, ) -> dict: if issubclass(operator_class, ArchetypeOperator): return operator_class.variant_operator_kwargs(resolved_class, operator_kwargs=operator_kwargs) return dict(operator_kwargs or {})
[docs] def resolve_graph( graph: Graph, resources: ClusterResources | Resources, ) -> Graph: resolved = Graph() visited: dict[int, Node] = {} def _clone(node: Node) -> Node: node_id = id(node) if node_id in visited: return visited[node_id] operator = node.operator if isinstance(operator, ArchetypeOperator): operator = type(operator)(**node.operator_kwargs) resolved_class = resolve_operator_class(node.operator_class, resources, operator_kwargs=node.operator_kwargs) resolved_kwargs = resolve_operator_kwargs(node.operator_class, resolved_class, node.operator_kwargs) cloned = Node( operator, name=node.name, operator_class=resolved_class, operator_kwargs=resolved_kwargs, ) visited[node_id] = cloned for child in node.children: cloned.children.append(_clone(child)) return cloned for root in graph.roots: resolved.roots.append(_clone(root)) if graph._tail is not None: resolved._tail = visited.get(id(graph._tail)) return resolved
[docs] def resolve_graph_for_local_execution(graph: Graph) -> Graph: return resolve_graph(graph, _rrh.gather_local_resources())