core.typed_torch#

Utilities for improved type hinting with torch interfaces.

Module Contents#

Classes#

_Module

Protocol allowing us to unwrap forward.

Functions#

apply_module

Returns the provided module unchanged, but with correct type hints.

not_none

Asserts that the provided value is not None and returns it.

copy_signature

Decorator to copy the signature from one function to another.

Data#

API#

core.typed_torch.P#

‘ParamSpec(…)’

core.typed_torch.R_co#

‘TypeVar(…)’

core.typed_torch.T#

‘TypeVar(…)’

class core.typed_torch._Module#

Bases: typing.Generic[core.typed_torch.P, core.typed_torch.R_co], typing.Protocol

Protocol allowing us to unwrap forward.

forward(
*args: core.typed_torch.P,
**kwargs: core.typed_torch.P,
) core.typed_torch.R_co#

Forward method of the matching torch.nn.Module.

core.typed_torch.apply_module(
m: core.typed_torch._Module[core.typed_torch.P, core.typed_torch.R_co],
*,
check_subclass: bool = True,
) collections.abc.Callable[core.typed_torch.P, core.typed_torch.R_co]#

Returns the provided module unchanged, but with correct type hints.

Parameters:
  • m – An instance of a subclass of torch.nn.Module.

  • check_subclass – If True, checks that m is a subclass of torch.nn.Module and raises a TypeError if not.

Returns:

That module unchanged, but with correct type hints.

core.typed_torch.not_none(value: core.typed_torch.T | None) core.typed_torch.T#

Asserts that the provided value is not None and returns it.

Parameters:

value – An optional value.

Returns:

The provided value, guaranteed to be not None.

core.typed_torch.R_src#

‘TypeVar(…)’

core.typed_torch.R_dst#

‘TypeVar(…)’

core.typed_torch.P_src#

‘ParamSpec(…)’

core.typed_torch.P_dst#

‘ParamSpec(…)’

core.typed_torch.First_dst#

‘TypeVar(…)’

core.typed_torch.copy_signature(
source: collections.abc.Callable[..., Any],
/,
*,
handle_return_type: Literal[preserve, overwrite] = 'preserve',
handle_first_src_param: Literal[copy, skip] = 'copy',
handle_first_dst_param: Literal[preserve, drop] = 'drop',
)#

Decorator to copy the signature from one function to another.

Similar to functools.wraps, but preserves the signature instead of the metadata. Useful when writing adapter/wrapper functions that forward arguments to another function, as in:

  def function_with_lots_of_args(
      a: int,
      b: str,
      c: float,
      ...
  ) -> BigObject:
      ...

  @copy_signature(function_with_lots_of_args)
  def convenient_wrapper(*args: Any, **kwargs: Any) -> str:
      return function_with_lots_of_args(*args, **kwargs).to_string()
Parameters:
  • source – The function or callable from which to copy the signature.

  • handle_return_type – How to handle the return type of the decorated function. ‘preserve’ to keep the decorated function’s return type (the default, since many wrappers are specifically written to return a different type), or ‘overwrite’ to copy the source function’s return type as well.

  • handle_first_src_param – Whether to include the first parameter of the source function. ‘copy’ to include it in the decorated function’s signature (the default), ‘skip’ to exclude it (useful for removing ‘self’ or ‘cls’).

  • handle_first_dst_param

    Whether to keep the first parameter of the decorated function. ‘drop’ to overwrite it just like any other parameter (the default), or ‘preserve’ to keep it in the decorated function’s signature (useful for preserving ‘self’ or ‘cls’).

    Returns: A decorator that copies the signature from source to the decorated function.