Checkpointing#

Dynamic mode pipelines can produce checkpoints that capture the state of all stateful operators - readers and random number generators - so that processing can be resumed at the captured iteration. Operators that do not maintain user-observable state (decoders, resizes, normalizations, etc.) are conceptually stateless and are not part of a checkpoint.

This page describes the two checkpointing layers that DALI Dynamic exposes:

  1. A manual get_state / set_state interface on individual readers and RNGs.

  2. A semi-automatic Checkpoint aggregator that collects, serializes, and restores the state of a registered set of objects.

Note

A Reader’s state can be applied only to a freshly constructed reader, before its first iteration. The underlying prefetch thread starts on the first call into the reader, after which the snapshot queue is locked. Calls to set_state() after iteration has begun raise a RuntimeError.

Manual checkpointing#

Note

Manual checkpointing is a an advanced feature. It is a set building blocks for higher level systems, including the built-in semi-automatic checkpointing. It allows fine-grained control over individual reader or RNG states, enabling integration with pre-existing checkpoint systems or transferring state of compatible objects across process boundary. In typical usage scenario, it’s more convenent to use the semi-automatic checkpointing.

Both RNG and the readers exposed via ndd.readers.* provide get_state and set_state methods. The state object returned by get_state can be converted to a string with str(), and set_state accepts either the state object or its string representation:

import nvidia.dali.experimental.dynamic as ndd

reader = ndd.readers.File(file_root="...")
it = reader.next_epoch(batch_size=16)

# Iterate for a while...
first = next(it)
second = next(it)

# Capture a checkpoint after the second batch.
reader_state = reader.get_state()
serialized = str(reader_state)  # safe to write to disk, send over the wire, etc.

# Later, on a fresh reader:
resumed = ndd.readers.File(file_root="...")
resumed.set_state(serialized)
for batch in resumed.next_epoch(batch_size=16):
    ...  # produces the third batch first

The RNG interface is symmetric:

rng = ndd.random.RNG(seed=42)
rng_state = rng.get_state()
...
rng.set_state(rng_state)

Semi-automatic checkpointing with Checkpoint#

The Checkpoint class collects the state of a registered set of stateful objects, serializes it to a single string, and restores the state of new objects from that string.

A typical save/restore cycle looks like this:

import nvidia.dali.experimental.dynamic as ndd

reader = ndd.readers.File(file_root="...")
rng = ndd.random.RNG(seed=42)

ckpt = ndd.checkpoint.Checkpoint()
ckpt.register(reader, "reader")
ckpt.register(rng, "rng")

# ... iterate for some time ...

ckpt.collect()                          # capture the current state
ckpt.save("ckpt_{seq:04d}.json")        # writes ckpt_0000.json, ckpt_0001.json, ...

Restoring from disk is the symmetric operation:

reader = ndd.readers.File(file_root="...")
rng = ndd.random.RNG()

ckpt = ndd.checkpoint.Checkpoint()
ckpt.load("ckpt_{seq:04d}.json")        # picks up the most recent file
ckpt.register(reader, "reader")          # state applied implicitly here
ckpt.register(rng, "rng")                # ditto

for batch in reader.next_epoch(batch_size=16):
    ...

Current checkpoint#

The convenience function checkpoint.current() returns the Checkpoint bound to the current EvalContext. This function allows the code hidden behind function calls to use checkpointing without modifying the API to pass the context explicitly.

Using checkpoint.current():

ckpt = ndd.checkpoint.current()
ckpt.register(reader, "reader")

Registration semantics#

register() accepts an optional name argument:

  • If name is provided, the entry is stored under that key. Any previous op registered under the same key is replaced.

  • If name is omitted, the checkpoint first looks up the op by identity. If it is already registered, the existing key is returned. Otherwise, internally generated sequential names are used.

When the checkpoint is in loaded state and the registered key is present in the loaded dictionary, the saved state is applied to the op immediately. This makes the load/restore flow above a single line per op.

Lifecycle flags#

The is_complete and is_loaded properties reflect the most recent operation that populated the state dictionary:

  • collect() sets is_complete and clears is_loaded. New ops cannot be registered (call clear() to reset).

  • deserialize() (and load()) set is_loaded and clear is_complete. Subsequent register() calls must use keys that exist in the loaded state.

Filename patterns#

save() and load() take a Python format string with a single {seq} placeholder. save substitutes the next free sequence number, load picks the highest one matching the pattern on disk. Format specifiers (e.g. {seq:04d}) are honored.

Manual restore#

The state of every registered op is applied implicitly when the op is added to a loaded checkpoint via register(). The restore() method is the explicit counterpart - it applies all currently dirty states in one call, and is mostly useful when the ops were registered before the state was supplied (e.g. via set_state() or deserialize()).

