nemo_rl.algorithms.async_utils.replay_buffer#

Module Contents#

Classes#

ReplayBufferImpl

Replay buffer storing per-prompt groups.

ReplayBuffer

ReplayBufferNew

Staleness-window replay buffer.

API#

class nemo_rl.algorithms.async_utils.replay_buffer.ReplayBufferImpl(max_size: int)#

Bases: nemo_rl.algorithms.async_utils.interfaces.ReplayBufferProtocol

Replay buffer storing per-prompt groups.

A single entry corresponds to 1 prompt repeated by grpo.num_generations_per_prompt (required to compute per-prompt advantages).

Initialization

add(
trajectory: dict[str, Any],
weight_version: int,
target_weight_version: int,
) str#

Add a per-prompt trajectory group with metadata.

Parameters:
  • trajectory – data dict

  • weight_version – version of the model weights used for generation

  • target_weight_version – version of the model weights this trajectory is intended for training

get_debug_info() dict#

Get debug information about buffer state.

get_last_target_weight_already_generated() int#
get_existing_target_weights() set[int]#

Get set of target weight versions that already have trajectories.

_remove_indices(indices: Iterable[int]) None#

Remove trajectories at the given indices.

sample(
num_prompt_groups: int,
current_weight_version: int,
max_age_steps: int,
) Optional[dict[str, Any]]#

Sample per-prompt trajectory groups intended for the current training step.

Only returns trajectories with target_weight_version == current_weight_version. If insufficient trajectories are available, returns None to stall training until the remaining trajectories are generated. This ensures no trajectory loses its last chance to be used for its intended training step.

Returns:

Dictionary with ‘trajectories’ and ‘avg_trajectory_age’ keys, or None if insufficient data

size() int#

Return current buffer size.

clear() None#

Clear the buffer.

class nemo_rl.algorithms.async_utils.replay_buffer.ReplayBuffer(max_size: int)#

Bases: nemo_rl.algorithms.async_utils.replay_buffer.ReplayBufferImpl

class nemo_rl.algorithms.async_utils.replay_buffer.ReplayBufferNew(
max_size: int,
max_staleness: int,
sample_freshest_first: bool = True,
)#

Bases: nemo_rl.algorithms.async_utils.replay_buffer.ReplayBufferImpl

Staleness-window replay buffer.

– WIP: DO NOT USE – This class is WIP and may be changed without notice, please DO NOT USE it.

Differences from ReplayBuffer:

  • _evict(): Stale rows (trainer_version - weight_version > max_staleness) are evicted at the start of every sample() call.

  • sample(): selects trajectories in freshest-first order (default) or FIFO order, controlled by the sample_freshest_first flag, from whatever remains in the buffer after eviction.

TODO: remove when cleaning up

  • max_age_steps won’t be used in ReplayBufferNew;

  • self.target_weight_versions won’t be used in ReplayBufferNew and will be removed when cleaning up. target_weight_versions gates generation on specific trainer steps, which causes generation pauses; ReplayBufferNew intentionally avoids this.

  • add this class to nemo_rl/algorithms/async_utils/init.py

Initialization

_evict(current_weight_version: int) None#

Evict rows where trainer_version - weight_version > max_staleness.

Must be called with self._lock held.

sample(
num_prompt_groups: int,
current_weight_version: int,
max_age_steps: int,
) Optional[dict[str, Any]]#

Sample num_prompt_groups trajectories, freshest-first.

Will evict stale rows before sampling, so we will get [current_weight_version - self.max_staleness, current_weight_version] valid trajectories.

Returns:

Dictionary with ‘trajectories’ and ‘avg_trajectory_age’ keys, or None.