nemo_rl.algorithms.async_utils.interfaces#
Module Contents#
Classes#
Interface for the replay buffer used in async RL training. |
API#
- class nemo_rl.algorithms.async_utils.interfaces.ReplayBufferProtocol#
Bases:
typing.ProtocolInterface for the replay buffer used in async RL training.
- add(
- trajectory: dict[str, Any],
- weight_version: int,
- target_weight_version: int,
Add a per-prompt trajectory group with metadata.
- Parameters:
trajectory – data dict
weight_version – version of the model weights used for generation
target_weight_version – version of the model weights this trajectory is intended for training
- sample(
- num_prompt_groups: int,
- current_weight_version: int,
- max_age_steps: int,
Sample per-prompt trajectory groups intended for the current training step.
Only returns trajectories with target_weight_version == current_weight_version. If insufficient trajectories are available, returns None to stall training until the remaining trajectories are generated. This ensures no trajectory loses its last chance to be used for its intended training step.
- Returns:
Dictionary with ‘trajectories’ and ‘avg_trajectory_age’ keys, or None if insufficient data
- size() int#
Return current buffer size.
- clear() None#
Clear the buffer.