nemo_rl.algorithms.async_utils
#
Module Contents#
Classes#
Replay buffer storing per-prompt groups. |
|
Collects trajectories asynchronously and adds them to replay buffer. |
Data#
API#
- nemo_rl.algorithms.async_utils.TokenizerType#
None
- class nemo_rl.algorithms.async_utils.ReplayBuffer(max_size: int)#
Replay buffer storing per-prompt groups.
A single entry corresponds to 1 prompt repeated by grpo.num_generations_per_prompt (required to compute per-prompt advantages).
Initialization
- push_with_wait_signal(
- trajectory: dict[str, Any],
- weight_version: int,
- target_weight_version: int,
Add a per-prompt trajectory group with metadata.
- Parameters:
trajectory – data dict
weight_version – version of the model weights used for generation
target_weight_version – version of the model weights this trajectory is intended for training
- get_debug_info() dict #
Get debug information about buffer state.
- get_last_target_weight_already_generated() int #
- get_existing_target_weights() set[int] #
Get set of target weight versions that already have trajectories.
- sample(
- num_prompt_groups: int,
- current_weight_version: int,
- max_age_steps: int,
Sample per-prompt trajectory groups intended for the current training step.
Only returns trajectories with target_weight_version == current_weight_version. If insufficient trajectories are available, returns None to stall training until the remaining trajectories are generated. This ensures no trajectory loses its last chance to be used for its intended training step.
- Returns:
Dictionary with ‘trajectories’ and ‘avg_trajectory_age’ keys, or None if insufficient data
- size() int #
Return current buffer size.
- clear() None #
Clear the buffer.
- class nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- tokenizer: nemo_rl.algorithms.async_utils.TokenizerType,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
- replay_buffer: Any,
- start_step: int = 0,
Collects trajectories asynchronously and adds them to replay buffer.
Initialization
- _calculate_target_weights(generation_weight_version: int) list[int] #
Calculate target weight versions for given generation weight version.
The list of versions returned enumerate the possible version a generation server can target. These versions are looped over to see what training step they can target. If all target versions are exhausted, this generation server will remain idle until the next weight update.
Example: generation_weight_version = 10 max_trajectory_age_steps = 4
- Returns:
[11, 12, 13, 14] # Meaning this generation server can create trajectories for training step 11, 12, 13, 14
- _get_next_target_for_generation(
- generation_weight_version: int,
Get the next target weight that needs generation (if any).
- set_weight_version(version: int) None #
- _should_pause_for_generation_limits() bool #
Check if collection should be paused due to generation limits.
- start_collection(
- dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
Start collecting trajectories from dataloader.
- _collection_loop()#
Run the collection loop in background thread.
- _process_batch( ) None #
Process a single batch and generate for one target weight.
- get_weight_version() int #
- pause() None #
Pause trajectory collection.
- resume() None #
Resume trajectory collection.
- prepare_for_refit() None #
Pause new generation starts and wait for pending generations to complete before refit.
- resume_after_refit() None #
Resume new generation starts after refit is complete.
- wait_for_pending_generations() None #
Wait for all in-flight generation threads to complete.
- get_dataloader_state() dict #
Get the current dataloader state for checkpointing.
- _cleanup_finished_threads() None #
- _run_prompt_group_worker(
- repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
- generation_weight_version: int,
- target_weight_version: int,
- prompt_idx: int,