core.typed_torch#

Utilities for improved type hinting with torch interfaces.

Module Contents#

Classes#

_Module

Protocol allowing us to unwrap forward.

Functions#

apply_module

Returns the provided module unchanged, but with correct type hints.

not_none

Asserts that the provided value is not None and returns it.

Data#

P

R_co

T

API#

core.typed_torch.P#

‘ParamSpec(…)’

core.typed_torch.R_co#

‘TypeVar(…)’

core.typed_torch.T#

‘TypeVar(…)’

class core.typed_torch._Module#

Bases: typing.Generic[core.typed_torch.P, core.typed_torch.R_co], typing.Protocol

Protocol allowing us to unwrap forward.

forward(
*args: core.typed_torch.P,
**kwargs: core.typed_torch.P,
) core.typed_torch.R_co#

Forward method of the matching torch.nn.Module.

core.typed_torch.apply_module(
m: core.typed_torch._Module[core.typed_torch.P, core.typed_torch.R_co],
*,
check_subclass: bool = True,
) collections.abc.Callable[core.typed_torch.P, core.typed_torch.R_co]#

Returns the provided module unchanged, but with correct type hints.

Parameters:
  • m – An instance of a subclass of torch.nn.Module.

  • check_subclass – If True, checks that m is a subclass of torch.nn.Module and raises a TypeError if not.

Returns:

That module unchanged, but with correct type hints.

core.typed_torch.not_none(value: core.typed_torch.T | None) core.typed_torch.T#

Asserts that the provided value is not None and returns it.

Parameters:

value – An optional value.

Returns:

The provided value, guaranteed to be not None.