nemo_rl.algorithms.async_utils.replay_buffer#
Module Contents#
Classes#
Replay buffer storing per-prompt groups. |
|
Staleness-window replay buffer. |
API#
- class nemo_rl.algorithms.async_utils.replay_buffer.ReplayBufferImpl(max_size: int)#
Bases:
nemo_rl.algorithms.async_utils.interfaces.ReplayBufferProtocolReplay buffer storing per-prompt groups.
A single entry corresponds to 1 prompt repeated by grpo.num_generations_per_prompt (required to compute per-prompt advantages).
Initialization
- add(
- trajectory: dict[str, Any],
- weight_version: int,
- target_weight_version: int,
Add a per-prompt trajectory group with metadata.
- Parameters:
trajectory – data dict
weight_version – version of the model weights used for generation
target_weight_version – version of the model weights this trajectory is intended for training
- get_debug_info() dict#
Get debug information about buffer state.
- get_last_target_weight_already_generated() int#
- get_existing_target_weights() set[int]#
Get set of target weight versions that already have trajectories.
- _remove_indices(indices: Iterable[int]) None#
Remove trajectories at the given indices.
- sample(
- num_prompt_groups: int,
- current_weight_version: int,
- max_age_steps: int,
Sample per-prompt trajectory groups intended for the current training step.
Only returns trajectories with target_weight_version == current_weight_version. If insufficient trajectories are available, returns None to stall training until the remaining trajectories are generated. This ensures no trajectory loses its last chance to be used for its intended training step.
- Returns:
Dictionary with ‘trajectories’ and ‘avg_trajectory_age’ keys, or None if insufficient data
- size() int#
Return current buffer size.
- clear() None#
Clear the buffer.
- class nemo_rl.algorithms.async_utils.replay_buffer.ReplayBuffer(max_size: int)#
Bases:
nemo_rl.algorithms.async_utils.replay_buffer.ReplayBufferImpl
- class nemo_rl.algorithms.async_utils.replay_buffer.ReplayBufferNew(
- max_size: int,
- max_staleness: int,
- sample_freshest_first: bool = True,
Bases:
nemo_rl.algorithms.async_utils.replay_buffer.ReplayBufferImplStaleness-window replay buffer.
– WIP: DO NOT USE – This class is WIP and may be changed without notice, please DO NOT USE it.
Differences from ReplayBuffer:
_evict(): Stale rows (trainer_version - weight_version > max_staleness) are evicted at the start of every sample() call.
sample(): selects trajectories in freshest-first order (default) or FIFO order, controlled by the sample_freshest_first flag, from whatever remains in the buffer after eviction.
TODO: remove when cleaning up
max_age_steps won’t be used in ReplayBufferNew;
self.target_weight_versions won’t be used in ReplayBufferNew and will be removed when cleaning up. target_weight_versions gates generation on specific trainer steps, which causes generation pauses; ReplayBufferNew intentionally avoids this.
add this class to nemo_rl/algorithms/async_utils/init.py
Initialization
- _evict(current_weight_version: int) None#
Evict rows where trainer_version - weight_version > max_staleness.
Must be called with self._lock held.
- sample(
- num_prompt_groups: int,
- current_weight_version: int,
- max_age_steps: int,
Sample num_prompt_groups trajectories, freshest-first.
Will evict stale rows before sampling, so we will get [current_weight_version - self.max_staleness, current_weight_version] valid trajectories.
- Returns:
Dictionary with ‘trajectories’ and ‘avg_trajectory_age’ keys, or None.