PhysicsNeMo Active Learning#

Developing Active Learning Workflows#

For a high level overview and understanding of how to construct active learning workflows using PhysicsNeMo, users should consult the User Guide . The guide will motivate the need for active learning, the abstraction provided by PhysicsNeMo, and some additional tips for developing custom components like querying and labeling strategies.

API Reference#

Protocols#

This module contains base classes for active learning protocols.

These are protocols intended to be abstract, and importing these classes specifically is intended to either be subclassed, or for type annotations.

Protocol Architecture#

Python typing.Protocol s are used for structural typing: essentially, they are used to describe an expected interface in a way that is helpful for static type checkers to make sure concrete implementations provide everything that is needed for a workflow to function. typing.Protocol s are not actually enforced at runtime, and inheritance is not required for them to function: as long as the implementation provides the expected attributes and methods, they will be compatible with the protocol.

The active learning framework is built around several key protocol abstractions that work together to orchestrate the active learning workflow:

Core Infrastructure Protocols:
  • AbstractQueue[T] - Generic queue protocol for passing data between components

  • DataPool[T] - Protocol for data reservoirs that support appending and sampling

  • ActiveLearningProtocol - Base protocol providing common interface for all AL strategies

Strategy Protocols (inherit from ActiveLearningProtocol):
  • QueryStrategy - Defines how to select data points for labeling

  • LabelStrategy - Defines processes for adding ground truth labels to unlabeled data

  • MetrologyStrategy - Defines procedures that assess model improvements beyond validation metrics

Model Interface Protocols:
  • TrainingProtocol - Interface for training step functions

  • ValidationProtocol - Interface for validation step functions

  • InferenceProtocol - Interface for inference step functions

  • TrainingLoop - Interface for complete training loop implementations

  • LearnerProtocol - Comprehensive interface for learner modules (combines training/validation/inference)

Orchestration Protocol:
  • DriverProtocol - Main orchestrator that coordinates all components in the active learning loop

Active Learning Workflow#

The typical active learning workflow orchestrated by DriverProtocol follows this sequence:

  1. Training Phase: Use LearnerProtocol or TrainingLoop to train the model on training_pool

  2. Metrology Phase (optional): Apply MetrologyStrategy instances to assess model performance

  3. Query Phase: Apply QueryStrategy instances to select samples from unlabeled_poolquery_queue

  4. Labeling Phase (optional): Apply LabelStrategy instances to label queued samples → label_queue

  5. Data Integration: Move labeled data from label_queue to training_pool

Type Parameters#

  • T: Data structure containing both inputs and ground truth labels

  • S: Data structure containing only inputs (no ground truth labels)


class physicsnemo.active_learning.protocols.AbstractQueue(*args, **kwargs)[source]#

Bases: Protocol[T]

Defines a generic queue protocol for data that is passed between active learning components.

This can be a simple local queue.Queue, or a more sophisticated distributed queue system.

The primary use case for this is to allow a query strategy to enqueue some data structure for the labeling strategy to consume, and once the labeling is done, enqueue to a data serialization workflow. While there is no explcit restriction on the type of queue that is implemented, a reasonable assumption to make would be a FIFO queue, unless otherwise specified by the concrete implementation.

Optional Serialization Methods#

Implementations may optionally provide to_list() and from_list() methods for checkpoint serialization. If not provided, the queue will be serialized using torch.save() as a fallback.

Type Parameters#

T

The type of items that will be stored in the queue.

See also

QueryStrategy

Enqueues data to be labeled

LabelStrategy

Dequeues data for labeling and enqueues labeled data

DriverProtocol

Uses queues to pass data between strategies

empty() bool[source]#

Method to check if the queue is empty/has been depleted.

Returns:

True if the queue is empty, False otherwise.

Return type:

bool

get() T[source]#

Method to get a data structure from the queue.

This method should remove the data structure from the queue, and return it to a consumer.

Returns:

The data structure that was removed from the queue.

Return type:

T

put(
item: T,
) None[source]#

Method to put a data structure into the queue.

Parameters:

item (T) – The data structure to put into the queue.

class physicsnemo.active_learning.protocols.ActiveLearningPhase(
value,
names=<not given>,
*values,
module=None,
qualname=None,
type=None,
start=1,
boundary=None,
)[source]#

Bases: StrEnum

An enumeration of the different phases of the active learning workflow.

This is primarily used in the metadata for restarting an ongoing active learning experiment.

See also

ActiveLearningProtocol

Base protocol for active learning strategies

DriverProtocol

Main orchestrator that uses this enumeration

DATA_INTEGRATION = 'data_integration'#
LABELING = 'labeling'#
METROLOGY = 'metrology'#
QUERY = 'query'#
TRAINING = 'training'#
class physicsnemo.active_learning.protocols.ActiveLearningProtocol(*args: Any, **kwargs: Any)[source]#

Bases: Protocol

This protocol acts as a basis for all active learning protocols.

This ensures that all protocols have some common interface, for example the ability to attach() to another object for scope management.

__protocol_name__#

The name of the protocol. This is primarily used for repr and str f-strings. This should be defined by concrete implementations.

Type:

str

_args#

A dictionary of arguments that were used to instantiate the protocol. This is used for serialization and deserialization of the protocol, and follows the same pattern as the _args attribute of physicsnemo.Module.

Type:

dict[str, Any]

attach(self, other: object) None:[source]#

This method is used to attach the current object to another, allowing the protocol to access the attached object’s scope. The use case for this is to allow a protocol access to the driver’s scope to access dataset, model, etc. as needed. This needs to be implemented by concrete implementations.

is_attached: bool

Whether the current object is attached to another object. This is left abstract, as it depends on how attach() is implemented.

logger: Logger

The logger for this protocol. This is used to log information about the protocol’s progress.

_setup_logger(self) None:[source]#

This method is used to setup the logger for the protocol. The default implementation is to configure the logger similarly to how physicsnemo loggers are configured.

See also

QueryStrategy

Query strategy protocol (child)

LabelStrategy

Label strategy protocol (child)

MetrologyStrategy

Metrology strategy protocol (child)

DriverProtocol

Main orchestrator that uses these protocols

attach(other: object) None[source]#

This method is used to attach another object to the current protocol, allowing the attached object to access the scope of this protocol. The primary reason for this is to allow the protocol to access things like the dataset, the learner model, etc. as needed.

Example use cases would be for a query strategy to access the unlabeled_pool; for a metrology strategy to access the validation_pool, and for any strategy to be able to access the surrogate/learner model.

This method can be as simple as setting self.driver = other, but is left abstract in case there are other potential use cases where multiple protocols could share information.