Limitations#

A few constraints to keep in mind when using Checkpoint:

  • Readers must opt in. A Reader only supports checkpointing when constructed with enable_checkpointing=True; the backend then maintains a prefetch_queue_depth + 1 deep snapshot queue, which has a small runtime cost. Registering a reader that was not opted in is allowed only if its backend has not yet been initialized - the call to register() will then enable checkpointing retroactively; otherwise it raises a RuntimeError.

  • Compiled mode is not supported. Calling Reader.next_epoch with compile=True on a reader that has checkpointing enabled (or vice versa) raises NotImplementedError.

  • Reader state must be applied early. Reader.set_state (and any buffered state propagated through register()) must run before the reader’s first iteration; the prefetch thread cannot be restored once it has started.

  • The order of anonymous registrations matters. When a checkpoint is loaded and ops are re-added without explicit names, the same number of ops must be registered in the same order as at save time. The count is validated, and a stored type tag is checked at apply time - so cross-type swaps fail loudly - but registering ops of compatible types in a different order is not detected. Prefer named registration when in doubt.

  • Format version is strict. deserialize() rejects payloads whose version does not match the current format with a ValueError; there is no automatic upgrade path.

  • Not thread-safe. A single Checkpoint instance must not be accessed concurrently from multiple threads. The EvalContext-bound checkpoint shared by checkpoint.current() follows the same rule.

  • The default ``EvalContext`` is reused. ndd.checkpoint.current() returns the checkpoint bound to the thread-local default EvalContext, which lives for the lifetime of the process (or the enclosing with EvalContext(...): block). Registrations accumulate across unrelated runs unless you call clear() between them.

API reference#

ReaderState#

class nvidia.dali.experimental.dynamic._ops.ReaderState(op, serialized)#

Serialized checkpoint state of a Reader.

Wraps the serialized representation produced by the underlying operator’s SerializeCheckpoint(). It can be converted to a string with str(), saved to disk, and later passed to Reader.set_state() to restore the reader to the captured iteration position.

The object also keeps a reference to the originating operator so that future extensions can re-serialize the live state on demand.

Reader checkpoint methods#

Reader.get_state(*, cuda_stream=None)#

Returns the current checkpoint state of this reader.

The returned state object captures the iteration position of the underlying loader. It can be passed back to set_state() to resume processing from this point. The state is serialized to a string by str(state).

Warning

The methods get_state and set_state are not inherently thread-safe with respect to running the reader. External synchronization is necessary.

Note

If there are any pending asynchronous or deferred calls to this operator, the function will wait for them to finish before getting the state.

Parameters:

cuda_stream – The CUDA stream on which the readers is running or None.

Returns:

An opaque state object that wraps a serialized checkpoint string.

Return type:

ReaderState

Raises:

RuntimeError – If the reader has not started any epoch yet (the backend has not been initialized) or if the underlying reader does not support checkpointing.

Reader.set_state(state)#

Restores the reader’s iteration position from a saved state.

Parameters:
  • state (ReaderState or str) – Either a state object obtained from get_state(), or its string representation (as produced by str(state)).

  • warning:: (..) – The methods get_state and set_state are not inherently thread-safe with respect to running the reader. External synchronization is necessary.

  • note:: (..) – If there are any pending asynchronous or deferred calls to this operator, the function will wait for them to finish before setting the state.

  • has (If _sphinx_paramlinks_nvidia.dali.experimental.dynamic._ops.Reader.set_state.the reader's backend has not yet been initialized (i.e. no epoch)

  • started) (been)

  • first (the _sphinx_paramlinks_nvidia.dali.experimental.dynamic._ops.Reader.set_state.state is buffered and applied automatically the)

  • created. (time _sphinx_paramlinks_nvidia.dali.experimental.dynamic._ops.Reader.set_state.the backend is)

RNG checkpoint methods#

RNG.get_state(*, cuda_stream=None)#

Returns the internal state of the generator.

Equivalent to state. Provided so that an RNG exposes the same get_state / set_state interface as a Reader, which is the contract used by the checkpointing API.

Parameters:

cuda_stream (Any) – Not used.

Returns:

  • Opaque state object. The object can be converted to a string with str(state) and

  • later used to set the state or construct an RNG.

RNG.set_state(value)#

Sets the internal state of the generator.

Equivalent to assigning to state. Provided so that an RNG exposes the same get_state / set_state interface as a Reader, which is the contract used by the checkpointing API.

Parameters:

value (object | str) – Either a state object obtained from get_state() (or the state property) or its string representation.

Checkpoint#

class nvidia.dali.experimental.dynamic.checkpoint.Checkpoint#

Aggregates the state of stateful objects (readers and RNGs) for resume.

