nemo_rl.algorithms.async_utils.trajectory_collector#
Module Contents#
Classes#
Collects trajectories asynchronously and adds them to replay buffer. |
Data#
API#
- nemo_rl.algorithms.async_utils.trajectory_collector.TokenizerType#
None
- class nemo_rl.algorithms.async_utils.trajectory_collector.AsyncTrajectoryCollector(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- tokenizer: nemo_rl.algorithms.async_utils.trajectory_collector.TokenizerType,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
- replay_buffer: Any,
- start_step: int = 0,
Collects trajectories asynchronously and adds them to replay buffer.
Initialization
- _calculate_target_weights(generation_weight_version: int) list[int]#
Calculate target weight versions for given generation weight version.
The list of versions returned enumerate the possible version a generation server can target. These versions are looped over to see what training step they can target. If all target versions are exhausted, this generation server will remain idle until the next weight update.
Example: generation_weight_version = 10 max_trajectory_age_steps = 4
- Returns:
[11, 12, 13, 14] # Meaning this generation server can create trajectories for training step 11, 12, 13, 14
- _get_next_target_for_generation(
- generation_weight_version: int,
Get the next target weight that needs generation (if any).
- set_weight_version(version: int) None#
- _should_pause_for_generation_limits() bool#
Check if collection should be paused due to generation limits.
- start_collection(
- dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
Start collecting trajectories from dataloader.
- _collection_loop()#
Run the collection loop in background thread.
- _process_batch( ) None#
Process a single batch and generate for one target weight.
- get_weight_version() int#
- pause() None#
Pause trajectory collection.
- resume() None#
Resume trajectory collection.
- prepare_for_refit() None#
Pause new generation starts and optionally wait for pending generations.
For vLLM V1 async engine, leverages in-flight weight updates via collective_rpc, allowing ongoing generations to continue with their current KV caches while weights are updated. This significantly improves async performance.
For non-async engines, waits for all pending generations to complete before refit.
- resume_after_refit() None#
Resume new generation starts after refit is complete.
- wait_for_pending_generations() None#
Wait for all in-flight generation threads to complete.
- get_dataloader_state() dict#
Get the current dataloader state for checkpointing.
- _cleanup_finished_threads() None#
- _run_prompt_group_worker(
- repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
- generation_weight_version: int,
- target_weight_version: int,
- prompt_idx: int,