Parameters:

other (object) – The object to attach to.

property checkpoint_dir: Path#

Utility property for strategies to conveniently access the checkpoint directory.

This is useful for (de)serializing data tied to checkpointing.

Returns:

The checkpoint directory, which includes the active learning step index.

Return type:

Path

Raises:

RuntimeError – If the strategy is not attached to a driver yet.

property is_attached: bool#

Property to check if the current object is already attached.

This is left abstract, as it depends on how attach is implemented.

Returns:

True if the current object is attached, False otherwise.

Return type:

bool

property logger: Logger#

Property to access the logger for this protocol.

If the logger has not been configured yet, the property will call the _setup_logger method to configure it.

Returns:

The logger for this protocol.

Return type:

Logger

property strategy_dir: Path#

Returns the directory where the underlying strategy can use to persist data.

Depending on the strategy abstraction, further nesting may be required (e.g active learning step index, phase, etc.).

Returns:

The directory where the metrology strategy will persist its records.

Return type:

Path

Raises:

RuntimeError – If the metrology strategy is not attached to a driver yet.

class physicsnemo.active_learning.protocols.DataPool(*args, **kwargs)[source]#

Bases: Protocol[T]

An abstract protocol for some reservoir of data that is used for some part of active learning, parametrized such that it will return data structures of an arbitrary type T.

All methods are left abstract, and need to be defined by concrete implementations. For the most part, a torch.utils.data.Dataset would match this protocol, provided that it implements the append() method which will allow data to be persisted to a filesystem.

__getitem__(self, index: int) T:[source]#

Method to get a single data structure from the data pool.

__len__(self) int:[source]#

Method to get the length of the data pool.

__iter__(self) Iterator[T]:[source]#

Method to iterate over the data pool.

append(self, item: T) None:[source]#

Method to append a data structure to the data pool.

See also

DriverProtocol

Uses data pools for training, validation, and unlabeled data

AbstractQueue

Queue protocol for passing data between components

append(
item: T,
) None[source]#

Method to append a data structure to the data pool.

For persistent storage pools, this will actually mean that the item is serialized to a filesystem.

Parameters:

item (T) – The data structure to append to the data pool.

class physicsnemo.active_learning.protocols.DriverProtocol[source]#

Bases: object

This protocol specifies the expected interface for an active learning driver: for a concrete implementation, refer to the driver module instead. The specification is provided mostly as a reference, and for ease of type hinting to prevent circular imports.

learner#

The learner module that will be used as the surrogate within the active learning loop.

Type:

LearnerProtocol

query_strategies#

The query strategies that will be used for selecting data points to label. A list of QueryStrategy instances can be included, and will sequentially be used to populate the query_queue that passes samples over to labeling.

Type:

list

query_queue#

The queue containing data samples to be labeled. QueryStrategy instances should enqueue samples to this queue.

Type:

AbstractQueue

label_strategy#

The label strategy that will be used for labeling data points. In contrast to the other strategies, only a single label strategy is supported. This strategy will consume the query_queue and enqueue labeled data to the label_queue.

Type:

LabelStrategy or None

label_queue#

The queue containing freshly labeled data. LabelStrategy instances should enqueue labeled data to this queue, and the driver will subsequently serialize data contained within this queue to a persistent format.

Type:

AbstractQueue or None

metrology_strategies#

The metrology strategies that will be used for assessing the performance of the surrogate. A list of MetrologyStrategy instances can be included, and will sequentially be used to populate the metrology_queue that passes data over to the learner.

Type:

list or None

training_pool#

The pool of data to be used for training. This data will be used to train the underlying model, and is assumed to be mutable in that additional data can be added to the pool over the course of active learning.

Type:

DataPool

validation_pool#

The pool of data to be used for validation. This data will be used for both conventional validation, as well as for metrology. This dataset is considered to be immutable, and should not be modified over the course of active learning. This dataset is considered optional, as both validation and metrology are.

Type:

DataPool or None

unlabeled_pool#

An optional pool of data to be used for querying and labeling. If supplied, this dataset can be depleted by a query strategy to select data points for labeling. In principle, this could also represent a generative model, i.e. not just a static dataset, but at a high level represents a distribution of data.

Type:

DataPool or None

See also

QueryStrategy

Query strategy protocol

LabelStrategy

Label strategy protocol

MetrologyStrategy

Metrology strategy protocol

LearnerProtocol

Learner protocol

DataPool

Data pool protocol

AbstractQueue

Queue protocol

active_learning_step(
*args: Any,
**kwargs: Any,
) None[source]#

Implements the active learning step.

This step performs a single pass of the active learning loop, with the intended order being: training, metrology, query, labeling, with the metrology and labeling steps being optional.

Parameters:
  • args (Any) – Additional arguments to pass to the method.

  • kwargs (Any) – Additional keyword arguments to pass to the method.

attach_strategies() None[source]#

Attaches all provided strategies.

This method relies on the attach method of the strategies, which will subsequently give the strategy access to the driver’s scope.

Example use cases would be for any strategy (apart from label strategy) to access the underlying model (LearnerProtocol); for a query strategy to access the unlabeled_pool; for a metrology strategy to access the validation_pool.

label_queue: AbstractQueue[T] | None#
label_strategy: LabelStrategy | None#
learner: LearnerProtocol#
metrology_strategies: list[MetrologyStrategy] | None#
query_queue: AbstractQueue[T]#
query_strategies: list[QueryStrategy]#
training_pool: DataPool[T]#
unlabeled_pool: DataPool[T] | None#
validation_pool: DataPool[T] | None#
class physicsnemo.active_learning.protocols.InferenceProtocol(*args, **kwargs)[source]#

Bases: Protocol

This protocol defines the interface for inference steps: given a model and some input data, return the output of the model’s forward pass.

A concrete implementation can simply be a function with a signature that matches what is defined in __call__().

See also

LearnerProtocol

Learner protocol with an inference_step method

QueryStrategy

Uses inference for query strategies

MetrologyStrategy

Uses inference for metrology strategies

class physicsnemo.active_learning.protocols.LabelStrategy(*args: Any, **kwargs: Any)[source]#

Bases: ActiveLearningProtocol

This protocol defines a label strategy for active learning.

A label strategy is responsible for labeling data points; this may be an simple Python function for demonstrating a concept, or an external, potentially time consuming and complex, process.

__is_external_process__#

Whether the label strategy is running in an external process.

Type:

bool

__provides_fields__#

The fields that the label strategy provides. This should be set by concrete implementations, and should be used to write and map labeled data to fields within the data structure T.

