Checkpointing#
Dynamic mode pipelines can produce checkpoints that capture the state of all stateful operators - readers and random number generators - so that processing can be resumed at the captured iteration. Operators that do not maintain user-observable state (decoders, resizes, normalizations, etc.) are conceptually stateless and are not part of a checkpoint.
This page describes the two checkpointing layers that DALI Dynamic exposes:
A manual
get_state/set_stateinterface on individual readers and RNGs.A semi-automatic
Checkpointaggregator that collects, serializes, and restores the state of a registered set of objects.
Note
A Reader’s state can be applied only to a freshly constructed
reader, before its first iteration. The underlying prefetch thread starts
on the first call into the reader, after which the snapshot queue is locked.
Calls to set_state() after iteration has begun raise a
RuntimeError.
Manual checkpointing#
Note
Manual checkpointing is a an advanced feature. It is a set building blocks for higher level systems, including the built-in semi-automatic checkpointing. It allows fine-grained control over individual reader or RNG states, enabling integration with pre-existing checkpoint systems or transferring state of compatible objects across process boundary. In typical usage scenario, it’s more convenent to use the semi-automatic checkpointing.
Both RNG and the readers
exposed via ndd.readers.* provide get_state and set_state methods.
The state object returned by get_state can be converted to a string with
str(), and set_state accepts either the state object or its string
representation:
import nvidia.dali.experimental.dynamic as ndd
reader = ndd.readers.File(file_root="...")
it = reader.next_epoch(batch_size=16)
# Iterate for a while...
first = next(it)
second = next(it)
# Capture a checkpoint after the second batch.
reader_state = reader.get_state()
serialized = str(reader_state) # safe to write to disk, send over the wire, etc.
# Later, on a fresh reader:
resumed = ndd.readers.File(file_root="...")
resumed.set_state(serialized)
for batch in resumed.next_epoch(batch_size=16):
... # produces the third batch first
The RNG interface is symmetric:
rng = ndd.random.RNG(seed=42)
rng_state = rng.get_state()
...
rng.set_state(rng_state)
Semi-automatic checkpointing with Checkpoint#
The Checkpoint class collects the state of a registered
set of stateful objects, serializes it to a single string, and restores the
state of new objects from that string.
A typical save/restore cycle looks like this:
import nvidia.dali.experimental.dynamic as ndd
reader = ndd.readers.File(file_root="...")
rng = ndd.random.RNG(seed=42)
ckpt = ndd.checkpoint.Checkpoint()
ckpt.register(reader, "reader")
ckpt.register(rng, "rng")
# ... iterate for some time ...
ckpt.collect() # capture the current state
ckpt.save("ckpt_{seq:04d}.json") # writes ckpt_0000.json, ckpt_0001.json, ...
Restoring from disk is the symmetric operation:
reader = ndd.readers.File(file_root="...")
rng = ndd.random.RNG()
ckpt = ndd.checkpoint.Checkpoint()
ckpt.load("ckpt_{seq:04d}.json") # picks up the most recent file
ckpt.register(reader, "reader") # state applied implicitly here
ckpt.register(rng, "rng") # ditto
for batch in reader.next_epoch(batch_size=16):
...
Current checkpoint#
The convenience function checkpoint.current() returns the
Checkpoint bound to the current EvalContext.
This function allows the code hidden behind function calls to use checkpointing
without modifying the API to pass the context explicitly.
Using checkpoint.current():
ckpt = ndd.checkpoint.current()
ckpt.register(reader, "reader")
Registration semantics#
register() accepts an optional name argument:
If
nameis provided, the entry is stored under that key. Any previous op registered under the same key is replaced.If
nameis omitted, the checkpoint first looks up the op by identity. If it is already registered, the existing key is returned. Otherwise, internally generated sequential names are used.
When the checkpoint is in loaded state and the registered key is present in the loaded dictionary, the saved state is applied to the op immediately. This makes the load/restore flow above a single line per op.
Lifecycle flags#
The is_complete and
is_loaded properties reflect the most recent
operation that populated the state dictionary:
collect()setsis_completeand clearsis_loaded. New ops cannot be registered (callclear()to reset).deserialize()(andload()) setis_loadedand clearis_complete. Subsequentregister()calls must use keys that exist in the loaded state.
Filename patterns#
save() and load()
take a Python format string with a single {seq} placeholder. save
substitutes the next free sequence number, load picks the highest one
matching the pattern on disk. Format specifiers (e.g. {seq:04d}) are
honored.
Manual restore#
The state of every registered op is applied implicitly when the op is added to
a loaded checkpoint via register(). The
restore() method is the explicit counterpart -
it applies all currently dirty states in one call, and is mostly useful when
the ops were registered before the state was supplied (e.g. via
set_state() or
deserialize()).
Limitations#
A few constraints to keep in mind when using Checkpoint:
Readers must opt in. A
Readeronly supports checkpointing when constructed withenable_checkpointing=True; the backend then maintains aprefetch_queue_depth + 1deep snapshot queue, which has a small runtime cost. Registering a reader that was not opted in is allowed only if its backend has not yet been initialized - the call toregister()will then enable checkpointing retroactively; otherwise it raises aRuntimeError.Compiled mode is not supported. Calling
Reader.next_epochwithcompile=Trueon a reader that has checkpointing enabled (or vice versa) raisesNotImplementedError.Reader state must be applied early.
Reader.set_state(and any buffered state propagated throughregister()) must run before the reader’s first iteration; the prefetch thread cannot be restored once it has started.The order of anonymous registrations matters. When a checkpoint is loaded and ops are re-added without explicit names, the same number of ops must be registered in the same order as at save time. The count is validated, and a stored type tag is checked at apply time - so cross-type swaps fail loudly - but registering ops of compatible types in a different order is not detected. Prefer named registration when in doubt.
Format version is strict.
deserialize()rejects payloads whoseversiondoes not match the current format with aValueError; there is no automatic upgrade path.Not thread-safe. A single
Checkpointinstance must not be accessed concurrently from multiple threads. TheEvalContext-bound checkpoint shared bycheckpoint.current()follows the same rule.The default ``EvalContext`` is reused.
ndd.checkpoint.current()returns the checkpoint bound to the thread-local defaultEvalContext, which lives for the lifetime of the process (or the enclosingwith EvalContext(...):block). Registrations accumulate across unrelated runs unless you callclear()between them.
API reference#
ReaderState#
- class nvidia.dali.experimental.dynamic._ops.ReaderState(op, serialized)#
Serialized checkpoint state of a
Reader.Wraps the serialized representation produced by the underlying operator’s
SerializeCheckpoint(). It can be converted to a string withstr(), saved to disk, and later passed toReader.set_state()to restore the reader to the captured iteration position.The object also keeps a reference to the originating operator so that future extensions can re-serialize the live state on demand.
Reader checkpoint methods#
- Reader.get_state(*, cuda_stream=None)#
Returns the current checkpoint state of this reader.
The returned state object captures the iteration position of the underlying loader. It can be passed back to
set_state()to resume processing from this point. The state is serialized to a string bystr(state).Warning
The methods
get_stateandset_stateare not inherently thread-safe with respect to running the reader. External synchronization is necessary.Note
If there are any pending asynchronous or deferred calls to this operator, the function will wait for them to finish before getting the state.
- Parameters:
cuda_stream¶ – The CUDA stream on which the readers is running or None.
- Returns:
An opaque state object that wraps a serialized checkpoint string.
- Return type:
- Raises:
RuntimeError – If the reader has not started any epoch yet (the backend has not been initialized) or if the underlying reader does not support checkpointing.
- Reader.set_state(state)#
Restores the reader’s iteration position from a saved state.
- Parameters:
state¶ (
ReaderStateor str) – Either a state object obtained fromget_state(), or its string representation (as produced bystr(state)).warning::¶ (..) – The methods
get_stateandset_stateare not inherently thread-safe with respect to running the reader. External synchronization is necessary.note::¶ (..) – If there are any pending asynchronous or deferred calls to this operator, the function will wait for them to finish before setting the state.
has (If _sphinx_paramlinks_nvidia.dali.experimental.dynamic._ops.Reader.set_state.the reader's backend has not yet been initialized (i.e. no epoch)
started¶) (been)
first (the _sphinx_paramlinks_nvidia.dali.experimental.dynamic._ops.Reader.set_state.state is buffered and applied automatically the)
created. (time _sphinx_paramlinks_nvidia.dali.experimental.dynamic._ops.Reader.set_state.the backend is)
RNG checkpoint methods#
- RNG.get_state(*, cuda_stream=None)#
Returns the internal state of the generator.
Equivalent to
state. Provided so that anRNGexposes the sameget_state/set_stateinterface as aReader, which is the contract used by the checkpointing API.- Parameters:
cuda_stream¶ (Any) – Not used.
- Returns:
Opaque state object. The object can be converted to a string with
str(state)andlater used to set the state or construct an RNG.
- RNG.set_state(value)#
Sets the internal state of the generator.
Equivalent to assigning to
state. Provided so that anRNGexposes the sameget_state/set_stateinterface as aReader, which is the contract used by the checkpointing API.- Parameters:
value¶ (object | str) – Either a state object obtained from
get_state()(or thestateproperty) or its string representation.
Checkpoint#
- class nvidia.dali.experimental.dynamic.checkpoint.Checkpoint#
Aggregates the state of stateful objects (readers and RNGs) for resume.
A checkpoint stores a mapping from a key (operator name) to that operator’s serialized state. Objects can be added with
register(), after which their state can be collected withcollect()or restored withrestore(). The serialized representation can be persisted withsave()and loaded back withload().The supported objects are:
Readerinstances (thendd.readers.*operators)RNGinstances
Both expose
get_stateandset_statemethods, which is the duck-typed contract used by this class.Examples
>>> import nvidia.dali.experimental.dynamic as ndd >>> >>> reader = ndd.readers.File(file_root="...") >>> rng = ndd.random.RNG(seed=42) >>> >>> ckpt = ndd.checkpoint.Checkpoint() >>> ckpt.register(reader, "reader") >>> ckpt.register(rng, "rng") >>> >>> # ... iterate for a while ... >>> >>> ckpt.collect() >>> ckpt.save("ckpt_{seq:04d}.json")
- clear()#
Resets the checkpoint to its initial state.
Drops all registered ops, all stored states, and resets the
completeandloadedflags as well as the auto-generated key counter.
- collect()#
Collects the state of every registered op into the checkpoint.
Sets the
completeflag, clears theloadedflag, and marks all states as clean. Verifies that the state dictionary does not contain any keys without a corresponding registered op.
- deserialize(data)#
Replaces the checkpoint state dictionary with a deserialized version.
- Parameters:
data¶ (str) – A string previously produced by
serialize().
Notes
Sets the
loadedflag and clears thecompleteflag. Marks every loaded entry as dirty so that subsequentregister()calls (orrestore()) will apply the state.
- get_state(name)#
Returns the stored state for the op registered under
name.- Parameters:
name¶ (str) – The key under which the op is registered.
- Returns:
The state object passed to
set_stateor obtained from theoperator during
collect. If the checkpoint was deserializedor loaded from a file, it’ll be stringified.
- Raises:
KeyError – If no op is registered under
name.
- property is_loaded#
Trueifload()(ordeserialize()) populated the state.
- load(filename)#
Reads a previously saved checkpoint from a file.
- Parameters:
filename¶ (str) – A Python format string, optionally containing
{seq}(e.g."ckpt_{seq:04d}.json"). If multiple files match, the one with the highest sequence number is loaded.- Returns:
The path of the file that was read.
- Return type:
str
- Raises:
FileNotFoundError – If no file matching the pattern exists.
Notes
Marks all loaded entries as dirty (see
deserialize()).
- property names#
Names of all currently registered ops.
- register(op, name=None)#
Adds a stateful op to the checkpoint.
- Parameters:
op¶ (Reader or RNG) – The stateful object to register. Must expose
get_state/set_statemethods.name¶ (str, optional) – The key under which to store the op. If omitted, a sequential key is generated, unless
opwas already registered (under any name), in which case its existing key is returned. Ifnameis provided, any existing operator under that key is replaced. The key must not start with"__op_"- this prefix is reserved for automatically generated names.
- Returns:
The key under which
opis registered.- Return type:
str
Notes
If a state is currently associated with
nameand is marked dirty (e.g. it came from a recentload()), the state is immediately applied toopviaop.set_stateand the entry is marked clean.
- restore()#
Restores the state of every registered op from the checkpoint.
Intended for manual use - typically you’d register ops one by one and rely on
register()to apply the state implicitly. This method applies all dirty states at once.- Raises:
RuntimeError – If the state dictionary contains entries with no corresponding registered ops, or if any registered op is missing a state in the dictionary (unless the dictionary is empty, in which case this is a no-op).
- save(filename)#
Serializes the checkpoint and writes it to a file.
- Parameters:
filename¶ (str) – A Python format string, optionally containing
{seq}(e.g."ckpt_{seq:04d}.json"). The placeholder is replaced with a sequential number that does not collide with any existing file.- Returns:
The path of the file that was written.
- Return type:
str
- Raises:
RuntimeError – If the state dictionary is empty.
- serialize()#
Serializes the checkpoint state dictionary to a JSON string.
- Raises:
RuntimeError – If the state dictionary is empty.
- set_state(name, value, op_type=None)#
Manually sets the state for an operator (registered or future)
- Parameters:
name¶ (str) – The key under which the op is registered.
value¶ – The state value. It can be of any type that the respective operator’s
set_stateaccepts.op_type¶ (class or str, optional) – The class (or its qualified name) of the operator that this state belongs to. When provided, the type name is stored alongside the state and verified against the actual operator’s type before the state is applied (in
_maybe_apply()).
- Raises:
KeyError – If the checkpoint is complete and no operator is registered under
name.TypeError – If
op_typeis neither a type nor a string.
Notes
The state is marked dirty. It is applied to the op the next time
register()is called for that key, or whenrestore()is invoked.
current#
- nvidia.dali.experimental.dynamic.checkpoint.current()#
Returns the
Checkpointbound to the currentEvalContext.