# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from typing import (
Any,
Callable,
Concatenate,
Literal,
Optional,
ParamSpec,
Protocol,
Type,
TypeVar,
Union,
cast,
overload,
runtime_checkable,
)
import fiddle as fdl
from fiddle.experimental import auto_config as _auto_config
from rich.pretty import Pretty
from rich.table import Table
from nemo_run.config import Config, Partial
from nemo_run.core.execution.base import Executor
from nemo_run.core.frontend.console.api import CONSOLE, CustomConfigRepr
F = TypeVar("F", bound=Callable[..., Any])
T = TypeVar("T")
P = ParamSpec("P")
ROOT_TASK_NAMESPACE = "nemo_run.task"
ROOT_TASK_FACTORY_NAMESPACE = "nemo_run.task_factory"
ROOT_TYPER_NAMESPACE = "nemo_run.typer"
DEFAULT_NAME = "default"
AUTOBUILD_CLASSES = (Executor,)
def default_autoconfig_buildable(
fn: Callable[P, T],
cls: Type[Union[Partial, Config]],
*args: P.args,
**kwargs: P.kwargs,
) -> Config[T] | Partial[T] | list[Config[T]] | list[Partial[T]]:
def exemption_policy(cfg):
return cfg in [Partial, Config] or getattr(cfg, "__auto_config__", False)
_output = _auto_config.auto_config(
fn,
experimental_allow_control_flow=False,
experimental_allow_dataclass_attribute_access=True,
experimental_exemption_policy=exemption_policy,
).as_buildable(*args, **kwargs)
if isinstance(_output, list):
return [fdl.cast(cls, item) for item in _output]
return fdl.cast(cls, _output)
@overload
def autoconvert(
fn: Callable[P, Config[T]],
*,
partial: bool = False,
) -> Callable[P, Config[T]]: ...
@overload
def autoconvert( # type: ignore
fn: Callable[P, Partial[T]],
*,
partial: bool = False,
) -> Callable[P, Partial[T]]: ...
@overload
def autoconvert(
fn: Callable[P, T],
*,
partial: bool = False,
) -> Callable[P, Config[T]]: ...
@overload
def autoconvert(
*,
partial: Literal[True] = ...,
) -> Callable[
[Callable[P, T] | Callable[P, Config[T]] | Callable[P, Partial[T]]],
Callable[P, Partial[T]],
]: ...
@overload
def autoconvert(
*,
partial: Literal[False] = False,
) -> Callable[
[Callable[P, T] | Callable[P, Config[T]] | Callable[P, Partial[T]]],
Callable[P, Config[T]],
]: ...
[docs]
def autoconvert(
fn: Optional[Callable[P, T] | Callable[P, Config[T]] | Callable[P, Partial[T]]] = None,
*,
partial: bool = False,
to_buildable_fn: Callable[
Concatenate[Callable[P, T], Type[Union[Partial, Config]], P],
Config[T] | Partial[T],
] = default_autoconfig_buildable,
) -> (
Callable[P, Config[T] | Partial[T]]
| Callable[
[Callable[P, T] | Callable[P, Config[T]] | Callable[P, Partial[T]]],
Callable[P, Config[T] | Partial[T]],
]
):
"""
The autoconvert function is a powerful and flexible decorator for Python functions that can
modify the behavior of the function it decorates by converting the returned object in a nested manner to:
run.Config (when partial is False) or run.Partial (when partial is True).
This conversion is done by a provided conversion function `to_buildable_fn`, which defaults to `default_autoconfig_buildable`.
Under the hood, it uses `fiddle's autoconfig <https://fiddle.readthedocs.io/en/latest/api_reference/autoconfig.html>`_ to parse the function's AST and convert objects to their run.Config/run.Partial counterparts.
You can use it in two different ways:
- Directly as a decorator for a function you define:
.. code-block:: python
@autoconvert
def my_func(param1: int, param2: str) -> MyType:
return MyType(param1=param1, param2=param2)
This will return `run.Config(MyType, param1=param1, param2=param2)` when called, assuming that
`partial=False` (otherwise, it would be a run.Partial instance).
- Indirectly, as a way to convert an existing function:
.. code-block:: python
def my_func(param1: int, param2: str) -> MyType:
return MyType(param1=param1, param2=param2)
my_new_func = autoconvert(partial=True)(my_func)
Now, calling `my_new_func` will actually return `run.Partial(MyType, param1=param1, param2=param2)` rather
than a `MyType` instance.
Parameters:
- fn:
The function to be decorated. This parameter is optional, and if not provided,
`autoconvert` acts as a decorator factory. Defaults to None.
- partial:
A boolean flag that indicates whether the return type of `fn` should be converted
to Partial[T] (if True) or Config[T] (if False). Defaults to False.
- to_buildable_fn:
The conversion function to be used for the desired output type. This function
takes another function and any positional and keyword arguments and returns an
instance of either Config[T] or Partial[T]. By default, it uses
`default_autoconfig_buildable`.
"""
def wrapper(
fn: Callable[P, T] | Callable[P, Config[T]] | Callable[P, Partial[T]],
) -> Callable[P, Config[T] | Partial[T]]:
@wraps(fn)
def autobuilder(*args: P.args, **kwargs: P.kwargs) -> Config[T] | Partial[T]:
return to_buildable_fn(
cast(Callable[P, T], fn),
Partial if partial else Config,
*args,
**kwargs,
)
autobuilder.wrapped = fn # type: ignore
autobuilder.__auto_config__ = True # type: ignore
return autobuilder
return wrapper if fn is None else wrapper(fn)
def dryrun_fn(
configured_fn: Partial,
executor: Optional[Executor] = None,
build: bool = False,
) -> None:
if not isinstance(configured_fn, (Config, Partial)):
raise TypeError(f"Need a run Partial for dryrun. Got {configured_fn}.")
fn = configured_fn.__fn_or_cls__
console = CONSOLE
console.print(f"[bold cyan]Dry run for task {fn.__module__}:{fn.__name__}[/bold cyan]")
table_resolved_args = Table(show_header=True, header_style="bold magenta")
table_resolved_args.add_column("Argument Name", style="dim", width=20)
table_resolved_args.add_column("Resolved Value", width=60)
for arg_name in dir(configured_fn):
repr = CustomConfigRepr(getattr(configured_fn, arg_name))
table_resolved_args.add_row(arg_name, Pretty(repr))
console.print("[bold green]Resolved Arguments[/bold green]")
console.print(table_resolved_args)
if executor:
console.print("[bold green]Executor[/bold green]")
table_executor = Table(show_header=False, header_style="bold magenta")
table_executor.add_column("Executor")
table_executor.add_row(Pretty(CustomConfigRepr(executor)))
console.print(table_executor)
if build:
fdl.build(configured_fn)
@runtime_checkable
class AutoConfigProtocol(Protocol):
def __auto_config__(self) -> bool: ...