A checkpoint stores a mapping from a key (operator name) to that operator’s serialized state. Objects can be added with register(), after which their state can be collected with collect() or restored with restore(). The serialized representation can be persisted with save() and loaded back with load().

The supported objects are:

  • Reader instances (the ndd.readers.* operators)

  • RNG instances

Both expose get_state and set_state methods, which is the duck-typed contract used by this class.

Examples

>>> import nvidia.dali.experimental.dynamic as ndd
>>>
>>> reader = ndd.readers.File(file_root="...")
>>> rng = ndd.random.RNG(seed=42)
>>>
>>> ckpt = ndd.checkpoint.Checkpoint()
>>> ckpt.register(reader, "reader")
>>> ckpt.register(rng, "rng")
>>>
>>> # ... iterate for a while ...
>>>
>>> ckpt.collect()
>>> ckpt.save("ckpt_{seq:04d}.json")
clear()#

Resets the checkpoint to its initial state.

Drops all registered ops, all stored states, and resets the complete and loaded flags as well as the auto-generated key counter.

collect()#

Collects the state of every registered op into the checkpoint.

Sets the complete flag, clears the loaded flag, and marks all states as clean. Verifies that the state dictionary does not contain any keys without a corresponding registered op.

deserialize(data)#

Replaces the checkpoint state dictionary with a deserialized version.

Parameters:

data (str) – A string previously produced by serialize().

Notes

Sets the loaded flag and clears the complete flag. Marks every loaded entry as dirty so that subsequent register() calls (or restore()) will apply the state.

get_state(name)#

Returns the stored state for the op registered under name.

Parameters:

name (str) – The key under which the op is registered.

Returns:

  • The state object passed to set_state or obtained from the

  • operator during collect. If the checkpoint was deserialized

  • or loaded from a file, it’ll be stringified.

Raises:

KeyError – If no op is registered under name.

property is_complete#

True if collect() was the last operation to populate the state.

property is_loaded#

True if load() (or deserialize()) populated the state.

load(filename)#

Reads a previously saved checkpoint from a file.

Parameters:

filename (str) – A Python format string, optionally containing {seq} (e.g. "ckpt_{seq:04d}.json"). If multiple files match, the one with the highest sequence number is loaded.

Returns:

The path of the file that was read.

Return type:

str

Raises:

FileNotFoundError – If no file matching the pattern exists.

Notes

Marks all loaded entries as dirty (see deserialize()).

property names#

Names of all currently registered ops.

register(op, name=None)#

Adds a stateful op to the checkpoint.

Parameters:
  • op (Reader or RNG) – The stateful object to register. Must expose get_state / set_state methods.

  • name (str, optional) – The key under which to store the op. If omitted, a sequential key is generated, unless op was already registered (under any name), in which case its existing key is returned. If name is provided, any existing operator under that key is replaced. The key must not start with "__op_" - this prefix is reserved for automatically generated names.

Returns:

The key under which op is registered.

Return type:

str

Notes

If a state is currently associated with name and is marked dirty (e.g. it came from a recent load()), the state is immediately applied to op via op.set_state and the entry is marked clean.

restore()#

Restores the state of every registered op from the checkpoint.

Intended for manual use - typically you’d register ops one by one and rely on register() to apply the state implicitly. This method applies all dirty states at once.

Raises:

RuntimeError – If the state dictionary contains entries with no corresponding registered ops, or if any registered op is missing a state in the dictionary (unless the dictionary is empty, in which case this is a no-op).

save(filename)#

Serializes the checkpoint and writes it to a file.

Parameters:

filename (str) – A Python format string, optionally containing {seq} (e.g. "ckpt_{seq:04d}.json"). The placeholder is replaced with a sequential number that does not collide with any existing file.

Returns:

The path of the file that was written.

Return type:

str

Raises:

RuntimeError – If the state dictionary is empty.

serialize()#

Serializes the checkpoint state dictionary to a JSON string.

Raises:

RuntimeError – If the state dictionary is empty.

set_state(name, value, op_type=None)#

Manually sets the state for an operator (registered or future)

Parameters:
  • name (str) – The key under which the op is registered.

  • value – The state value. It can be of any type that the respective operator’s set_state accepts.

  • op_type (class or str, optional) – The class (or its qualified name) of the operator that this state belongs to. When provided, the type name is stored alongside the state and verified against the actual operator’s type before the state is applied (in _maybe_apply()).

Raises:
  • KeyError – If the checkpoint is complete and no operator is registered under name.

  • TypeError – If op_type is neither a type nor a string.

Notes

The state is marked dirty. It is applied to the op the next time register() is called for that key, or when restore() is invoked.

current#

nvidia.dali.experimental.dynamic.checkpoint.current()#

Returns the Checkpoint bound to the current EvalContext.