Dynamic mode checkpointing#

This notebook demonstrates how to save and restore the state of a DALI dynamic mode data loading loop. In dynamic mode, the stateful objects are individual readers and random number generators - they expose get_state / set_state methods, and the `Checkpoint <checkpointing.rst>`__ helper aggregates the state of a registered set of them into a single serialized blob.

The example reads JPEG images from DALI_EXTRA_PATH/db/single/mixed with ndd.readers.File, decodes them, and applies a rotation by a random angle in the range (-30, 30) degrees. We will run the loop for a couple of warmup iterations, save a checkpoint, and then build a fresh reader + RNG pair that resumes from exactly the captured point.

[1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import nvidia.dali.experimental.dynamic as ndd

IMAGE_DIR = os.path.join(os.environ["DALI_EXTRA_PATH"], "db", "single", "mixed")
BATCH_SIZE = 8

Building the augmentation#

augment takes a (jpegs, labels) batch from the reader and returns the rotated images together with the per-sample labels and the random angles that were used. Returning the angles lets us later confirm that the RNG produced the same draws after the checkpoint was restored.

[2]:
def augment(jpegs, labels, rng):
    images = ndd.decoders.image(jpegs, device="gpu")
    angles = ndd.random.uniform(
        batch_size=BATCH_SIZE, range=(-30.0, 30.0), rng=rng
    )
    rotated = ndd.rotate(images, angle=angles, fill_value=0)
    return rotated, labels, angles

Saving a checkpoint#

We construct the reader with enable_checkpointing=True so that its prefetch backend keeps the per-iteration snapshots needed for saving a checkpoint. The reader and the RNG are then registered in Checkpoint object. We let the pipeline run for a couple of iterations after which the checkpoint is collected and saved to a file.

[3]:
reader = ndd.readers.File(
    file_root=IMAGE_DIR,
    random_shuffle=True,
    enable_checkpointing=True,
    seed=1234,
    name="image_reader",
)
rng = ndd.random.RNG(seed=42)

ckpt = ndd.checkpoint.current()
ckpt.register(reader, "image_reader")
ckpt.register(rng, "rng")

warmup_iters = 3
epoch = reader.next_epoch(batch_size=BATCH_SIZE)
for _ in range(warmup_iters):
    jpegs, labels = next(epoch)
    augment(jpegs, labels, rng)

ckpt.collect()
checkpoint_path = ckpt.save("ndd_checkpoint_{seq:04d}.json")
print(f"Saved checkpoint to {checkpoint_path}")
Saved checkpoint to ndd_checkpoint_0000.json

Restoring the state#

In the previous step we’ve saved the state of the reader and the RNG. Once we restore that state in a new reader and RNG, the next batch they will produce will be the same as in the original pipeline. Let’s run a comparison.

Producing a reference batch#

We pull one more batch out of the original reader and RNG. This is the batch we will compare against the first batch produced by the restored loop.

[4]:
ref_jpegs, ref_labels = next(epoch)
ref_rotated, ref_labels, ref_angles = augment(ref_jpegs, ref_labels, rng)


def to_numpy_list(batch):
    return [np.array(sample) for sample in batch.cpu()]


ref_label_values = [int(t.flatten()[0]) for t in to_numpy_list(ref_labels)]
ref_angle_values = [float(t.flatten()[0]) for t in to_numpy_list(ref_angles)]

images_np = to_numpy_list(ref_rotated)
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(images_np[i])
    ax.set_title(
        f"label={ref_label_values[i]}, angle={ref_angle_values[i]:.2f}"
    )
    ax.axis("off")
plt.tight_layout()
plt.show()
../_images/dali_dynamic_checkpointing_tutorial_7_0.png

Restoring from the checkpoint#

On the restore side we build a fresh reader and a fresh RNG, load the saved checkpoint, and register both objects under the same names as before. Checkpoint.register applies the loaded state to each object as soon as it is added - the reader’s state must be applied before its prefetch thread starts, so we do this before the first next_epoch call.

[5]:
restored_reader = ndd.readers.File(
    file_root=IMAGE_DIR,
    random_shuffle=True,
    enable_checkpointing=True,
    name="image_reader",
)
restored_rng = ndd.random.RNG()

restored_ckpt = ndd.checkpoint.Checkpoint()
restored_ckpt.load("ndd_checkpoint_{seq:04d}.json")
restored_ckpt.register(restored_reader, "image_reader")
restored_ckpt.register(restored_rng, "rng")

restored_epoch = restored_reader.next_epoch(batch_size=BATCH_SIZE)
rest_jpegs, rest_labels = next(restored_epoch)
rest_rotated, rest_labels, rest_angles = augment(
    rest_jpegs, rest_labels, restored_rng
)

rest_label_values = [int(t.flatten()[0]) for t in to_numpy_list(rest_labels)]
rest_angle_values = [float(t.flatten()[0]) for t in to_numpy_list(rest_angles)]

Verifying the resume#

The labels track which files the reader picked, in order - matching labels mean the reader resumed at the right position and kept the same shuffle. Matching angles mean the RNG was restored to the same internal state.

[6]:
print("labels (original): ", ref_label_values)
print("labels (restored): ", rest_label_values)
print("angles (original): ", [f"{a:.4f}" for a in ref_angle_values])
print("angles (restored): ", [f"{a:.4f}" for a in rest_angle_values])

assert ref_label_values == rest_label_values, "reader order diverged"
assert ref_angle_values == rest_angle_values, "RNG state diverged"
print("\nFile order and random angles match exactly.")
labels (original):  [1, 1, 2, 4, 3, 3, 1, 3]
labels (restored):  [1, 1, 2, 4, 3, 3, 1, 3]
angles (original):  ['13.1024', '-9.7800', '-26.1646', '-25.4199', '-9.6380', '16.9450', '-29.3992', '26.0642']
angles (restored):  ['13.1024', '-9.7800', '-26.1646', '-25.4199', '-9.6380', '16.9450', '-29.3992', '26.0642']

File order and random angles match exactly.

Finally, let’s display the output.

[7]:
images_np = to_numpy_list(rest_rotated)
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(images_np[i])
    ax.set_title(
        f"label={rest_label_values[i]}, angle={rest_angle_values[i]:.2f}"
    )
    ax.axis("off")
plt.tight_layout()
plt.show()
../_images/dali_dynamic_checkpointing_tutorial_13_0.png