nemo_rl.algorithms.async_utils#

Module Contents#

Classes#

ReplayBuffer

Replay buffer storing per-prompt groups.

AsyncTrajectoryCollector

Collects trajectories asynchronously and adds them to replay buffer.

Data#

API#

nemo_rl.algorithms.async_utils.TokenizerType#

None

class nemo_rl.algorithms.async_utils.ReplayBuffer(max_size: int)#

Replay buffer storing per-prompt groups.

A single entry corresponds to 1 prompt repeated by grpo.num_generations_per_prompt (required to compute per-prompt advantages).

Initialization

push_with_wait_signal(
trajectory: dict[str, Any],
weight_version: int,
target_weight_version: int,
) str#

Add a per-prompt trajectory group with metadata.

Parameters:
  • trajectory – data dict

  • weight_version – version of the model weights used for generation

  • target_weight_version – version of the model weights this trajectory is intended for training

get_debug_info() dict#

Get debug information about buffer state.

get_last_target_weight_already_generated() int#
get_existing_target_weights() set[int]#

Get set of target weight versions that already have trajectories.

sample(
num_prompt_groups: int,
current_weight_version: int,
max_age_steps: int,
) Optional[dict[str, Any]]#

Sample per-prompt trajectory groups intended for the current training step.

Only returns trajectories with target_weight_version == current_weight_version. If insufficient trajectories are available, returns None to stall training until the remaining trajectories are generated. This ensures no trajectory loses its last chance to be used for its intended training step.

Returns:

Dictionary with ‘trajectories’ and ‘avg_trajectory_age’ keys, or None if insufficient data

size() int#

Return current buffer size.

clear() None#

Clear the buffer.

class nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector(
policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
tokenizer: nemo_rl.algorithms.async_utils.TokenizerType,
task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
master_config: nemo_rl.algorithms.grpo.MasterConfig,
replay_buffer: Any,
start_step: int = 0,
)#

Collects trajectories asynchronously and adds them to replay buffer.

Initialization

_calculate_target_weights(generation_weight_version: int) list[int]#

Calculate target weight versions for given generation weight version.

The list of versions returned enumerate the possible version a generation server can target. These versions are looped over to see what training step they can target. If all target versions are exhausted, this generation server will remain idle until the next weight update.

Example: generation_weight_version = 10 max_trajectory_age_steps = 4

Returns:

[11, 12, 13, 14] # Meaning this generation server can create trajectories for training step 11, 12, 13, 14

_get_next_target_for_generation(
generation_weight_version: int,
) Optional[int]#

Get the next target weight that needs generation (if any).

set_weight_version(version: int) None#
_should_pause_for_generation_limits() bool#

Check if collection should be paused due to generation limits.

start_collection(
dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
) None#

Start collecting trajectories from dataloader.

_collection_loop()#

Run the collection loop in background thread.

_process_batch(
batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
) None#

Process a single batch and generate for one target weight.

get_weight_version() int#
pause() None#

Pause trajectory collection.

resume() None#

Resume trajectory collection.

prepare_for_refit() None#

Pause new generation starts and wait for pending generations to complete before refit.

resume_after_refit() None#

Resume new generation starts after refit is complete.

wait_for_pending_generations() None#

Wait for all in-flight generation threads to complete.

get_dataloader_state() dict#

Get the current dataloader state for checkpointing.

_cleanup_finished_threads() None#
_run_prompt_group_worker(
repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
generation_weight_version: int,
target_weight_version: int,
prompt_idx: int,
) None#