nemo_rl.environments.reward_model_environment
#
Module Contents#
Classes#
Configuration for RewardModelEnvironment. |
|
Environment that uses a reward model to score conversations. |
API#
- class nemo_rl.environments.reward_model_environment.RewardModelEnvironmentConfig#
Bases:
typing.TypedDict
Configuration for RewardModelEnvironment.
.. attribute:: enabled
Whether the reward model environment is enabled
.. attribute:: model_name
Name of the reward model to use (e.g., “Skywork/Skywork-Reward-V2-Qwen3-0.6B”)
.. attribute:: tokenizer
Tokenizer configuration
.. attribute:: precision
Model precision (e.g., “bfloat16”, “float16”, “float32”)
.. attribute:: batch_size
Batch size for processing conversations
.. attribute:: checkpoint_path
Path to model checkpoint (optional)
.. attribute:: max_model_len
Maximum sequence length for the model
.. attribute:: logprob_batch_size
Batch size for log probability computation
.. attribute:: resources
Resource allocation configuration
.. attribute:: reward_model_cfg
Reward model specific configuration
.. attribute:: dtensor_cfg
DTensor configuration for distributed training
.. attribute:: dynamic_batching
Dynamic batching configuration
.. attribute:: sequence_packing
Sequence packing configuration
.. attribute:: max_grad_norm
Maximum gradient norm for training
.. attribute:: generation
Generation configuration for VLLM
Initialization
Initialize self. See help(type(self)) for accurate signature.
- enabled: bool#
None
- model_name: str#
None
- precision: str#
None
- batch_size: int#
None
- checkpoint_path: str#
None
- logprob_batch_size: int#
None
- resources: Dict[str, Any]#
None
- dtensor_cfg: Optional[Dict[str, Any]]#
None
- dynamic_batching: nemo_rl.models.policy.DynamicBatchingConfig#
None
- sequence_packing: NotRequired[nemo_rl.models.policy.SequencePackingConfig]#
None
- max_grad_norm: Optional[float]#
None
- generation: Optional[nemo_rl.models.generation.vllm.VllmConfig]#
None
- class nemo_rl.environments.reward_model_environment.RewardModelEnvironment(config: Dict[str, Any])#
Bases:
nemo_rl.environments.interfaces.EnvironmentInterface
Environment that uses a reward model to score conversations.
This environment implements a reward model-based scoring system for reinforcement learning tasks. It takes conversation logs as input and returns rewards based on the quality of the assistant’s responses as judged by a pre-trained reward model.
.. attribute:: config
Configuration dictionary containing all environment settings
.. attribute:: virtual_cluster
Ray virtual cluster for resource management
.. attribute:: tokenizer
Tokenizer for text processing
.. attribute:: reward_model_policy
Policy object containing the reward model
Initialization
Initialize the reward model environment.
- Parameters:
config – Configuration dictionary containing reward model settings. Must include model_name, tokenizer, resources, and other required parameters as defined in RewardModelEnvironmentConfig.
- DEFAULT_PY_EXECUTABLE#
None
- preprocess_data(
- message_logs: List[nemo_rl.data.interfaces.LLMMessageLogType],
Preprocess the message logs for the reward model.
This method tokenizes and formats conversation logs into the format expected by the reward model. It handles:
Tokenization of user and assistant messages
Formatting with proper special tokens
Batching and padding for efficient processing
Sequence length validation and truncation
- Parameters:
message_logs – List of conversation message logs, where each log contains a list of messages with ‘role’ and ‘content’ fields.
- Returns:
BatchedDataDict containing tokenized and formatted data ready for reward model inference.
- step(
- message_logs: List[nemo_rl.data.interfaces.LLMMessageLogType],
- env_infos: List[Dict[str, Any]],
Calculate rewards for the given message logs using the reward model.
This method processes conversation logs through the reward model to compute quality scores for each conversation. The rewards are based on the reward model’s assessment of how well the assistant’s responses align with human preferences.
- Parameters:
message_logs – List of conversation message logs to be scored. Each log should contain alternating user and assistant messages.
env_infos – List of environment info dictionaries (currently unused but required by the interface).
- Returns:
observations: List of observation dictionaries with reward information
metadata: List of metadata dictionaries (currently None)
next_stop_strings: List of stop strings (currently None)
rewards: Tensor of computed rewards for each conversation
terminateds: Tensor indicating episode termination (all True)
answers: List of assistant responses from the conversations
- Return type:
EnvironmentReturn containing
- global_post_process_and_metrics( ) Tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict, dict] #
Post processing function after all rollouts are done for the batch and returns metrics.
This method computes aggregate statistics and metrics from the processed batch. It provides insights into reward distribution and processing statistics.
- Parameters:
batch – The batch data dictionary containing processed conversations and rewards.
- Returns:
processed_batch: The input batch (no modifications)
metrics_dict: Dictionary containing computed metrics including:
reward_model_env/num_samples: Number of samples processed
reward_model_env/mean_reward: Average reward across the batch
reward_model_env/std_reward: Standard deviation of rewards
reward_model_env/min_reward: Minimum reward in the batch
reward_model_env/max_reward: Maximum reward in the batch
- Return type:
Tuple of (processed_batch, metrics_dict) where
- shutdown()#
Shutdown the reward model worker and virtual cluster.
This method properly cleans up resources by shutting down the reward model policy and virtual cluster. It should be called when the environment is no longer needed to prevent resource leaks.
.. note::
The environment will also automatically call this method in its destructor, but it’s recommended to call it explicitly for better resource management.
- __del__()#
Destructor that ensures proper cleanup when the object is garbage collected.
This is an extra safety net in case the user forgets to call shutdown() and the pointer to the object is lost due to leaving a function scope. It’s always recommended that the user calls shutdown() explicitly for better resource management.