Source code for physicsnemo.core.registry
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from importlib.metadata import EntryPoint, entry_points
from typing import TYPE_CHECKING, Dict, List, Union
# Handle both stdlib and backport EntryPoint classes for isinstance checks.
# Some packages (e.g., mlflow, opentelemetry) import importlib_metadata which
# can cause entry points to be instances of importlib_metadata.EntryPoint
# instead of importlib.metadata.EntryPoint.
# For future maintainers who think, "I could clean this up ..."
# As much as I WANT to remove this, it is not worth it. If there is any
# usage of importlib_metadata in other packages it could break. It certainly
# breaks docstring testing in CI. And it was very difficult to debug,
# since the breaking import that changes the EntryPoint definition typically
# shows up OUTSIDE of the physicsnemo code base, you'll go crazy trying to debug.
try:
import importlib_metadata
_ENTRYPOINT_TYPES: tuple = (EntryPoint, importlib_metadata.EntryPoint)
except ImportError:
_ENTRYPOINT_TYPES = (EntryPoint,)
if TYPE_CHECKING:
from physicsnemo.core.module import Module
# This model registry follows conventions similar to fsspec,
# https://github.com/fsspec/filesystem_spec/blob/master/fsspec/registry.py#L62C2-L62C2
# Tutorial on entrypoints: https://amir.rachum.com/blog/2017/07/28/python-entry-points/
# Borg singleton pattern: https://stackoverflow.com/questions/1318406/why-is-the-borg-pattern-better-than-the-singleton-pattern-in-python
[docs]
class ModelRegistry:
_shared_state = {"_model_registry": None}
def __new__(cls, *args, **kwargs):
obj = super(ModelRegistry, cls).__new__(cls)
obj.__dict__ = cls._shared_state
if cls._shared_state["_model_registry"] is None:
cls._shared_state["_model_registry"] = cls._construct_registry()
return obj
@staticmethod
def _construct_registry() -> Dict[str, type["Module"] | EntryPoint]:
registry: Dict[str, type["Module"] | EntryPoint] = {}
entrypoints = entry_points(group="physicsnemo.models")
for entry_point in entrypoints:
registry[entry_point.name] = entry_point
# Pull in any modulus models for backwards compatibility
entrypoints = entry_points(group="modulus.models")
for entry_point in entrypoints:
if entry_point.name not in registry:
# Add depricated warning
warnings.warn(
f"Model {entry_point.name} is being loaded from the 'modulus.models' group. "
f"This probably means it is being exposed from a package that has not yet been "
f"updated to use the 'physicsnemo.models' group. This group may be removed in a "
f"future release. Please contact the package maintainer to update the entry point.",
DeprecationWarning,
stacklevel=2,
)
registry[entry_point.name] = entry_point
return registry
[docs]
def register(self, model: type["Module"], name: Union[str, None] = None) -> None:
"""
Registers a physicsnemo model class in the model registry under the provided name. If no name
is provided, the model's name (from its `__name__` attribute) is used. If the
name is already in use, raises a ValueError.
Parameters
----------
model : physicsnemo.core.Module
The model class to be registered.
name : str, optional
The name to register the model under. If None, the model class name
is used.
Raises
------
ValueError
If the provided name is already in use in the registry.
Examples
--------
Example 1: Register a model class using its default name (from ``__name__``):
>>> from physicsnemo.core import Module, ModelRegistry
>>> # Define a custom model class
>>> class MyCustomModel(Module):
... def __init__(self, hidden_size):
... super().__init__()
... self.hidden_size = hidden_size
...
... def forward(self, x):
... return x
>>> # Get the registry instance
>>> registry = ModelRegistry()
>>> # Register the model without specifying a name
>>> # The class name 'MyCustomModel' will be used automatically
>>> registry.register(MyCustomModel)
>>> # Retrieve the model class from the registry
>>> ModelClass = registry.factory('MyCustomModel')
>>> # Instantiate the model
>>> model = ModelClass(hidden_size=128)
"""
# If no name provided, use the model class name
if name is None:
name = model.__name__
# Check if name already in use
if name in self._model_registry:
raise ValueError(
f"Name {name} already in use.\n"
f"Current registered models are: {sorted(self.list_models())}"
)
# Add this class to the dict of model registry
self._model_registry[name] = model
[docs]
def factory(self, name: str) -> type["Module"]:
"""
Returns a registered model class given its name.
Parameters
----------
name : str
The name of the registered model.
Returns
-------
model : physicsnemo.core.Module
The registered model.
Raises
------
KeyError
If no model is registered under the provided name.
"""
model = self._model_registry.get(name)
if model is not None:
if isinstance(model, _ENTRYPOINT_TYPES):
model = model.load()
# Update the registry with the loaded object:
self._model_registry[name] = model
return model
raise KeyError(
f"No model is registered under the name {name}. "
f"Current registered models are: {sorted(self.list_models())}"
)
[docs]
def list_models(self) -> List[str]:
"""
Returns a list of the names of all models currently registered in the registry.
Returns
-------
List[str]
A list of the names of all registered models. The order of the names is not
guaranteed to be consistent.
"""
return list(self._model_registry.keys())
def __clear_registry__(self):
# NOTE: This is only used for testing purposes
self._model_registry = {}
def __restore_registry__(self):
# NOTE: This is only used for testing purposes
self._model_registry = self._construct_registry()