Type:

set or None

See also

ActiveLearningProtocol

Base protocol for all active learning strategies

AbstractQueue

Queue protocol for dequeuing and enqueuing data

QueryStrategy

Produces queued data for labeling

DriverProtocol

Orchestrates the label strategy

label(
queue_to_label: AbstractQueue[T],
serialize_queue: AbstractQueue[T],
*args: Any,
**kwargs: Any,
) None[source]#

Method that implements the logic behind labeling data.

This method should be implemented by concrete implementations, and assume that an active learning driver will pass a queue for this method to dequeue data to be labeled.

Parameters:
  • queue_to_label (AbstractQueue[T]) – Queue containing data structures to be labeled. Generally speaking, this should be passed over after running query strateg(ies).

  • serialize_queue (AbstractQueue[T]) – Queue for enqueing labeled data to be serialized.

  • args (Any) – Additional arguments to pass to the method.

  • kwargs (Any) – Additional keyword arguments to pass to the method.

class physicsnemo.active_learning.protocols.LearnerProtocol[source]#

Bases: object

This protocol represents the learner part of an active learning algorithm.

This corresponds to a set of trainable parameters that are optimized, and subsequently used for inference and evaluation.

The required methods make this classes that implement this protocol provide all the required functionality across all active learning steps. Keep in mind that, similar to all other protocols in this module, this is merely the required interface and not the actual implementation.

See also

DriverProtocol

Uses the learner protocol in the active learning loop

TrainingProtocol

Training step protocol

ValidationProtocol

Validation step protocol

InferenceProtocol

Inference step protocol

TrainingLoop

Training loop protocol that can use a learner

forward(
*args: Any,
**kwargs: Any,
) Any[source]#

Implements the forward pass for a single batch.

This method is called between all active learning steps, and should contain the logic for how a model ingests data and produces predictions.

Parameters:
  • args (Any) – Additional arguments to pass to the model.

  • kwargs (Any) – Additional keyword arguments to pass to the model.

Returns:

The output of the model’s forward pass.

Return type:

Any

inference_step(
data: T | S,
*args: Any,
**kwargs: Any,
) None[source]#

Implements the inference logic for a single batch.

This can match the forward pass exactly, but provides an opportunity to differentiate (or lack thereof, with no pun intended). Specifically, this method will be called during query and metrology steps.

This should mirror the InferenceProtocol definition, except that the model corresponds to this object.

Parameters:
  • data (T | S) – The data to infer on. Typically assumed to be a batch of data.

  • args (Any) – Additional arguments to pass to the method.

  • kwargs (Any) – Additional keyword arguments to pass to the method.

property parameters: Iterator[Tensor]#

Returns an iterator over the parameters of the learner.

If subclassing from torch.nn.Module, this will automatically return the parameters of the module.

Returns:

An iterator over the parameters of the learner.

Return type:

Iterator[torch.Tensor]

training_step(
data: T,
*args: Any,
**kwargs: Any,
) None[source]#

Implements the training logic for a single batch.

This method will be called in training steps only, and not used for validation, query, or metrology steps. Specifically this means that gradients will be computed and used to update parameters.

In cases where gradients are not needed, consider implementing the validation_step() method instead.

This should mirror the TrainingProtocol definition, except that the model corresponds to this object.

Parameters:
  • data (T) – The data to train on. Typically assumed to be a batch of data.

  • args (Any) – Additional arguments to pass to the method.

  • kwargs (Any) – Additional keyword arguments to pass to the method.

validation_step(
data: T,
*args: Any,
**kwargs: Any,
) None[source]#

Implements the validation logic for a single batch.

This can match the forward pass, without the need for weight updates. This method will be called in validation steps only, and not used for query or metrology steps. In those cases, implement the inference_step() method instead.

This should mirror the ValidationProtocol definition, except that the model corresponds to this object.

Parameters:
  • data (T) – The data to validate on. Typically assumed to be a batch of data.

  • args (Any) – Additional arguments to pass to the method.

  • kwargs (Any) – Additional keyword arguments to pass to the method.

class physicsnemo.active_learning.protocols.MetrologyStrategy(*args: Any, **kwargs: Any)[source]#

Bases: ActiveLearningProtocol

This protocol defines a metrology strategy for active learning.

A metrology strategy is responsible for assessing the improvements to the underlying model, beyond simple validation metrics. This should reflect the application requirements of the model, which may include running a simulation.

records#

A sequence of record data structures that records the history of the active learning process, as viewed by this particular metrology view.

Type:

list

See also

ActiveLearningProtocol

Base protocol for all active learning strategies

DriverProtocol

Orchestrates metrology strategies

DataPool

Data pool protocol for accessing validation data

append(
record: S,
) None[source]#

Method to append a record to the metrology strategy.

Parameters:

record (S) – The record to append to the metrology strategy.

compute(
*args: Any,
**kwargs: Any,
) None[source]#

Method to compute the metrology strategy. No data is passed to this method, as it is expected that the data be drawn as needed from various DataPool connected to the driver.

This method defines the core logic for computing a particular view of performance by the underlying model on the data. Once computed, the data needs to be formatted into a record data structure S, that is then appended to the records attribute.

Parameters:
  • args (Any) – Additional arguments to pass to the method.

  • kwargs (Any) – Additional keyword arguments to pass to the method.

load_records(
path: Path | None = None,
*args: Any,
**kwargs: Any,
) None[source]#

Method to load the records of the metrology strategy, i.e. the reverse of serialize_records.

This should be defined by a concrete implementation, which dictates how the records are loaded, e.g. from a JSON file, database, etc.

If no path is provided, the strategy should load the latest records as sensible defaults. The records attribute should then be overwritten in-place.

Parameters:
  • path (Path | None) – The path to load the records from. If not provided, the strategy should load the latest records as sensible defaults.

  • args (Any) – Additional arguments to pass to the method.

  • kwargs (Any) – Additional keyword arguments to pass to the method.

records: list[S]#
reset() None[source]#

Method to reset any stateful attributes of the metrology strategy.

By default, the records attribute is reset to an empty list.

serialize_records(
path: Path | None = None,
*args: Any,
**kwargs: Any,
) None[source]#

Method to serialize the records of the metrology strategy.

This should be defined by a concrete implementation, which dictates how the records are persisted, e.g. to a JSON file, database, etc.

The strategy_dir property can be used to determine the directory where the records should be persisted.

