PhysicsNeMo Active Learning#
Developing Active Learning Workflows#
For a high level overview and understanding of how to construct active learning workflows using PhysicsNeMo, users should consult the User Guide . The guide will motivate the need for active learning, the abstraction provided by PhysicsNeMo, and some additional tips for developing custom components like querying and labeling strategies.
API Reference#
Protocols#
This module contains base classes for active learning protocols.
These are protocols intended to be abstract, and importing these classes specifically is intended to either be subclassed, or for type annotations.
Protocol Architecture#
Python typing.Protocol s are used for structural typing: essentially, they are
used to describe an expected interface in a way that is helpful for static type checkers
to make sure concrete implementations provide everything that is needed for a workflow
to function. typing.Protocol s are not actually enforced at runtime, and
inheritance is not required for them to function: as long as the implementation
provides the expected attributes and methods, they will be compatible with the protocol.
The active learning framework is built around several key protocol abstractions that work together to orchestrate the active learning workflow:
- Core Infrastructure Protocols:
AbstractQueue[T] - Generic queue protocol for passing data between components
DataPool[T] - Protocol for data reservoirs that support appending and sampling
ActiveLearningProtocol - Base protocol providing common interface for all AL strategies
- Strategy Protocols (inherit from ActiveLearningProtocol):
QueryStrategy - Defines how to select data points for labeling
LabelStrategy - Defines processes for adding ground truth labels to unlabeled data
MetrologyStrategy - Defines procedures that assess model improvements beyond validation metrics
- Model Interface Protocols:
TrainingProtocol - Interface for training step functions
ValidationProtocol - Interface for validation step functions
InferenceProtocol - Interface for inference step functions
TrainingLoop - Interface for complete training loop implementations
LearnerProtocol - Comprehensive interface for learner modules (combines training/validation/inference)
- Orchestration Protocol:
DriverProtocol - Main orchestrator that coordinates all components in the active learning loop
Active Learning Workflow#
The typical active learning workflow orchestrated by DriverProtocol follows this sequence:
Training Phase: Use LearnerProtocol or TrainingLoop to train the model on training_pool
Metrology Phase (optional): Apply MetrologyStrategy instances to assess model performance
Query Phase: Apply QueryStrategy instances to select samples from unlabeled_pool → query_queue
Labeling Phase (optional): Apply LabelStrategy instances to label queued samples → label_queue
Data Integration: Move labeled data from label_queue to training_pool
Type Parameters#
T: Data structure containing both inputs and ground truth labels
S: Data structure containing only inputs (no ground truth labels)
- class physicsnemo.active_learning.protocols.AbstractQueue(*args, **kwargs)[source]#
Bases:
Protocol[T]Defines a generic queue protocol for data that is passed between active learning components.
This can be a simple local queue.Queue, or a more sophisticated distributed queue system.
The primary use case for this is to allow a query strategy to enqueue some data structure for the labeling strategy to consume, and once the labeling is done, enqueue to a data serialization workflow. While there is no explcit restriction on the type of queue that is implemented, a reasonable assumption to make would be a FIFO queue, unless otherwise specified by the concrete implementation.
Optional Serialization Methods#
Implementations may optionally provide to_list() and from_list() methods for checkpoint serialization. If not provided, the queue will be serialized using torch.save() as a fallback.
Type Parameters#
- T
The type of items that will be stored in the queue.
See also
QueryStrategyEnqueues data to be labeled
LabelStrategyDequeues data for labeling and enqueues labeled data
DriverProtocolUses queues to pass data between strategies
- empty() bool[source]#
Method to check if the queue is empty/has been depleted.
- Returns:
True if the queue is empty, False otherwise.
- Return type:
bool
- class physicsnemo.active_learning.protocols.ActiveLearningPhase(
- value,
- names=<not given>,
- *values,
- module=None,
- qualname=None,
- type=None,
- start=1,
- boundary=None,
Bases:
StrEnumAn enumeration of the different phases of the active learning workflow.
This is primarily used in the metadata for restarting an ongoing active learning experiment.
See also
ActiveLearningProtocolBase protocol for active learning strategies
DriverProtocolMain orchestrator that uses this enumeration
- DATA_INTEGRATION = 'data_integration'#
- LABELING = 'labeling'#
- METROLOGY = 'metrology'#
- QUERY = 'query'#
- TRAINING = 'training'#
- class physicsnemo.active_learning.protocols.ActiveLearningProtocol(*args: Any, **kwargs: Any)[source]#
Bases:
ProtocolThis protocol acts as a basis for all active learning protocols.
This ensures that all protocols have some common interface, for example the ability to
attach()to another object for scope management.- __protocol_name__#
The name of the protocol. This is primarily used for
reprandstrf-strings. This should be defined by concrete implementations.- Type:
str
- _args#
A dictionary of arguments that were used to instantiate the protocol. This is used for serialization and deserialization of the protocol, and follows the same pattern as the
_argsattribute ofphysicsnemo.Module.- Type:
dict[str, Any]
- attach(self, other: object) None:[source]#
This method is used to attach the current object to another, allowing the protocol to access the attached object’s scope. The use case for this is to allow a protocol access to the driver’s scope to access dataset, model, etc. as needed. This needs to be implemented by concrete implementations.
- is_attached: bool
Whether the current object is attached to another object. This is left abstract, as it depends on how
attach()is implemented.
- logger: Logger
The logger for this protocol. This is used to log information about the protocol’s progress.
- _setup_logger(self) None:[source]#
This method is used to setup the logger for the protocol. The default implementation is to configure the logger similarly to how
physicsnemologgers are configured.
See also
QueryStrategyQuery strategy protocol (child)
LabelStrategyLabel strategy protocol (child)
MetrologyStrategyMetrology strategy protocol (child)
DriverProtocolMain orchestrator that uses these protocols
- attach(other: object) None[source]#
This method is used to attach another object to the current protocol, allowing the attached object to access the scope of this protocol. The primary reason for this is to allow the protocol to access things like the dataset, the learner model, etc. as needed.
Example use cases would be for a query strategy to access the
unlabeled_pool; for a metrology strategy to access thevalidation_pool, and for any strategy to be able to access the surrogate/learner model.This method can be as simple as setting
self.driver = other, but is left abstract in case there are other potential use cases where multiple protocols could share information.- Parameters:
other (object) – The object to attach to.
- property checkpoint_dir: Path#
Utility property for strategies to conveniently access the checkpoint directory.
This is useful for (de)serializing data tied to checkpointing.
- Returns:
The checkpoint directory, which includes the active learning step index.
- Return type:
Path
- Raises:
RuntimeError – If the strategy is not attached to a driver yet.
- property is_attached: bool#
Property to check if the current object is already attached.
This is left abstract, as it depends on how
attachis implemented.- Returns:
True if the current object is attached, False otherwise.
- Return type:
bool
- property logger: Logger#
Property to access the logger for this protocol.
If the logger has not been configured yet, the property will call the _setup_logger method to configure it.
- Returns:
The logger for this protocol.
- Return type:
Logger
- property strategy_dir: Path#
Returns the directory where the underlying strategy can use to persist data.
Depending on the strategy abstraction, further nesting may be required (e.g active learning step index, phase, etc.).
- Returns:
The directory where the metrology strategy will persist its records.
- Return type:
Path
- Raises:
RuntimeError – If the metrology strategy is not attached to a driver yet.
- class physicsnemo.active_learning.protocols.DataPool(*args, **kwargs)[source]#
Bases:
Protocol[T]An abstract protocol for some reservoir of data that is used for some part of active learning, parametrized such that it will return data structures of an arbitrary type
T.All methods are left abstract, and need to be defined by concrete implementations. For the most part, a torch.utils.data.Dataset would match this protocol, provided that it implements the
append()method which will allow data to be persisted to a filesystem.See also
DriverProtocolUses data pools for training, validation, and unlabeled data
AbstractQueueQueue protocol for passing data between components
- class physicsnemo.active_learning.protocols.DriverProtocol[source]#
Bases:
objectThis protocol specifies the expected interface for an active learning driver: for a concrete implementation, refer to the
drivermodule instead. The specification is provided mostly as a reference, and for ease of type hinting to prevent circular imports.- learner#
The learner module that will be used as the surrogate within the active learning loop.
- Type:
- query_strategies#
The query strategies that will be used for selecting data points to label. A list of
QueryStrategyinstances can be included, and will sequentially be used to populate thequery_queuethat passes samples over to labeling.- Type:
list
- query_queue#
The queue containing data samples to be labeled.
QueryStrategyinstances should enqueue samples to this queue.- Type:
- label_strategy#
The label strategy that will be used for labeling data points. In contrast to the other strategies, only a single label strategy is supported. This strategy will consume the
query_queueand enqueue labeled data to thelabel_queue.- Type:
LabelStrategyor None
- label_queue#
The queue containing freshly labeled data.
LabelStrategyinstances should enqueue labeled data to this queue, and the driver will subsequently serialize data contained within this queue to a persistent format.- Type:
AbstractQueueor None
- metrology_strategies#
The metrology strategies that will be used for assessing the performance of the surrogate. A list of
MetrologyStrategyinstances can be included, and will sequentially be used to populate themetrology_queuethat passes data over to the learner.- Type:
list or None
- training_pool#
The pool of data to be used for training. This data will be used to train the underlying model, and is assumed to be mutable in that additional data can be added to the pool over the course of active learning.
- Type:
- validation_pool#
The pool of data to be used for validation. This data will be used for both conventional validation, as well as for metrology. This dataset is considered to be immutable, and should not be modified over the course of active learning. This dataset is considered optional, as both validation and metrology are.
- Type:
DataPoolor None
- unlabeled_pool#
An optional pool of data to be used for querying and labeling. If supplied, this dataset can be depleted by a query strategy to select data points for labeling. In principle, this could also represent a generative model, i.e. not just a static dataset, but at a high level represents a distribution of data.
- Type:
DataPoolor None
See also
QueryStrategyQuery strategy protocol
LabelStrategyLabel strategy protocol
MetrologyStrategyMetrology strategy protocol
LearnerProtocolLearner protocol
DataPoolData pool protocol
AbstractQueueQueue protocol
- active_learning_step(
- *args: Any,
- **kwargs: Any,
Implements the active learning step.
This step performs a single pass of the active learning loop, with the intended order being: training, metrology, query, labeling, with the metrology and labeling steps being optional.
- Parameters:
args (Any) – Additional arguments to pass to the method.
kwargs (Any) – Additional keyword arguments to pass to the method.
- attach_strategies() None[source]#
Attaches all provided strategies.
This method relies on the
attachmethod of the strategies, which will subsequently give the strategy access to the driver’s scope.Example use cases would be for any strategy (apart from label strategy) to access the underlying model (
LearnerProtocol); for a query strategy to access theunlabeled_pool; for a metrology strategy to access thevalidation_pool.
- label_queue: AbstractQueue[T] | None#
- label_strategy: LabelStrategy | None#
- learner: LearnerProtocol#
- metrology_strategies: list[MetrologyStrategy] | None#
- query_queue: AbstractQueue[T]#
- query_strategies: list[QueryStrategy]#
- class physicsnemo.active_learning.protocols.InferenceProtocol(*args, **kwargs)[source]#
Bases:
ProtocolThis protocol defines the interface for inference steps: given a model and some input data, return the output of the model’s forward pass.
A concrete implementation can simply be a function with a signature that matches what is defined in
__call__().See also
LearnerProtocolLearner protocol with an inference_step method
QueryStrategyUses inference for query strategies
MetrologyStrategyUses inference for metrology strategies
- class physicsnemo.active_learning.protocols.LabelStrategy(*args: Any, **kwargs: Any)[source]#
Bases:
ActiveLearningProtocolThis protocol defines a label strategy for active learning.
A label strategy is responsible for labeling data points; this may be an simple Python function for demonstrating a concept, or an external, potentially time consuming and complex, process.
- __is_external_process__#
Whether the label strategy is running in an external process.
- Type:
bool
- __provides_fields__#
The fields that the label strategy provides. This should be set by concrete implementations, and should be used to write and map labeled data to fields within the data structure
T.- Type:
set or None
See also
ActiveLearningProtocolBase protocol for all active learning strategies
AbstractQueueQueue protocol for dequeuing and enqueuing data
QueryStrategyProduces queued data for labeling
DriverProtocolOrchestrates the label strategy
- label(
- queue_to_label: AbstractQueue[T],
- serialize_queue: AbstractQueue[T],
- *args: Any,
- **kwargs: Any,
Method that implements the logic behind labeling data.
This method should be implemented by concrete implementations, and assume that an active learning driver will pass a queue for this method to dequeue data to be labeled.
- Parameters:
queue_to_label (AbstractQueue[T]) – Queue containing data structures to be labeled. Generally speaking, this should be passed over after running query strateg(ies).
serialize_queue (AbstractQueue[T]) – Queue for enqueing labeled data to be serialized.
args (Any) – Additional arguments to pass to the method.
kwargs (Any) – Additional keyword arguments to pass to the method.
- class physicsnemo.active_learning.protocols.LearnerProtocol[source]#
Bases:
objectThis protocol represents the learner part of an active learning algorithm.
This corresponds to a set of trainable parameters that are optimized, and subsequently used for inference and evaluation.
The required methods make this classes that implement this protocol provide all the required functionality across all active learning steps. Keep in mind that, similar to all other protocols in this module, this is merely the required interface and not the actual implementation.
See also
DriverProtocolUses the learner protocol in the active learning loop
TrainingProtocolTraining step protocol
ValidationProtocolValidation step protocol
InferenceProtocolInference step protocol
TrainingLoopTraining loop protocol that can use a learner
- forward(
- *args: Any,
- **kwargs: Any,
Implements the forward pass for a single batch.
This method is called between all active learning steps, and should contain the logic for how a model ingests data and produces predictions.
- Parameters:
args (Any) – Additional arguments to pass to the model.
kwargs (Any) – Additional keyword arguments to pass to the model.
- Returns:
The output of the model’s forward pass.
- Return type:
Any
- inference_step(
- data: T | S,
- *args: Any,
- **kwargs: Any,
Implements the inference logic for a single batch.
This can match the forward pass exactly, but provides an opportunity to differentiate (or lack thereof, with no pun intended). Specifically, this method will be called during query and metrology steps.
This should mirror the
InferenceProtocoldefinition, except that the model corresponds to this object.- Parameters:
data (T | S) – The data to infer on. Typically assumed to be a batch of data.
args (Any) – Additional arguments to pass to the method.
kwargs (Any) – Additional keyword arguments to pass to the method.
- property parameters: Iterator[Tensor]#
Returns an iterator over the parameters of the learner.
If subclassing from torch.nn.Module, this will automatically return the parameters of the module.
- Returns:
An iterator over the parameters of the learner.
- Return type:
Iterator[torch.Tensor]
- training_step(
- data: T,
- *args: Any,
- **kwargs: Any,
Implements the training logic for a single batch.
This method will be called in training steps only, and not used for validation, query, or metrology steps. Specifically this means that gradients will be computed and used to update parameters.
In cases where gradients are not needed, consider implementing the
validation_step()method instead.This should mirror the
TrainingProtocoldefinition, except that the model corresponds to this object.- Parameters:
data (T) – The data to train on. Typically assumed to be a batch of data.
args (Any) – Additional arguments to pass to the method.
kwargs (Any) – Additional keyword arguments to pass to the method.
- validation_step(
- data: T,
- *args: Any,
- **kwargs: Any,
Implements the validation logic for a single batch.
This can match the forward pass, without the need for weight updates. This method will be called in validation steps only, and not used for query or metrology steps. In those cases, implement the
inference_step()method instead.This should mirror the
ValidationProtocoldefinition, except that the model corresponds to this object.- Parameters:
data (T) – The data to validate on. Typically assumed to be a batch of data.
args (Any) – Additional arguments to pass to the method.
kwargs (Any) – Additional keyword arguments to pass to the method.
- class physicsnemo.active_learning.protocols.MetrologyStrategy(*args: Any, **kwargs: Any)[source]#
Bases:
ActiveLearningProtocolThis protocol defines a metrology strategy for active learning.
A metrology strategy is responsible for assessing the improvements to the underlying model, beyond simple validation metrics. This should reflect the application requirements of the model, which may include running a simulation.
- records#
A sequence of record data structures that records the history of the active learning process, as viewed by this particular metrology view.
- Type:
list
See also
ActiveLearningProtocolBase protocol for all active learning strategies
DriverProtocolOrchestrates metrology strategies
DataPoolData pool protocol for accessing validation data
- append(
- record: S,
Method to append a record to the metrology strategy.
- Parameters:
record (S) – The record to append to the metrology strategy.
- compute(
- *args: Any,
- **kwargs: Any,
Method to compute the metrology strategy. No data is passed to this method, as it is expected that the data be drawn as needed from various
DataPoolconnected to the driver.This method defines the core logic for computing a particular view of performance by the underlying model on the data. Once computed, the data needs to be formatted into a record data structure
S, that is then appended to therecordsattribute.- Parameters:
args (Any) – Additional arguments to pass to the method.
kwargs (Any) – Additional keyword arguments to pass to the method.
- load_records(
- path: Path | None = None,
- *args: Any,
- **kwargs: Any,
Method to load the records of the metrology strategy, i.e. the reverse of serialize_records.
This should be defined by a concrete implementation, which dictates how the records are loaded, e.g. from a JSON file, database, etc.
If no path is provided, the strategy should load the latest records as sensible defaults. The records attribute should then be overwritten in-place.
- Parameters:
path (Path | None) – The path to load the records from. If not provided, the strategy should load the latest records as sensible defaults.
args (Any) – Additional arguments to pass to the method.
kwargs (Any) – Additional keyword arguments to pass to the method.
- records: list[S]#
- reset() None[source]#
Method to reset any stateful attributes of the metrology strategy.
By default, the
recordsattribute is reset to an empty list.
- serialize_records(
- path: Path | None = None,
- *args: Any,
- **kwargs: Any,
Method to serialize the records of the metrology strategy.
This should be defined by a concrete implementation, which dictates how the records are persisted, e.g. to a JSON file, database, etc.
The strategy_dir property can be used to determine the directory where the records should be persisted.
- Parameters:
path (Path | None) – The path to serialize the records to. If not provided, the strategy should provide a reasonable default, such as with the checkpointing or within the corresponding metrology directory via strategy_dir.
args (Any) – Additional arguments to pass to the method.
kwargs (Any) – Additional keyword arguments to pass to the method.
- class physicsnemo.active_learning.protocols.QueryStrategy(*args: Any, **kwargs: Any)[source]#
Bases:
ActiveLearningProtocolThis protocol defines a query strategy for active learning.
A query strategy is responsible for selecting data points for labeling. In the most general sense, concrete instances of this protocol will specify how many samples to query, and the heuristics for selecting samples.
- max_samples#
The maximum number of samples to query. This can be interpreted as the exact number of samples to query, or as an upper limit for querying methods that are threshold based.
- Type:
int
See also
ActiveLearningProtocolBase protocol for all active learning strategies
AbstractQueueQueue protocol for enqueuing data
LabelStrategyConsumes queued data for labeling
DriverProtocolOrchestrates query strategies
- max_samples: int#
- sample(
- query_queue: AbstractQueue[T],
- *args: Any,
- **kwargs: Any,
Method that implements the logic behind querying data to be labeled.
This method should be implemented by concrete implementations, and assume that an active learning driver will pass a queue for this method to enqueue data to be labeled.
Additional
argsandkwargsare passed to the method, and can be used to pass additional information to the query strategy.This method will enqueue in place, and should not return anything.
- Parameters:
query_queue (AbstractQueue[T]) – The queue to enqueue data to be labeled.
args (Any) – Additional arguments to pass to the method.
kwargs (Any) – Additional keyword arguments to pass to the method.
- class physicsnemo.active_learning.protocols.TrainingLoop(*args, **kwargs)[source]#
Bases:
ProtocolDefines a protocol that implements a training loop.
This protocol is intended to be called within the active learning loop during the training phase, where the model is trained on a specified number of epochs or training steps, and optionally validated on a dataset.
If a
LearnerProtocolis provided, thentrain_fnandvalidate_fnbecome optional as they will be defined within theLearnerProtocol. If they are provided, however, then they should override theLearnerProtocolvariants.If graph capture/compilation is intended, then
train_fnandvalidate_fnshould be wrapped withStaticCaptureTrainingandStaticCaptureEvaluateNoGrad, respectively.See also
DriverProtocolUses training loops in the training phase
TrainingProtocolTraining step protocol
ValidationProtocolValidation step protocol
LearnerProtocolLearner protocol with training/validation methods
- class physicsnemo.active_learning.protocols.TrainingProtocol(*args, **kwargs)[source]#
Bases:
ProtocolThis protocol defines the interface for training steps: given a model and some input data, compute the reduced, differentiable loss tensor and return it.
A concrete implementation can simply be a function with a signature that matches what is defined in
__call__().See also
TrainingLoopTraining loop protocol that uses this protocol
LearnerProtocolLearner protocol with a training_step method
ValidationProtocolValidation step protocol
- class physicsnemo.active_learning.protocols.ValidationProtocol(*args, **kwargs)[source]#
Bases:
ProtocolThis protocol defines the interface for validation steps: given a model and some input data, compute metrics of interest and if relevant to do so, log the results.
A concrete implementation can simply be a function with a signature that matches what is defined in
__call__().See also
TrainingLoopTraining loop protocol that uses this protocol
LearnerProtocolLearner protocol with a validation_step method
TrainingProtocolTraining step protocol
Configuration Classes#
These data structures are used to modify the behavior of different components of the active learning workflow. The general pattern is to ensure that they are JSON-serializable so that they can be checkpointed and restarted.
Configuration dataclasses for the active learning driver.
This module provides structured configuration classes that separate different concerns in the active learning workflow: optimization, training, strategies, and driver orchestration.
- class physicsnemo.active_learning.config.DriverConfig(
- batch_size: int,
- max_active_learning_steps: int | None = None,
- run_id: str = <factory>,
- fine_tuning_lr: float | None = None,
- reset_optim_states: bool = True,
- skip_training: bool = False,
- skip_metrology: bool = False,
- skip_labeling: bool = False,
- checkpoint_interval: int = 1,
- checkpoint_on_training: bool = False,
- checkpoint_on_metrology: bool = False,
- checkpoint_on_query: bool = False,
- checkpoint_on_labeling: bool = True,
- model_checkpoint_frequency: int = 0,
- root_log_dir: str | Path = PosixPath('/builds/modulus/physicsnemo-docs/docs/active_learning_logs'),
- dist_manager: DistributedManager | None = None,
- collate_fn: callable | None = None,
- num_dataloader_workers: int = 0,
- device: str | torch.device | None = None,
- dtype: torch.dtype | None = None,
Bases:
objectConfiguration for driver orchestration and infrastructure.
This contains parameters that control the overall loop execution, logging, checkpointing, and distributed training setup - orthogonal to the specific AL or training logic.
- batch_size#
The batch size to use for data loaders.
- Type:
int
- max_active_learning_steps#
Maximum number of AL iterations to perform. None means infinite.
- Type:
int or None
- run_id#
Unique identifier for this run. Auto-generated if not provided.
- Type:
str
- fine_tuning_lr#
Learning rate to switch to after the first AL step for fine-tuning.
- Type:
float or None
- reset_optim_states#
Whether to reset optimizer states between AL steps. Defaults to True.
- Type:
bool
- skip_training#
If True, skip the training phase entirely. Defaults to False.
- Type:
bool
- skip_metrology#
If True, skip the metrology phase entirely. Defaults to False.
- Type:
bool
- skip_labeling#
If True, skip the labeling phase entirely. Defaults to False.
- Type:
bool
- checkpoint_interval#
Save model checkpoint every N AL steps. 0 disables checkpointing. Defaults to 1.
- Type:
int
- checkpoint_on_training#
If True, save checkpoint at the start of the training phase. Defaults to False.
- Type:
bool
- checkpoint_on_metrology#
If True, save checkpoint at the start of the metrology phase. Defaults to False.
- Type:
bool
- checkpoint_on_query#
If True, save checkpoint at the start of the query phase. Defaults to False.
- Type:
bool
- checkpoint_on_labeling#
If True, save checkpoint at the start of the labeling phase. Defaults to True.
- Type:
bool
- model_checkpoint_frequency#
Save model weights every N epochs during training. 0 means only save between active learning phases. Useful for mid-training restarts. Defaults to 0.
- Type:
int
- root_log_dir#
Directory to save logs and checkpoints to. Defaults to an ‘active_learning_logs’ directory in the current working directory.
- Type:
str or
pathlib.Path
- dist_manager#
Manager for distributed training configuration.
- Type:
DistributedManageror None
- collate_fn#
Custom collate function for batching data.
- Type:
callable or None
- num_dataloader_workers#
Number of worker processes for data loading. Defaults to 0.
- Type:
int
- device#
Device to use for model and data. This is intended for single process workflows; for distributed workflows, the device should be set in
DistributedManagerinstead. If not specified, then the device will default totorch.get_default_device().- Type:
str or torch.device or None
- dtype#
The dtype to use for model and data, and AMP contexts. If not provided, then the dtype will default to
torch.get_default_dtype().- Type:
torch.dtype or None
See also
DriverUses this config for orchestration
TrainingConfigTraining configuration
StrategiesConfigStrategies configuration
DataPoolData pool protocol
AbstractQueueQueue protocol
- batch_size: int#
- checkpoint_interval: int = 1#
- checkpoint_on_labeling: bool = True#
- checkpoint_on_metrology: bool = False#
- checkpoint_on_query: bool = False#
- checkpoint_on_training: bool = False#
- collate_fn: callable | None = None#
- device: str | torch.device | None = None#
- dist_manager: DistributedManager | None = None#
- dtype: torch.dtype | None = None#
- fine_tuning_lr: float | None = None#
- classmethod from_json(
- json_str: str,
- **kwargs: Any,
Creates a DriverConfig instance from a JSON string.
This method reconstructs a DriverConfig from JSON. Note that certain fields cannot be serialized and must be provided via kwargs: -
dist_manager: DistributedManager instance (optional) -collate_fn: Custom collate function (optional)- Parameters:
json_str (str) – A JSON string that was previously serialized using
to_json().**kwargs (Any) – Additional keyword arguments to override or provide non-serializable fields like
dist_managerandcollate_fn.
- Returns:
A new
DriverConfiginstance.- Return type:
- Raises:
ValueError – If the JSON cannot be parsed or required fields are missing.
Notes
The device and dtype fields are reconstructed from their string representations. The
log_dirfield in JSON is mapped toroot_log_dirin the config.
- max_active_learning_steps: int | None = None#
- model_checkpoint_frequency: int = 0#
- num_dataloader_workers: int = 0#
- reset_optim_states: bool = True#
- root_log_dir: str | Path = PosixPath('/builds/modulus/physicsnemo-docs/docs/active_learning_logs')#
- run_id: str#
- skip_labeling: bool = False#
- skip_metrology: bool = False#
- skip_training: bool = False#
- class physicsnemo.active_learning.config.OptimizerConfig(
- optimizer_cls: type[~torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>,
- optimizer_kwargs: dict[str,
- ~typing.Any] = <factory>,
- scheduler_cls: type[~torch.optim.lr_scheduler._LRScheduler] | None = None,
- scheduler_kwargs: dict[str,
- ~typing.Any] = <factory>,
Bases:
objectConfiguration for optimizer and learning rate scheduler.
This encapsulates all training optimization parameters, keeping them separate from the active learning orchestration logic.
- optimizer_cls#
The optimizer class to use. Defaults to AdamW.
- Type:
type
- optimizer_kwargs#
Keyword arguments to pass to the optimizer constructor. Defaults to {“lr”: 1e-4}.
- Type:
dict
- scheduler_cls#
The learning rate scheduler class to use. If None, no scheduler will be configured.
- Type:
type or None
- scheduler_kwargs#
Keyword arguments to pass to the scheduler constructor.
- Type:
dict
See also
TrainingConfigUses this config for optimizer setup
DriverConfigures optimizer using this config
- classmethod from_dict(
- data: dict[str, Any],
Creates an OptimizerConfig instance from a dictionary.
This method assumes that the optimizer and scheduler classes are included in the
physicsnemo.active_learning.registry, or a module path is specified to import the class from.- Parameters:
data (dict[str, Any]) – A dictionary that was previously serialized using the
to_dictmethod.- Returns:
A new
OptimizerConfiginstance.- Return type:
- optimizer_cls#
alias of
AdamW
- optimizer_kwargs: dict[str, Any]#
- scheduler_cls: type[_LRScheduler] | None = None#
- scheduler_kwargs: dict[str, Any]#
- class physicsnemo.active_learning.config.StrategiesConfig(
- query_strategies: list[QueryStrategy],
- queue_cls: type[AbstractQueue],
- label_strategy: LabelStrategy | None = None,
- metrology_strategies: list[MetrologyStrategy] | None = None,
- unlabeled_datapool: DataPool | None = None,
Bases:
objectConfiguration for active learning strategies and data acquisition.
This encapsulates the query-label-metrology cycle that is at the heart of active learning: strategies for selecting data, labeling it, and measuring model uncertainty/performance.
- query_strategies#
The query strategies to use for selecting data to label. Each element should be a
QueryStrategyinstance.- Type:
list
- queue_cls#
The queue implementation to use for passing data between query and labeling phases. Should implement
AbstractQueue.- Type:
type
- label_strategy#
The strategy to use for labeling queried data. If None, labeling will be skipped.
- Type:
LabelStrategyor None
- metrology_strategies#
Strategies for measuring model performance and uncertainty. Each element should be a
MetrologyStrategyinstance. If None, metrology will be skipped.- Type:
list or None
- unlabeled_datapool#
Pool of unlabeled data that query strategies can sample from. Not all strategies require this (some may generate synthetic data).
- Type:
DataPoolor None
See also
DriverUses this config for strategy orchestration
QueryStrategyQuery strategy protocol
LabelStrategyLabel strategy protocol
MetrologyStrategyMetrology strategy protocol
- classmethod from_dict(
- data: dict[str, Any],
- **kwargs: Any,
Create a
StrategiesConfiginstance from a dictionary.This method heavily relies on classes being added to the
physicsnemo.active_learning.registry, which is used to instantiate all strategies and custom types used in active learning. As a fall back, the registry.construct method will try and import the class from the module path if it is not found in the registry.- Parameters:
data (dict[str, Any]) – A dictionary that was previously serialized using the
to_dictmethod.**kwargs (Any) – Additional keyword arguments to pass to the constructor.
- Returns:
A new
StrategiesConfiginstance.- Return type:
- Raises:
ValueError: – If the data contains unexpected keys or if the object construction fails.
NameError: – If a class is not found in the registry and no module path is provided.
ModuleNotFoundError: – If a module is not found with the specified module path.
- label_strategy: LabelStrategy | None = None#
- metrology_strategies: list[MetrologyStrategy] | None = None#
- query_strategies: list[QueryStrategy]#
- queue_cls: type[AbstractQueue]#
- to_dict() dict[str, Any][source]#
Method that converts the present
StrategiesConfiginstance into a dictionary that can be JSON serialized.This method, for the most part, assumes that strategies are subclasses of
ActiveLearningProtocoland/or they have an_argsattribute that captures the arguments to the constructor.One issue is the inability to reliably serialize the
unlabeled_datapool, which for the most part, likely does not need serialization as a dataset. Regardless, this method will trigger a warning ifunlabeled_datapoolis not None.- Returns:
A dictionary that can be JSON serialized.
- Return type:
dict[str, Any]
- class physicsnemo.active_learning.config.TrainingConfig(
- train_datapool: ~physicsnemo.active_learning.protocols.DataPool,
- max_training_epochs: int,
- val_datapool: ~physicsnemo.active_learning.protocols.DataPool | None = None,
- optimizer_config: ~physicsnemo.active_learning.config.OptimizerConfig = <factory>,
- max_fine_tuning_epochs: int | None = None,
- train_loop_fn: ~physicsnemo.active_learning.protocols.TrainingLoop = <factory>,
Bases:
objectConfiguration for the training phase of active learning.
This groups all training-related components together, making it clear when training is or isn’t being used in the AL workflow.
- max_training_epochs#
The maximum number of epochs to train for. If
max_fine_tuning_epochsisn’t specified, this value is used for all active learning steps.- Type:
int
- optimizer_config#
Configuration for the optimizer and scheduler. Defaults to AdamW with lr=1e-4, no scheduler.
- Type:
- max_fine_tuning_epochs#
The maximum number of epochs used during fine-tuning steps, i.e. after the first active learning step. If None, then the fine-tuning will be performed for the duration of the active learning loop.
- Type:
int or None
- train_loop_fn#
The training loop function that orchestrates the training process. This defaults to a concrete implementation,
DefaultTrainingLoop, which provides a very typical loop that includes the use of static capture, etc.- Type:
See also
DriverUses this config for training
OptimizerConfigOptimizer configuration
StrategiesConfigStrategies configuration
DefaultTrainingLoopDefault training loop implementation
- classmethod from_dict(
- data: dict[str, Any],
- **kwargs: Any,
Creates a TrainingConfig instance from a dictionary.
This method assumes that the training loop class is included in the
physicsnemo.active_learning.registry, or a module path is specified to import the class from. Note that datapools must be provided via kwargs as they are not serialized.- Parameters:
data (dict[str, Any]) – A dictionary that was previously serialized using the
to_dictmethod.**kwargs (Any) – Additional keyword arguments to pass to the constructor. This is where you must provide
train_datapooland optionallyval_datapool.
- Returns:
A new
TrainingConfiginstance.- Return type:
- Raises:
ValueError – If required datapools are not provided in kwargs, if the data contains unexpected keys, or if object construction fails.
- max_fine_tuning_epochs: int | None = None#
- max_training_epochs: int#
- optimizer_config: OptimizerConfig#
- to_dict() dict[str, Any][source]#
Returns a JSON-serializable dictionary representation of the TrainingConfig.
For round-tripping, the registry is used to de-serialize the training loop and optimizer configuration. Note that datapools (train_datapool and val_datapool) are NOT serialized as they typically contain large datasets, file handles, or other non-serializable state.
- Returns:
A dictionary that can be JSON serialized. Excludes datapools.
- Return type:
dict[str, Any]
Warning
This method will issue a warning about the exclusion of datapools.
- train_loop_fn: TrainingLoop#
Default Training Loop#
This module and corresponding
DefaultTrainingLoop class
implements the TrainingLoop interface,
and should provide most of the necessary boilerplate for model training
and fine-tuning; users will need to provide the training, validation,
and testing step protocols when configuring the loop.
- class physicsnemo.active_learning.loop.DefaultTrainingLoop(*args: Any, **kwargs: Any)[source]#
Bases:
TrainingLoopDefault implementation of the
TrainingLoopprotocol.This provides a functional training loop with support for static capture, progress bars, checkpointing, and distributed training. It implements the standard epoch-based training pattern with optional validation.
See also
TrainingLoopProtocol specification for training loops
DriverUses training loops in the training phase
TrainingConfigConfiguration for training
TrainingProtocolTraining step protocol
ValidationProtocolValidation step protocol
- property amp_type: dtype#
- static load_training_checkpoint(
- checkpoint_dir: Path,
- model: Module | LearnerProtocol,
- optimizer: Optimizer,
- lr_scheduler: _LRScheduler | None = None,
Load training state from checkpoint directory.
Model weights are loaded separately. Optimizer, scheduler, and epoch metadata are loaded from the combined training_state.pt file.
- Parameters:
checkpoint_dir (
pathlib.Path) – Directory containing checkpoint files.model (
ModuleorLearnerProtocol) – Model to load weights into.optimizer (Optimizer) – Optimizer to load state into.
lr_scheduler (_LRScheduler or None) – Optional LR scheduler to load state into.
- Returns:
Training epoch from metadata if available, else None.
- Return type:
int or None
- save_training_checkpoint(
- checkpoint_dir: Path,
- model: Module | LearnerProtocol,
- optimizer: Optimizer,
- lr_scheduler: _LRScheduler | None = None,
- training_epoch: int | None = None,
Save training state to checkpoint directory.
Model weights are saved separately. Optimizer, scheduler, and epoch metadata are combined into a single training_state.pt file.
- Parameters:
checkpoint_dir (
pathlib.Path) – Directory to save checkpoint files.model (
ModuleorLearnerProtocol) – Model to save weights for.optimizer (Optimizer) – Optimizer to save state from.
lr_scheduler (_LRScheduler or None) – Optional LR scheduler to save state from.
training_epoch (int or None) – Current training epoch for metadata.
Active Learning Driver#
This module and class implements the
DriverProtocol interface,
and is usable out-of-the-box for most active learning workflows. The
Driver class is configured
by DriverConfig, and serves
as the focal point for orchestrating the active learning.
This module contains the definition for an active learning driver class, which is responsible for orchestration and automation of the end-to-end active learning process.
- class physicsnemo.active_learning.driver.ActiveLearningCheckpoint(
- driver_config: DriverConfig,
- strategies_config: StrategiesConfig,
- active_learning_step_idx: int,
- active_learning_phase: ActiveLearningPhase,
- physicsnemo_version: str = '1.3.0a0',
- training_config: TrainingConfig | None = None,
- optimizer_state: dict[str, Any] | None = None,
- lr_scheduler_state: dict[str, Any] | None = None,
- has_query_queue: bool = False,
- has_label_queue: bool = False,
Bases:
objectMetadata associated with an ongoing (or completed) active learning experiment.
The information contained in this metadata should be sufficient to restart the active learning experiment at the nearest point: for example, training should be able to continue from an epoch, while for querying/sampling, etc. we continue from a pre-existing queue.
- driver_config#
Infrastructure and orchestration configuration.
- Type:
DriverConfig
- strategies_config#
Active learning strategies configuration.
- Type:
StrategiesConfig
- active_learning_step_idx#
Current iteration index of the active learning loop.
- Type:
int
- active_learning_phase#
Current phase of the active learning workflow.
- Type:
- physicsnemo_version#
Version of PhysicsNeMo used to create the checkpoint.
- Type:
str
- training_config#
Training components configuration, if training is used.
- Type:
TrainingConfigor None
- optimizer_state#
Optimizer state dictionary for checkpointing.
- Type:
dict or None
- lr_scheduler_state#
Learning rate scheduler state dictionary for checkpointing.
- Type:
dict or None
- has_query_queue#
Whether the checkpoint includes a query queue.
- Type:
bool
- has_label_queue#
Whether the checkpoint includes a label queue.
- Type:
bool
See also
DriverUses this dataclass for checkpointing
DriverConfigDriver configuration
StrategiesConfigStrategies configuration
TrainingConfigTraining configuration
- active_learning_phase: ActiveLearningPhase#
- active_learning_step_idx: int#
- driver_config: DriverConfig#
- has_label_queue: bool = False#
- has_query_queue: bool = False#
- lr_scheduler_state: dict[str, Any] | None = None#
- optimizer_state: dict[str, Any] | None = None#
- physicsnemo_version: str = '1.3.0a0'#
- strategies_config: StrategiesConfig#
- training_config: TrainingConfig | None = None#
- class physicsnemo.active_learning.driver.Driver(
- config: DriverConfig,
- learner: Module | LearnerProtocol,
- strategies_config: StrategiesConfig,
- training_config: TrainingConfig | None = None,
- inference_fn: InferenceProtocol | None = None,
Bases:
DriverProtocolProvides a simple implementation of the
DriverProtocolused to orchestrate an active learning process within PhysicsNeMo.At a high level, the active learning process is broken down into four phases: training, metrology, query, and labeling.
To understand the orchestration, start by inspecting the
active_learning_step()method, which defines a single iteration of the active learning loop, which is dispatched by therun()method. From there, it should be relatively straightforward to trace the remaining components.- config#
Infrastructure and orchestration configuration.
- Type:
DriverConfig
- learner#
The learner module for the active learning process.
- Type:
ModuleorLearnerProtocol
- strategies_config#
Active learning strategies (query, label, metrology).
- Type:
StrategiesConfig
- training_config#
Training components. None if training is skipped.
- Type:
TrainingConfigor None
- inference_fn#
Custom inference function.
- Type:
InferenceProtocolor None
- active_learning_step_idx#
Current iteration index of the active learning loop.
- Type:
int
- query_queue#
Queue populated with data by query strategies.
- Type:
- label_queue#
Queue populated with labeled data by the label strategy.
- Type:
- optimizer#
Configured optimizer (set after configure_optimizer is called).
- Type:
torch.optim.Optimizer or None
- lr_scheduler#
Configured learning rate scheduler.
- Type:
torch.optim.lr_scheduler._LRScheduler or None
- logger#
Persistent logger for the active learning process.
- Type:
logging.Logger
See also
DriverProtocolProtocol specification for active learning drivers
DriverConfigConfiguration for the driver
StrategiesConfigConfiguration for active learning strategies
TrainingConfigConfiguration for training
- active_learning_step(
- train_step_fn: TrainingProtocol | None = None,
- validate_step_fn: ValidationProtocol | None = None,
- *args: Any,
- **kwargs: Any,
Performs a single active learning iteration.
This method will perform the following sequence of steps: 1. Train the model stored in
Driver.learnerby creating data loaders withDriver.train_datapoolandDriver.val_datapool. 2. Run the metrology strategies stored inDriver.metrology_strategies. 3. Run the query strategies stored inDriver.query_strategies, if available. 4. Run the labeling strategy stored inDriver.label_strategy, if available.When entering each stage, we check to ensure all components necessary for the minimum function for that stage are available before proceeding.
If current_phase is set (e.g., from checkpoint resumption), only phases at or after current_phase will be executed. After completing all phases, current_phase is reset to None for the next AL step.
- Parameters:
train_step_fn (p.TrainingProtocol | None = None) – The training function to use for training. If not provided, then the
Driver.train_loop_fnwill be used.validate_step_fn (p.ValidationProtocol | None = None) – The validation function to use for validation. If not provided, then validation will not be performed.
args (Any) – Additional arguments to pass to the method. These will be passed to the training loop, metrology strategies, query strategies, and labeling strategies.
kwargs (Any) – Additional keyword arguments to pass to the method. These will be passed to the training loop, metrology strategies, query strategies, and labeling strategies.
- Raises:
ValueError – If any of the required components for a stage are not available.
- property active_learning_step_idx: int#
Returns the current active learning step index.
This represents the number of times the active learning step has been called, i.e. the number of iterations of the loop.
- Returns:
The current active learning step index.
- Return type:
int
- barrier() None[source]#
Wrapper to call barrier on the correct device.
Becomes a no-op if distributed is not initialized, otherwise will attempt to read the local device ID from either the distributed manager or the default device.
- property device: device#
Return a consistent device interface to use across the driver.
- property dist_manager: DistributedManager | None#
Returns the distributed manager, if it was specified as part of the DriverConfig configuration.
- Returns:
The distributed manager.
- Return type:
DistributedManager | None
- property is_lr_scheduler_configured: bool#
Returns whether the LR scheduler is configured.
- property is_optimizer_configured: bool#
Returns whether the optimizer is configured.
- property label_strategy: LabelStrategy | None#
Returns the label strategy from strategies_config.
- property last_checkpoint: Path | None#
Returns path to the most recently saved checkpoint.
- Returns:
Path to the last checkpoint directory, or None if no checkpoint has been saved yet.
- Return type:
Path | None
- classmethod load_checkpoint(
- checkpoint_path: str | Path,
- learner: Module | LearnerProtocol | None = None,
- train_datapool: DataPool | None = None,
- val_datapool: DataPool | None = None,
- unlabeled_datapool: DataPool | None = None,
- **kwargs: Any,
Load a Driver instance from a checkpoint.
Given a checkpoint directory, this method will attempt to reconstruct the driver and its associated components from the checkpoint. The checkpoint path must contain a
checkpoint.ptfile, which contains the metadata associated with the experiment.Additional parameters that might not be serialized with the checkpointing mechanism can/need to be provided to this method; for example when using non-physicsnemo.Module learners, and any data pools associated with the workflow.
Important
Currently, the strategy states are not reloaded from the checkpoint. This will be addressed in a future patch, but for now it is recommended to back up your strategy states (e.g. metrology records) manually before restarting experiments.
- Parameters:
checkpoint_path (str | Path) – Path to checkpoint directory containing checkpoint.pt and model weights.
learner (Module | p.LearnerProtocol | None) – Learner model to load weights into. If None, will attempt to reconstruct from checkpoint (only works for physicsnemo.Module).
train_datapool (p.DataPool | None) – Training datapool. Required if training_config exists in checkpoint.
val_datapool (p.DataPool | None) – Validation datapool. Optional.
unlabeled_datapool (p.DataPool | None) – Unlabeled datapool for query strategies. Optional.
**kwargs (Any) – Additional keyword arguments to override config values.
- Returns:
Reconstructed Driver instance ready to resume execution.
- Return type:
- property log_dir: Path#
Returns the log directory.
Note that this is the
DriverConfig.root_log_dircombined with the shortened run ID for the current run.Effectively, this means that each run will have its own directory for logs, checkpoints, etc.
- Returns:
The log directory.
- Return type:
Path
- property metrology_strategies: list[MetrologyStrategy] | None#
Returns the metrology strategies from strategies_config.
- property query_strategies: list[QueryStrategy]#
Returns the query strategies from strategies_config.
- run(
- train_step_fn: TrainingProtocol | None = None,
- validate_step_fn: ValidationProtocol | None = None,
- *args: Any,
- **kwargs: Any,
Runs the active learning loop until the maximum number of active learning steps is reached.
- Parameters:
train_step_fn (p.TrainingProtocol | None = None) – The training function to use for training. If not provided, then the
Driver.train_loop_fnwill be used.validate_step_fn (p.ValidationProtocol | None = None) – The validation function to use for validation. If not provided, then validation will not be performed.
args (Any) – Additional arguments to pass to the method. These will be passed to the training loop, metrology strategies, query strategies, and labeling strategies.
kwargs (Any) – Additional keyword arguments to pass to the method. These will be passed to the training loop, metrology strategies, query strategies, and labeling strategies.
- property run_id: str#
Returns the run id from the
DriverConfig.- Returns:
The run id.
- Return type:
str
- save_checkpoint(
- path: str | Path | None = None,
- training_epoch: int | None = None,
Save a checkpoint of the active learning experiment.
Saves AL orchestration state (configs, queues, step index, phase) and model weights. Training-specific state (optimizer, scheduler) is handled by DefaultTrainingLoop and saved to training_state.pt during training.
- Parameters:
path (str | Path | None) – Path to save checkpoint. If None, creates path based on current AL step index and phase: log_dir/checkpoints/step_{idx}/{phase}/
training_epoch (int | None) – Optional epoch number for mid-training checkpoints.
- Returns:
Checkpoint directory path, or None if checkpoint not saved (non-rank-0 in distributed).
- Return type:
Path | None
- property short_run_id: str#
Returns the first 8 characters of the run id.
The 8 character limit assumes that the run ID is a UUID4. This is particularly useful for user-facing interfaces, where you do not necessarily want to reference the full UUID.
- Returns:
The first 8 characters of the run id.
- Return type:
str
- property train_loop_fn: TrainingLoop | None#
Returns the training loop function from training_config.
Active Learning Registry#
The registry provides a centralized location for registering and constructing custom active learning strategies. It enables string-based lookups for checkpointing and provides argument validation when constructing protocol instances.
Note
Users should not use the class directly, but rather the instance of the
class through the registry object,
documented below.
- physicsnemo.active_learning.registry = ActiveLearningRegistry()#
Registry for active learning protocols.
This class provides a centralized registry for user-defined active learning protocols that implement the
ActiveLearningProtocol. It enables string-based lookups for checkpointing and provides argument validation when constructing protocol instances.The registry supports two primary modes of interaction: 1. Registration via decorator:
@registry.register("my_strategy")2. Construction with validation:registry.construct("my_strategy", **kwargs)- physicsnemo.active_learning._registry#
Internal dictionary mapping protocol names to their class types.
- Type:
dict
- physicsnemo.active_learning.register(cls_name: str) Callable#
Decorator to register a protocol class with a given name.
- physicsnemo.active_learning.construct(
- cls_name: str,
- \*\*kwargs,
Construct an instance of a registered protocol with argument validation.
- physicsnemo.active_learning.is_registered(cls_name: str) bool#
Check if a protocol name is registered.
- physicsnemo.active_learning.Properties()#
- ----------
- registered_names : list
A list of all registered protocol names, sorted alphabetically.
See also
ActiveLearningProtocolBase protocol for active learning strategies
QueryStrategyQuery strategy protocol
LabelStrategyLabel strategy protocol
MetrologyStrategyMetrology strategy protocol
Examples
Register a custom strategy:
>>> from physicsnemo.active_learning._registry import registry >>> @registry.register("my_custom_strategy") ... class MyCustomStrategy: ... def __init__(self, param1: int, param2: str): ... self.param1 = param1 ... self.param2 = param2
Construct an instance with validation:
>>> strategy = registry.construct("my_custom_strategy", param1=42, param2="test")
Global registry instance for active learning protocols.