Logging and Checkpointing#
Logging and checkpointing are important components of the model training workflow. They allow you to keep a record of the model hyperparameters and its performance during training.
In this tutorial, you will explore some of the utilities from PhysicsNeMo to simplify this important aspect of model training.
Logging in PhysicsNeMo#
PhysicsNeMo provides utilities to standardize the logs of different training runs. Using the logging utilities from PhysicsNeMo, you have the flexibility of choosing between the good-old console logging to more advanced ML experiment trackers like MLflow and Weights & Biases. You can always implement these loggers yourself, but in this example, you will use the utilities from PhysicsNeMo that will not only simplify this process but also provide a standardized output format. Let’s get started.
Console Logging#
The example below shows a setup using console logging.
import torch
import physicsnemo
from physicsnemo.datapipes.benchmarks.darcy import Darcy2D
from physicsnemo.launch.logging import LaunchLogger, PythonLogger
from physicsnemo.metrics.general.mse import mse
from physicsnemo.models.fno.fno import FNO
normaliser = {
"permeability": (1.25, 0.75),
"darcy": (4.52e-2, 2.79e-2),
}
dataloader = Darcy2D(
resolution=256, batch_size=64, nr_permeability_freq=5, normaliser=normaliser
)
model = FNO(
in_channels=1,
out_channels=1,
decoder_layers=1,
decoder_layer_size=32,
dimension=2,
latent_channels=32,
num_fno_layers=4,
num_fno_modes=12,
padding=5,
).to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: 0.85**step
)
# Initialize the logger
logger = PythonLogger("main") # General python logger
LaunchLogger.initialize()
# Use logger methods to track various information during training
logger.info("Starting Training!")
# we will setup the training to run for 20 epochs each epoch running for 5 iterations
dataloader = iter(dataloader)
for i in range(20):
# wrap the epoch in launch logger to control frequency of output for console logs
with LaunchLogger("train", epoch=i) as launchlog:
# this would be iterations through different batches
for _ in range(5):
batch = next(dataloader)
truth = batch["darcy"]
pred = model(batch["permeability"])
loss = mse(pred, truth)
loss.backward()
optimizer.step()
scheduler.step()
launchlog.log_minibatch({"Loss": loss.detach().cpu().numpy()})
launchlog.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]})
logger.info("Finished Training!")
The logger output can be seen below.
Warp 0.10.1 initialized:
CUDA Toolkit: 11.5, Driver: 12.2
Devices:
"cpu" | x86_64
"cuda:0" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:1" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:2" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:3" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:4" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:5" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:6" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:7" | Tesla V100-SXM2-16GB-N (sm_70)
Kernel cache: /root/.cache/warp/0.10.1
/usr/local/lib/python3.10/dist-packages/pydantic/_internal/_fields.py:128: UserWarning: Field "model_server_url" has conflict with protected namespace "model_".
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.
warnings.warn(
/usr/local/lib/python3.10/dist-packages/pydantic/_internal/_config.py:317: UserWarning: Valid config keys have changed in V2:
* 'schema_extra' has been renamed to 'json_schema_extra'
warnings.warn(message, UserWarning)
[21:23:57 - main - INFO] Starting Training!
Module physicsnemo.datapipes.benchmarks.kernels.initialization load on device 'cuda:0' took 73.06 ms
Module physicsnemo.datapipes.benchmarks.kernels.utils load on device 'cuda:0' took 314.91 ms
Module physicsnemo.datapipes.benchmarks.kernels.finite_difference load on device 'cuda:0' took 149.86 ms
[21:24:02 - train - INFO] Epoch 0 Metrics: Learning Rate = 4.437e-03, Loss = 1.009e+00
[21:24:02 - train - INFO] Epoch Execution Time: 5.664e+00s, Time/Iter: 1.133e+03ms
[21:24:06 - train - INFO] Epoch 1 Metrics: Learning Rate = 1.969e-03, Loss = 6.040e-01
[21:24:06 - train - INFO] Epoch Execution Time: 4.013e+00s, Time/Iter: 8.025e+02ms
...
[21:25:32 - train - INFO] Epoch 19 Metrics: Learning Rate = 8.748e-10, Loss = 1.384e-01
[21:25:32 - train - INFO] Epoch Execution Time: 4.010e+00s, Time/Iter: 8.020e+02ms
[21:25:32 - main - INFO] Finished Training!
MLflow Logging#
The example below shows a setup using MLflow logging. The only difference from
the previous example is that you will use the initialize_mlflow
function to initialize
the MLflow client and also set use_mlflow=True
when initializing the LaunchLogger
.
import torch
import physicsnemo
from physicsnemo.datapipes.benchmarks.darcy import Darcy2D
from physicsnemo.launch.logging import LaunchLogger, PythonLogger
from physicsnemo.launch.logging.mlflow import initialize_mlflow
from physicsnemo.metrics.general.mse import mse
from physicsnemo.models.fno.fno import FNO
normaliser = {
"permeability": (1.25, 0.75),
"darcy": (4.52e-2, 2.79e-2),
}
dataloader = Darcy2D(
resolution=256, batch_size=64, nr_permeability_freq=5, normaliser=normaliser
)
model = FNO(
in_channels=1,
out_channels=1,
decoder_layers=1,
decoder_layer_size=32,
dimension=2,
latent_channels=32,
num_fno_layers=4,
num_fno_modes=12,
padding=5,
).to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: 0.85**step
)
# Initialize the console logger
logger = PythonLogger("main") # General python logger
# Initialize the MLFlow logger
initialize_mlflow(
experiment_name="PhysicsNeMo Tutorials",
experiment_desc="Simple PhysicsNeMo Tutorials",
run_name="PhysicsNeMo MLFLow Tutorial",
run_desc="PhysicsNeMo Tutorial Training",
user_name="PhysicsNeMo User",
mode="offline",
)
LaunchLogger.initialize(use_mlflow=True)
# Use logger methods to track various information during training
logger.info("Starting Training!")
# we will setup the training to run for 20 epochs each epoch running for 5 iterations
dataloader = iter(dataloader)
for i in range(20):
# wrap the epoch in launch logger to control frequency of output for console logs
with LaunchLogger("train", epoch=i) as launchlog:
for _ in range(5):
batch = next(dataloader)
truth = batch["darcy"]
pred = model(batch["permeability"])
loss = mse(pred, truth)
loss.backward()
optimizer.step()
scheduler.step()
launchlog.log_minibatch({"Loss": loss.detach().cpu().numpy()})
launchlog.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]})
logger.info("Finished Training!")
During the run, you will notice a directory named mlruns_0
created which stores
the MLflow logs. To visualize the logs interactively, you can run the following:
mlflow ui --backend-store-uri mlruns_0/
And then navigate to localhost:5000 in your favorite browser.
Warning
Currently the MLflow logger will log the output of each processor separately. So in multi-processor runs, you will see multiple directories being created. This is a known issue and will be fixed in future releases.
Weights & Biases Logging#
The example below shows a setup using Weights & Biases logging. The only
difference from the previous example is that you will use the initialize_wandb
function
to initialize the Weights & Biases logger and also set use_wandb=True
when
initializing the LaunchLogger
.
import torch
import physicsnemo
from physicsnemo.datapipes.benchmarks.darcy import Darcy2D
from physicsnemo.launch.logging import LaunchLogger, PythonLogger
from physicsnemo.launch.logging.wandb import initialize_wandb
from physicsnemo.metrics.general.mse import mse
from physicsnemo.models.fno.fno import FNO
normaliser = {
"permeability": (1.25, 0.75),
"darcy": (4.52e-2, 2.79e-2),
}
dataloader = Darcy2D(
resolution=256, batch_size=64, nr_permeability_freq=5, normaliser=normaliser
)
model = FNO(
in_channels=1,
out_channels=1,
decoder_layers=1,
decoder_layer_size=32,
dimension=2,
latent_channels=32,
num_fno_layers=4,
num_fno_modes=12,
padding=5,
).to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: 0.85**step
)
# Initialize the console logger
logger = PythonLogger("main") # General python logger
# Initialize the MLFlow logger
initialize_wandb(
project="PhysicsNeMo Tutorials",
name="Simple PhysicsNeMo Tutorials",
entity="PhysicsNeMo MLFLow Tutorial",
mode="offline",
)
LaunchLogger.initialize(use_wandb=True)
# Use logger methods to track various information during training
logger.info("Starting Training!")
# we will setup the training to run for 20 epochs each epoch running for 10 iterations
dataloader = iter(dataloader)
for i in range(20):
# wrap the epoch in launch logger to control frequency of output for console logs
with LaunchLogger("train", epoch=i) as launchlog:
# this would be iterations through different batches
for _ in range(10):
batch = next(dataloader)
truth = batch["darcy"]
pred = model(batch["permeability"])
loss = mse(pred, truth)
loss.backward()
optimizer.step()
scheduler.step()
launchlog.log_minibatch({"Loss": loss.detach().cpu().numpy()})
launchlog.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]})
logger.info("Finished Training!")
During the run, you will notice a directory named wandb
created which stores
the W&B logs.
The logger output can also be seen below.
Warp 0.10.1 initialized:
CUDA Toolkit: 11.5, Driver: 12.2
Devices:
"cpu" | x86_64
"cuda:0" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:1" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:2" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:3" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:4" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:5" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:6" | Tesla V100-SXM2-16GB-N (sm_70)
"cuda:7" | Tesla V100-SXM2-16GB-N (sm_70)
Kernel cache: /root/.cache/warp/0.10.1
/usr/local/lib/python3.10/dist-packages/pydantic/_internal/_fields.py:128: UserWarning: Field "model_server_url" has conflict with protected namespace "model_".
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.
warnings.warn(
/usr/local/lib/python3.10/dist-packages/pydantic/_internal/_config.py:317: UserWarning: Valid config keys have changed in V2:
* 'schema_extra' has been renamed to 'json_schema_extra'
warnings.warn(message, UserWarning)
wandb: Tracking run with wandb version 0.15.12
wandb: W&B syncing is set to `offline` in this directory.
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
[21:26:38 - main - INFO] Starting Training!
Module physicsnemo.datapipes.benchmarks.kernels.initialization load on device 'cuda:0' took 74.11 ms
Module physicsnemo.datapipes.benchmarks.kernels.utils load on device 'cuda:0' took 310.06 ms
Module physicsnemo.datapipes.benchmarks.kernels.finite_difference load on device 'cuda:0' took 151.24 ms
[21:26:48 - train - INFO] Epoch 0 Metrics: Learning Rate = 1.969e-03, Loss = 7.164e-01
[21:26:48 - train - INFO] Epoch Execution Time: 9.703e+00s, Time/Iter: 9.703e+02ms
...
[21:29:47 - train - INFO] Epoch 19 Metrics: Learning Rate = 7.652e-17, Loss = 3.519e-01
[21:29:47 - train - INFO] Epoch Execution Time: 1.125e+01s, Time/Iter: 1.125e+03ms
[21:29:47 - main - INFO] Finished Training!
wandb: Waiting for W&B process to finish... (success).
wandb:
wandb: Run history:
wandb: epoch ▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
wandb: train/Epoch Time (s) ▃▁▃▃▃▃▁█▁▁▁▃▃▃▃▆▁▃▃▆
wandb: train/Learning Rate █▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: train/Loss █▁▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
wandb: train/Time per iter (ms) ▃▁▃▃▃▃▁█▁▁▁▃▃▃▃▆▁▃▃▆
wandb:
wandb: Run summary:
wandb: epoch 19
wandb: train/Epoch Time (s) 11.24806
wandb: train/Learning Rate 0.0
wandb: train/Loss 0.35193
wandb: train/Time per iter (ms) 1124.80645
wandb:
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /workspace/physicsnemo/docs/test_scripts/wandb/wandb/offline-run-20231115_212638-ib4ylq4e
wandb: Find logs at: ./wandb/wandb/offline-run-20231115_212638-ib4ylq4e/logs
To visualize the logs interactively, simply follow the instructions printed in the outputs.
Checkpointing in PhysicsNeMo#
PhysicsNeMo provides easy utilities to save and load the checkpoints of the model, optimizer, scheduler, and scaler during training and inference. Similar to logging, custom implementation can be used, but in this example you will see the utilities from PhysicsNeMo and some of its benefits.
Loading and Saving Checkpoints During Training#
The example below shows how you can save and load a checkpoint during training. The implementation
allows the model training to be resumed from the last saved checkpoint. Here, you will
demonstrate the use of load_checkpoint
and the save_checkpoint
functions.
import torch
import physicsnemo
from physicsnemo.datapipes.benchmarks.darcy import Darcy2D
from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
from physicsnemo.metrics.general.mse import mse
from physicsnemo.models.fno.fno import FNO
normaliser = {
"permeability": (1.25, 0.75),
"darcy": (4.52e-2, 2.79e-2),
}
dataloader = Darcy2D(
resolution=256, batch_size=64, nr_permeability_freq=5, normaliser=normaliser
)
model = FNO(
in_channels=1,
out_channels=1,
decoder_layers=1,
decoder_layer_size=32,
dimension=2,
latent_channels=32,
num_fno_layers=4,
num_fno_modes=12,
padding=5,
).to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: 0.85**step
)
# load the epoch and optimizer, model and scheduler parameters from the checkpoint if
# it exists. Here we will use the `load_checkpoint` function to load the checkpoint,
# optimizer, and scheduler parameters from the checkpoint.
loaded_epoch = load_checkpoint(
"./checkpoints",
models=model,
optimizer=optimizer,
scheduler=scheduler,
device="cuda",
)
# we will setup the training to run for 20 epochs each epoch running for 5 iterations
# starting with the loaded epoch
dataloader = iter(dataloader)
for i in range(max(1, loaded_epoch), 20):
# this would be iterations through different batches
for _ in range(5):
batch = next(dataloader)
true = batch["darcy"]
pred = model(batch["permeability"])
loss = mse(pred, true)
loss.backward()
optimizer.step()
scheduler.step()
# save checkpoint every 5th epoch
if i % 5 == 0:
save_checkpoint(
"./checkpoints",
models=model,
optimizer=optimizer,
scheduler=scheduler,
epoch=i,
)
The output of the above script when loaded from a partially trained model will be something like below.
>>> python test_scripts/test_basic_checkpointing.py
...
[23:11:09 - checkpoint - INFO] Loaded model state dictionary /workspace/release_23.11/docs_upgrade/physicsnemo/docs/checkpoints/FourierNeuralOperator.0.10.mdlus to device cuda
[23:11:09 - checkpoint - INFO] Loaded checkpoint file /workspace/release_23.11/docs_upgrade/physicsnemo/docs/checkpoints/checkpoint.0.10.pt to device cuda
[23:11:09 - checkpoint - INFO] Loaded optimizer state dictionary
[23:11:09 - checkpoint - INFO] Loaded scheduler state dictionary
...
[23:11:11 - checkpoint - INFO] Saved model state dictionary: /workspace/release_23.11/docs_upgrade/physicsnemo/docs/checkpoints/FourierNeuralOperator.0.10.mdlus
[23:11:12 - checkpoint - INFO] Saved training checkpoint: /workspace/release_23.11/docs_upgrade/physicsnemo/docs/checkpoints/checkpoint.0.10.pt
[23:11:16 - checkpoint - INFO] Saved model state dictionary: /workspace/release_23.11/docs_upgrade/physicsnemo/docs/checkpoints/FourierNeuralOperator.0.15.mdlus
[23:11:16 - checkpoint - INFO] Saved training checkpoint: /workspace/release_23.11/docs_upgrade/physicsnemo/docs/checkpoints/checkpoint.0.15.pt
[23:11:21 - checkpoint - INFO] Saved model state dictionary: /workspace/release_23.11/docs_upgrade/physicsnemo/docs/checkpoints/FourierNeuralOperator.0.20.mdlus
[23:11:21 - checkpoint - INFO] Saved training checkpoint: /workspace/release_23.11/docs_upgrade/physicsnemo/docs/checkpoints/checkpoint.0.20.pt
Loading Checkpoints During Inference#
For loading the checkpoint in inference, the process is straightforward, and you can refer to the samples provided in Running Inference on Trained Models and Saving and Loading PhysicsNeMo Models.