Parameters:
  • path (Path | None) – The path to serialize the records to. If not provided, the strategy should provide a reasonable default, such as with the checkpointing or within the corresponding metrology directory via strategy_dir.

  • args (Any) – Additional arguments to pass to the method.

  • kwargs (Any) – Additional keyword arguments to pass to the method.

class physicsnemo.active_learning.protocols.QueryStrategy(*args: Any, **kwargs: Any)[source]#

Bases: ActiveLearningProtocol

This protocol defines a query strategy for active learning.

A query strategy is responsible for selecting data points for labeling. In the most general sense, concrete instances of this protocol will specify how many samples to query, and the heuristics for selecting samples.

max_samples#

The maximum number of samples to query. This can be interpreted as the exact number of samples to query, or as an upper limit for querying methods that are threshold based.

Type:

int

See also

ActiveLearningProtocol

Base protocol for all active learning strategies

AbstractQueue

Queue protocol for enqueuing data

LabelStrategy

Consumes queued data for labeling

DriverProtocol

Orchestrates query strategies

max_samples: int#
sample(
query_queue: AbstractQueue[T],
*args: Any,
**kwargs: Any,
) None[source]#

Method that implements the logic behind querying data to be labeled.

This method should be implemented by concrete implementations, and assume that an active learning driver will pass a queue for this method to enqueue data to be labeled.

Additional args and kwargs are passed to the method, and can be used to pass additional information to the query strategy.

This method will enqueue in place, and should not return anything.

Parameters:
  • query_queue (AbstractQueue[T]) – The queue to enqueue data to be labeled.

  • args (Any) – Additional arguments to pass to the method.

  • kwargs (Any) – Additional keyword arguments to pass to the method.

class physicsnemo.active_learning.protocols.TrainingLoop(*args, **kwargs)[source]#

Bases: Protocol

Defines a protocol that implements a training loop.

This protocol is intended to be called within the active learning loop during the training phase, where the model is trained on a specified number of epochs or training steps, and optionally validated on a dataset.

If a LearnerProtocol is provided, then train_fn and validate_fn become optional as they will be defined within the LearnerProtocol. If they are provided, however, then they should override the LearnerProtocol variants.

If graph capture/compilation is intended, then train_fn and validate_fn should be wrapped with StaticCaptureTraining and StaticCaptureEvaluateNoGrad, respectively.

See also

DriverProtocol

Uses training loops in the training phase

TrainingProtocol

Training step protocol

ValidationProtocol

Validation step protocol

LearnerProtocol

Learner protocol with training/validation methods

class physicsnemo.active_learning.protocols.TrainingProtocol(*args, **kwargs)[source]#

Bases: Protocol

This protocol defines the interface for training steps: given a model and some input data, compute the reduced, differentiable loss tensor and return it.

A concrete implementation can simply be a function with a signature that matches what is defined in __call__().

See also

TrainingLoop

Training loop protocol that uses this protocol

LearnerProtocol

Learner protocol with a training_step method

ValidationProtocol

Validation step protocol

class physicsnemo.active_learning.protocols.ValidationProtocol(*args, **kwargs)[source]#

Bases: Protocol

This protocol defines the interface for validation steps: given a model and some input data, compute metrics of interest and if relevant to do so, log the results.

A concrete implementation can simply be a function with a signature that matches what is defined in __call__().

See also

TrainingLoop

Training loop protocol that uses this protocol

LearnerProtocol

Learner protocol with a validation_step method

TrainingProtocol

Training step protocol

Configuration Classes#

These data structures are used to modify the behavior of different components of the active learning workflow. The general pattern is to ensure that they are JSON-serializable so that they can be checkpointed and restarted.

Configuration dataclasses for the active learning driver.

This module provides structured configuration classes that separate different concerns in the active learning workflow: optimization, training, strategies, and driver orchestration.

class physicsnemo.active_learning.config.DriverConfig(
batch_size: int,
max_active_learning_steps: int | None = None,
run_id: str = <factory>,
fine_tuning_lr: float | None = None,
reset_optim_states: bool = True,
skip_training: bool = False,
skip_metrology: bool = False,
skip_labeling: bool = False,
checkpoint_interval: int = 1,
checkpoint_on_training: bool = False,
checkpoint_on_metrology: bool = False,
checkpoint_on_query: bool = False,
checkpoint_on_labeling: bool = True,
model_checkpoint_frequency: int = 0,
root_log_dir: str | Path = PosixPath('/builds/modulus/physicsnemo-docs/docs/active_learning_logs'),
dist_manager: DistributedManager | None = None,
collate_fn: callable | None = None,
num_dataloader_workers: int = 0,
device: str | torch.device | None = None,
dtype: torch.dtype | None = None,
)[source]#

Bases: object

Configuration for driver orchestration and infrastructure.

This contains parameters that control the overall loop execution, logging, checkpointing, and distributed training setup - orthogonal to the specific AL or training logic.

batch_size#

The batch size to use for data loaders.

Type:

int

max_active_learning_steps#

Maximum number of AL iterations to perform. None means infinite.

Type:

int or None

run_id#

Unique identifier for this run. Auto-generated if not provided.

Type:

str

fine_tuning_lr#

Learning rate to switch to after the first AL step for fine-tuning.

Type:

float or None

reset_optim_states#

Whether to reset optimizer states between AL steps. Defaults to True.

Type:

bool

skip_training#

If True, skip the training phase entirely. Defaults to False.

Type:

bool

skip_metrology#

If True, skip the metrology phase entirely. Defaults to False.

Type:

bool

skip_labeling#

If True, skip the labeling phase entirely. Defaults to False.

Type:

bool

checkpoint_interval#

Save model checkpoint every N AL steps. 0 disables checkpointing. Defaults to 1.

Type:

int

checkpoint_on_training#

If True, save checkpoint at the start of the training phase. Defaults to False.

Type:

bool

checkpoint_on_metrology#

If True, save checkpoint at the start of the metrology phase. Defaults to False.

Type:

bool

checkpoint_on_query#

If True, save checkpoint at the start of the query phase. Defaults to False.

Type:

bool

checkpoint_on_labeling#

If True, save checkpoint at the start of the labeling phase. Defaults to True.

Type:

bool

model_checkpoint_frequency#

Save model weights every N epochs during training. 0 means only save between active learning phases. Useful for mid-training restarts. Defaults to 0.

Type:

int

root_log_dir#

Directory to save logs and checkpoints to. Defaults to an ‘active_learning_logs’ directory in the current working directory.

Type:

str or pathlib.Path

dist_manager#

Manager for distributed training configuration.

Type:

DistributedManager or None

collate_fn#

Custom collate function for batching data.

Type:

callable or None

num_dataloader_workers#

Number of worker processes for data loading. Defaults to 0.

Type:

int

device#

Device to use for model and data. This is intended for single process workflows; for distributed workflows, the device should be set in DistributedManager instead. If not specified, then the device will default to torch.get_default_device().

Type:

str or torch.device or None

dtype#

The dtype to use for model and data, and AMP contexts. If not provided, then the dtype will default to torch.get_default_dtype().

Type:

torch.dtype or None

See also

Driver

Uses this config for orchestration

TrainingConfig

Training configuration

StrategiesConfig

Strategies configuration

DataPool

Data pool protocol

AbstractQueue

Queue protocol

batch_size: int#
checkpoint_interval: int = 1#
checkpoint_on_labeling: bool = True#
checkpoint_on_metrology: bool = False#
checkpoint_on_query: bool = False#
checkpoint_on_training: bool = False#
collate_fn: callable | None = None#
device: str | torch.device | None = None#
dist_manager: DistributedManager | None = None#
dtype: torch.dtype | None = None#
fine_tuning_lr: float | None = None#
classmethod from_json(
json_str: str,
**kwargs: Any,
) DriverConfig[source]#

Creates a DriverConfig instance from a JSON string.

This method reconstructs a DriverConfig from JSON. Note that certain fields cannot be serialized and must be provided via kwargs: - dist_manager: DistributedManager instance (optional) - collate_fn: Custom collate function (optional)

Parameters:
  • json_str (str) – A JSON string that was previously serialized using to_json().

  • **kwargs (Any) – Additional keyword arguments to override or provide non-serializable fields like dist_manager and collate_fn.

Returns:

A new DriverConfig instance.

Return type:

DriverConfig

Raises:

ValueError – If the JSON cannot be parsed or required fields are missing.

Notes

The device and dtype fields are reconstructed from their string representations. The log_dir field in JSON is mapped to root_log_dir in the config.

max_active_learning_steps: int | None = None#
model_checkpoint_frequency: int = 0#
num_dataloader_workers: int = 0#
reset_optim_states: bool = True#
root_log_dir: str | Path = PosixPath('/builds/modulus/physicsnemo-docs/docs/active_learning_logs')#
run_id: str#
skip_labeling: bool = False#
skip_metrology: bool = False#
skip_training: bool = False#
to_json() str[source]#

Returns a JSON string representation of the DriverConfig.

Note that certain fields are not serialized and must be provided when deserializing: dist_manager, collate_fn.

Returns:

A JSON string representation of the config.

Return type:

str

class physicsnemo.active_learning.config.OptimizerConfig(
optimizer_cls: type[~torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>,
optimizer_kwargs: dict[str,
~typing.Any] = <factory>,
scheduler_cls: type[~torch.optim.lr_scheduler._LRScheduler] | None = None,
scheduler_kwargs: dict[str,
~typing.Any] = <factory>,
)[source]#

Bases: object

Configuration for optimizer and learning rate scheduler.

This encapsulates all training optimization parameters, keeping them separate from the active learning orchestration logic.

optimizer_cls#

The optimizer class to use. Defaults to AdamW.

Type:

type

optimizer_kwargs#

Keyword arguments to pass to the optimizer constructor. Defaults to {“lr”: 1e-4}.

Type:

dict

scheduler_cls#

The learning rate scheduler class to use. If None, no scheduler will be configured.

Type:

type or None

scheduler_kwargs#

Keyword arguments to pass to the scheduler constructor.

Type:

dict

See also

TrainingConfig

Uses this config for optimizer setup

Driver

Configures optimizer using this config

classmethod from_dict(
data: dict[str, Any],
) OptimizerConfig[source]#

Creates an OptimizerConfig instance from a dictionary.

This method assumes that the optimizer and scheduler classes are included in the physicsnemo.active_learning.registry, or a module path is specified to import the class from.

Parameters:

data (dict[str, Any]) – A dictionary that was previously serialized using the to_dict method.

Returns:

A new OptimizerConfig instance.

Return type:

OptimizerConfig

optimizer_cls#

alias of AdamW

optimizer_kwargs: dict[str, Any]#
scheduler_cls: type[_LRScheduler] | None = None#
scheduler_kwargs: dict[str, Any]#
to_dict() dict[str, Any][source]#

Returns a JSON-serializable dictionary representation of the OptimizerConfig.

For round-tripping, the registry is used to de-serialize the optimizer and scheduler classes.

Returns:

A dictionary that can be JSON serialized.

Return type:

dict[str, Any]

class physicsnemo.active_learning.config.StrategiesConfig(
query_strategies: list[QueryStrategy],
queue_cls: type[AbstractQueue],
label_strategy: LabelStrategy | None = None,
metrology_strategies: list[MetrologyStrategy] | None = None,
unlabeled_datapool: DataPool | None = None,
)[source]#

Bases: object

Configuration for active learning strategies and data acquisition.

This encapsulates the query-label-metrology cycle that is at the heart of active learning: strategies for selecting data, labeling it, and measuring model uncertainty/performance.

query_strategies#

The query strategies to use for selecting data to label. Each element should be a QueryStrategy instance.

Type:

list

queue_cls#

The queue implementation to use for passing data between query and labeling phases. Should implement AbstractQueue.

Type:

type

label_strategy#

The strategy to use for labeling queried data. If None, labeling will be skipped.

Type:

LabelStrategy or None

metrology_strategies#

Strategies for measuring model performance and uncertainty. Each element should be a MetrologyStrategy instance. If None, metrology will be skipped.

Type:

list or None

unlabeled_datapool#

Pool of unlabeled data that query strategies can sample from. Not all strategies require this (some may generate synthetic data).

Type:

DataPool or None

See also

Driver

Uses this config for strategy orchestration

QueryStrategy

Query strategy protocol

LabelStrategy

Label strategy protocol

MetrologyStrategy

Metrology strategy protocol

classmethod from_dict(
data: dict[str, Any],
**kwargs: Any,
) StrategiesConfig[source]#

Create a StrategiesConfig instance from a dictionary.

This method heavily relies on classes being added to the physicsnemo.active_learning.registry, which is used to instantiate all strategies and custom types used in active learning. As a fall back, the registry.construct method will try and import the class from the module path if it is not found in the registry.

Parameters:
  • data (dict[str, Any]) – A dictionary that was previously serialized using the to_dict method.

  • **kwargs (Any) – Additional keyword arguments to pass to the constructor.

Returns:

A new StrategiesConfig instance.

Return type:

StrategiesConfig

Raises:
  • ValueError: – If the data contains unexpected keys or if the object construction fails.

  • NameError: – If a class is not found in the registry and no module path is provided.

  • ModuleNotFoundError: – If a module is not found with the specified module path.

label_strategy: LabelStrategy | None = None#
metrology_strategies: list[MetrologyStrategy] | None = None#
query_strategies: list[QueryStrategy]#
queue_cls: type[AbstractQueue]#
to_dict() dict[str, Any][source]#

Method that converts the present StrategiesConfig instance into a dictionary that can be JSON serialized.

This method, for the most part, assumes that strategies are subclasses of ActiveLearningProtocol and/or they have an _args attribute that captures the arguments to the constructor.

One issue is the inability to reliably serialize the unlabeled_datapool, which for the most part, likely does not need serialization as a dataset. Regardless, this method will trigger a warning if unlabeled_datapool is not None.

Returns:

A dictionary that can be JSON serialized.

Return type:

dict[str, Any]

unlabeled_datapool: DataPool | None = None#
class physicsnemo.active_learning.config.TrainingConfig(
train_datapool: ~physicsnemo.active_learning.protocols.DataPool,
max_training_epochs: int,
val_datapool: ~physicsnemo.active_learning.protocols.DataPool | None = None,
optimizer_config: ~physicsnemo.active_learning.config.OptimizerConfig = <factory>,
max_fine_tuning_epochs: int | None = None,
train_loop_fn: ~physicsnemo.active_learning.protocols.TrainingLoop = <factory>,
)[source]#

Bases: object

Configuration for the training phase of active learning.

This groups all training-related components together, making it clear when training is or isn’t being used in the AL workflow.

train_datapool#

The pool of labeled data to use for training.

Type:

DataPool

max_training_epochs#

The maximum number of epochs to train for. If max_fine_tuning_epochs isn’t specified, this value is used for all active learning steps.

Type:

int

val_datapool#

Optional pool of data to use for validation during training.

Type:

DataPool or None

optimizer_config#

Configuration for the optimizer and scheduler. Defaults to AdamW with lr=1e-4, no scheduler.

Type:

OptimizerConfig

max_fine_tuning_epochs#

The maximum number of epochs used during fine-tuning steps, i.e. after the first active learning step. If None, then the fine-tuning will be performed for the duration of the active learning loop.

Type:

int or None

train_loop_fn#

The training loop function that orchestrates the training process. This defaults to a concrete implementation, DefaultTrainingLoop, which provides a very typical loop that includes the use of static capture, etc.

Type:

TrainingLoop

See also

Driver

Uses this config for training

OptimizerConfig

Optimizer configuration

StrategiesConfig

Strategies configuration

DefaultTrainingLoop

Default training loop implementation

classmethod from_dict(
data: dict[str, Any],
**kwargs: Any,
) TrainingConfig[source]#

Creates a TrainingConfig instance from a dictionary.

This method assumes that the training loop class is included in the physicsnemo.active_learning.registry, or a module path is specified to import the class from. Note that datapools must be provided via kwargs as they are not serialized.

Parameters:
  • data (dict[str, Any]) – A dictionary that was previously serialized using the to_dict method.

  • **kwargs (Any) – Additional keyword arguments to pass to the constructor. This is where you must provide train_datapool and optionally val_datapool.

Returns:

A new TrainingConfig instance.

Return type:

TrainingConfig

Raises:

ValueError – If required datapools are not provided in kwargs, if the data contains unexpected keys, or if object construction fails.

max_fine_tuning_epochs: int | None = None#
max_training_epochs: int#
optimizer_config: OptimizerConfig#
to_dict() dict[str, Any][source]#

Returns a JSON-serializable dictionary representation of the TrainingConfig.

For round-tripping, the registry is used to de-serialize the training loop and optimizer configuration. Note that datapools (train_datapool and val_datapool) are NOT serialized as they typically contain large datasets, file handles, or other non-serializable state.

Returns:

A dictionary that can be JSON serialized. Excludes datapools.

Return type:

dict[str, Any]

Warning

This method will issue a warning about the exclusion of datapools.

train_datapool: DataPool#
train_loop_fn: TrainingLoop#
val_datapool: DataPool | None = None#

Default Training Loop#

This module and corresponding DefaultTrainingLoop class implements the TrainingLoop interface, and should provide most of the necessary boilerplate for model training and fine-tuning; users will need to provide the training, validation, and testing step protocols when configuring the loop.

class physicsnemo.active_learning.loop.DefaultTrainingLoop(*args: Any, **kwargs: Any)[source]#

Bases: TrainingLoop

Default implementation of the TrainingLoop protocol.

This provides a functional training loop with support for static capture, progress bars, checkpointing, and distributed training. It implements the standard epoch-based training pattern with optional validation.

See also

TrainingLoop

Protocol specification for training loops

Driver

Uses training loops in the training phase

TrainingConfig

Configuration for training

TrainingProtocol

Training step protocol

ValidationProtocol

Validation step protocol

property amp_type: dtype#
static load_training_checkpoint(
checkpoint_dir: Path,
model: Module | LearnerProtocol,
optimizer: Optimizer,
lr_scheduler: _LRScheduler | None = None,
) int | None[source]#

Load training state from checkpoint directory.

Model weights are loaded separately. Optimizer, scheduler, and epoch metadata are loaded from the combined training_state.pt file.

Parameters:
  • checkpoint_dir (pathlib.Path) – Directory containing checkpoint files.

  • model (Module or LearnerProtocol) – Model to load weights into.

  • optimizer (Optimizer) – Optimizer to load state into.

  • lr_scheduler (_LRScheduler or None) – Optional LR scheduler to load state into.

Returns:

Training epoch from metadata if available, else None.

Return type:

int or None

save_training_checkpoint(
checkpoint_dir: Path,
model: Module | LearnerProtocol,
optimizer: Optimizer,
lr_scheduler: _LRScheduler | None = None,
training_epoch: int | None = None,
) None[source]#

Save training state to checkpoint directory.

Model weights are saved separately. Optimizer, scheduler, and epoch metadata are combined into a single training_state.pt file.

Parameters:
  • checkpoint_dir (pathlib.Path) – Directory to save checkpoint files.

  • model (Module or LearnerProtocol) – Model to save weights for.

  • optimizer (Optimizer) – Optimizer to save state from.

  • lr_scheduler (_LRScheduler or None) – Optional LR scheduler to save state from.

  • training_epoch (int or None) – Current training epoch for metadata.

Active Learning Driver#

This module and class implements the DriverProtocol interface, and is usable out-of-the-box for most active learning workflows. The Driver class is configured by DriverConfig, and serves as the focal point for orchestrating the active learning.

This module contains the definition for an active learning driver class, which is responsible for orchestration and automation of the end-to-end active learning process.

class physicsnemo.active_learning.driver.ActiveLearningCheckpoint(
driver_config: DriverConfig,
strategies_config: StrategiesConfig,
active_learning_step_idx: int,
active_learning_phase: ActiveLearningPhase,
physicsnemo_version: str = '1.3.0a0',
training_config: TrainingConfig | None = None,
optimizer_state: dict[str, Any] | None = None,
lr_scheduler_state: dict[str, Any] | None = None,
has_query_queue: bool = False,
has_label_queue: bool = False,
)[source]#

Bases: object

Metadata associated with an ongoing (or completed) active learning experiment.

The information contained in this metadata should be sufficient to restart the active learning experiment at the nearest point: for example, training should be able to continue from an epoch, while for querying/sampling, etc. we continue from a pre-existing queue.

driver_config#

Infrastructure and orchestration configuration.

Type:

DriverConfig

strategies_config#

Active learning strategies configuration.

Type:

StrategiesConfig

active_learning_step_idx#

Current iteration index of the active learning loop.

Type:

int

active_learning_phase#

Current phase of the active learning workflow.

Type:

ActiveLearningPhase

physicsnemo_version#

Version of PhysicsNeMo used to create the checkpoint.

Type:

str

training_config#

Training components configuration, if training is used.

Type:

TrainingConfig or None

optimizer_state#

Optimizer state dictionary for checkpointing.

Type:

dict or None

lr_scheduler_state#

Learning rate scheduler state dictionary for checkpointing.

Type:

dict or None

has_query_queue#

Whether the checkpoint includes a query queue.

Type:

bool

has_label_queue#

Whether the checkpoint includes a label queue.

Type:

bool

See also

Driver

Uses this dataclass for checkpointing

DriverConfig

Driver configuration

StrategiesConfig

Strategies configuration

TrainingConfig

Training configuration

active_learning_phase: ActiveLearningPhase#
active_learning_step_idx: int#
driver_config: DriverConfig#
has_label_queue: bool = False#
has_query_queue: bool = False#
lr_scheduler_state: dict[str, Any] | None = None#
optimizer_state: dict[str, Any] | None = None#
physicsnemo_version: str = '1.3.0a0'#
strategies_config: StrategiesConfig#
training_config: TrainingConfig | None = None#
class physicsnemo.active_learning.driver.Driver(
config: DriverConfig,
learner: Module | LearnerProtocol,
strategies_config: StrategiesConfig,
training_config: TrainingConfig | None = None,
inference_fn: InferenceProtocol | None = None,
)[source]#

Bases: DriverProtocol

Provides a simple implementation of the DriverProtocol used to orchestrate an active learning process within PhysicsNeMo.

At a high level, the active learning process is broken down into four phases: training, metrology, query, and labeling.

To understand the orchestration, start by inspecting the active_learning_step() method, which defines a single iteration of the active learning loop, which is dispatched by the run() method. From there, it should be relatively straightforward to trace the remaining components.

config#

Infrastructure and orchestration configuration.

Type:

DriverConfig

learner#

The learner module for the active learning process.

Type:

Module or LearnerProtocol

strategies_config#

Active learning strategies (query, label, metrology).

Type:

StrategiesConfig

training_config#

Training components. None if training is skipped.

Type:

TrainingConfig or None

inference_fn#

Custom inference function.

Type:

InferenceProtocol or None

active_learning_step_idx#

Current iteration index of the active learning loop.

Type:

int

query_queue#

Queue populated with data by query strategies.

Type:

AbstractQueue

label_queue#

Queue populated with labeled data by the label strategy.

Type:

AbstractQueue

optimizer#

Configured optimizer (set after configure_optimizer is called).

Type:

torch.optim.Optimizer or None

lr_scheduler#

Configured learning rate scheduler.

Type:

torch.optim.lr_scheduler._LRScheduler or None

logger#

Persistent logger for the active learning process.

Type:

logging.Logger

See also

DriverProtocol

Protocol specification for active learning drivers

DriverConfig

Configuration for the driver

StrategiesConfig

Configuration for active learning strategies

TrainingConfig

Configuration for training

active_learning_step(
train_step_fn: TrainingProtocol | None = None,
validate_step_fn: ValidationProtocol | None = None,
*args: Any,
**kwargs: Any,
) None[source]#

Performs a single active learning iteration.

This method will perform the following sequence of steps: 1. Train the model stored in Driver.learner by creating data loaders with Driver.train_datapool and Driver.val_datapool. 2. Run the metrology strategies stored in Driver.metrology_strategies. 3. Run the query strategies stored in Driver.query_strategies, if available. 4. Run the labeling strategy stored in Driver.label_strategy, if available.

When entering each stage, we check to ensure all components necessary for the minimum function for that stage are available before proceeding.

If current_phase is set (e.g., from checkpoint resumption), only phases at or after current_phase will be executed. After completing all phases, current_phase is reset to None for the next AL step.

Parameters:
  • train_step_fn (p.TrainingProtocol | None = None) – The training function to use for training. If not provided, then the Driver.train_loop_fn will be used.

  • validate_step_fn (p.ValidationProtocol | None = None) – The validation function to use for validation. If not provided, then validation will not be performed.

  • args (Any) – Additional arguments to pass to the method. These will be passed to the training loop, metrology strategies, query strategies, and labeling strategies.

  • kwargs (Any) – Additional keyword arguments to pass to the method. These will be passed to the training loop, metrology strategies, query strategies, and labeling strategies.

Raises:

ValueError – If any of the required components for a stage are not available.

property active_learning_step_idx: int#

Returns the current active learning step index.

This represents the number of times the active learning step has been called, i.e. the number of iterations of the loop.

Returns:

The current active learning step index.

Return type:

int

attach_strategies() None[source]#

Calls strategy.attach for all available strategies.

barrier() None[source]#

Wrapper to call barrier on the correct device.

Becomes a no-op if distributed is not initialized, otherwise will attempt to read the local device ID from either the distributed manager or the default device.

configure_optimizer() None[source]#

Setup optimizer and LR schedulers from training_config.

property device: device#

Return a consistent device interface to use across the driver.

property dist_manager: DistributedManager | None#

Returns the distributed manager, if it was specified as part of the DriverConfig configuration.

Returns:

The distributed manager.

Return type:

DistributedManager | None

property is_lr_scheduler_configured: bool#

Returns whether the LR scheduler is configured.

property is_optimizer_configured: bool#

Returns whether the optimizer is configured.

property label_strategy: LabelStrategy | None#

Returns the label strategy from strategies_config.

property last_checkpoint: Path | None#

Returns path to the most recently saved checkpoint.

Returns:

Path to the last checkpoint directory, or None if no checkpoint has been saved yet.

Return type:

Path | None

classmethod load_checkpoint(
checkpoint_path: str | Path,
learner: Module | LearnerProtocol | None = None,
train_datapool: DataPool | None = None,
val_datapool: DataPool | None = None,
unlabeled_datapool: DataPool | None = None,
**kwargs: Any,
) Driver[source]#

Load a Driver instance from a checkpoint.

Given a checkpoint directory, this method will attempt to reconstruct the driver and its associated components from the checkpoint. The checkpoint path must contain a checkpoint.pt file, which contains the metadata associated with the experiment.

Additional parameters that might not be serialized with the checkpointing mechanism can/need to be provided to this method; for example when using non-physicsnemo.Module learners, and any data pools associated with the workflow.

Important

Currently, the strategy states are not reloaded from the checkpoint. This will be addressed in a future patch, but for now it is recommended to back up your strategy states (e.g. metrology records) manually before restarting experiments.

Parameters:
  • checkpoint_path (str | Path) – Path to checkpoint directory containing checkpoint.pt and model weights.

  • learner (Module | p.LearnerProtocol | None) – Learner model to load weights into. If None, will attempt to reconstruct from checkpoint (only works for physicsnemo.Module).

  • train_datapool (p.DataPool | None) – Training datapool. Required if training_config exists in checkpoint.

  • val_datapool (p.DataPool | None) – Validation datapool. Optional.

  • unlabeled_datapool (p.DataPool | None) – Unlabeled datapool for query strategies. Optional.

  • **kwargs (Any) – Additional keyword arguments to override config values.

Returns:

Reconstructed Driver instance ready to resume execution.

Return type:

Driver

property log_dir: Path#

Returns the log directory.

Note that this is the DriverConfig.root_log_dir combined with the shortened run ID for the current run.

Effectively, this means that each run will have its own directory for logs, checkpoints, etc.

Returns:

The log directory.

Return type:

Path

property metrology_strategies: list[MetrologyStrategy] | None#

Returns the metrology strategies from strategies_config.

property query_strategies: list[QueryStrategy]#

Returns the query strategies from strategies_config.

run(
train_step_fn: TrainingProtocol | None = None,
validate_step_fn: ValidationProtocol | None = None,
*args: Any,
**kwargs: Any,
) None[source]#

Runs the active learning loop until the maximum number of active learning steps is reached.

Parameters:
  • train_step_fn (p.TrainingProtocol | None = None) – The training function to use for training. If not provided, then the Driver.train_loop_fn will be used.

  • validate_step_fn (p.ValidationProtocol | None = None) – The validation function to use for validation. If not provided, then validation will not be performed.

  • args (Any) – Additional arguments to pass to the method. These will be passed to the training loop, metrology strategies, query strategies, and labeling strategies.

  • kwargs (Any) – Additional keyword arguments to pass to the method. These will be passed to the training loop, metrology strategies, query strategies, and labeling strategies.

property run_id: str#

Returns the run id from the DriverConfig.

Returns:

The run id.

Return type:

str

save_checkpoint(
path: str | Path | None = None,
training_epoch: int | None = None,
) Path | None[source]#

Save a checkpoint of the active learning experiment.

Saves AL orchestration state (configs, queues, step index, phase) and model weights. Training-specific state (optimizer, scheduler) is handled by DefaultTrainingLoop and saved to training_state.pt during training.

Parameters:
  • path (str | Path | None) – Path to save checkpoint. If None, creates path based on current AL step index and phase: log_dir/checkpoints/step_{idx}/{phase}/

  • training_epoch (int | None) – Optional epoch number for mid-training checkpoints.

Returns:

Checkpoint directory path, or None if checkpoint not saved (non-rank-0 in distributed).

Return type:

Path | None

property short_run_id: str#

Returns the first 8 characters of the run id.

The 8 character limit assumes that the run ID is a UUID4. This is particularly useful for user-facing interfaces, where you do not necessarily want to reference the full UUID.

Returns:

The first 8 characters of the run id.

Return type:

str

property train_datapool: DataPool | None#

Returns the training datapool from training_config.

property train_loop_fn: TrainingLoop | None#

Returns the training loop function from training_config.

property unlabeled_datapool: DataPool | None#

Returns the unlabeled datapool from strategies_config.

property val_datapool: DataPool | None#

Returns the validation datapool from training_config.

Active Learning Registry#

The registry provides a centralized location for registering and constructing custom active learning strategies. It enables string-based lookups for checkpointing and provides argument validation when constructing protocol instances.

Note

Users should not use the class directly, but rather the instance of the class through the registry object, documented below.

physicsnemo.active_learning.registry = ActiveLearningRegistry()#

Registry for active learning protocols.

This class provides a centralized registry for user-defined active learning protocols that implement the ActiveLearningProtocol. It enables string-based lookups for checkpointing and provides argument validation when constructing protocol instances.

The registry supports two primary modes of interaction: 1. Registration via decorator: @registry.register("my_strategy") 2. Construction with validation: registry.construct("my_strategy", **kwargs)

physicsnemo.active_learning._registry#

Internal dictionary mapping protocol names to their class types.

Type:

dict

physicsnemo.active_learning.register(cls_name: str) Callable#

Decorator to register a protocol class with a given name.

physicsnemo.active_learning.construct(
cls_name: str,
\*\*kwargs,
) :class:`~physicsnemo.active_learning.protocols.ActiveLearningProtocol`#

Construct an instance of a registered protocol with argument validation.

physicsnemo.active_learning.is_registered(cls_name: str) bool#

Check if a protocol name is registered.

physicsnemo.active_learning.Properties()#
----------
registered_names : list

A list of all registered protocol names, sorted alphabetically.

See also

ActiveLearningProtocol

Base protocol for active learning strategies

QueryStrategy

Query strategy protocol

LabelStrategy

Label strategy protocol

MetrologyStrategy

Metrology strategy protocol

Examples

Register a custom strategy:

>>> from physicsnemo.active_learning._registry import registry
>>> @registry.register("my_custom_strategy")
... class MyCustomStrategy:
...     def __init__(self, param1: int, param2: str):
...         self.param1 = param1
...         self.param2 = param2

Construct an instance with validation:

>>> strategy = registry.construct("my_custom_strategy", param1=42, param2="test")

Global registry instance for active learning protocols.