nemo_rl.models.policy.megatron_policy_worker
#
Module Contents#
Classes#
Functions#
Safely destroy parallel state and reset async call tracking. |
Data#
API#
- nemo_rl.models.policy.megatron_policy_worker.TokenizerType#
‘TypeVar(…)’
- nemo_rl.models.policy.megatron_policy_worker.setup_megatron_model(
- policy_cfg: nemo_rl.models.policy.PolicyConfig,
- cfg: nemo.tron.config.ConfigContainer,
- load_optimizer: bool = True,
- get_embedding_ranks=None,
- get_position_embedding_ranks=None,
- nemo_rl.models.policy.megatron_policy_worker.destroy_parallel_state()#
Safely destroy parallel state and reset async call tracking.
This function is called during initialization to clean up temporary distributed state from model import operations. Resetting async call tracking ensures that when the main Megatron distributed context is created, all ranks start with consistent call_idx values for async checkpointing.
- class nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker(
- config: nemo_rl.models.policy.PolicyConfig,
- tokenizer: nemo_rl.models.policy.megatron_policy_worker.TokenizerType,
- weights_path: Optional[str] = None,
- optimizer_path: Optional[str] = None,
- init_optimizer: bool = True,
- init_reference_model: bool = True,
- *,
- worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding,
- pre_init_communication_queue: ray.util.queue.Queue,
- **kwargs: Any,
Initialization
- __repr__()#
Customizes the actor’s prefix in the Ray logs.
This makes it easier to identify which worker is producing specific log messages.
- is_alive()#
- reset_peak_memory_stats() None #
- get_gpu_info()#
Return information about the GPU being used by this worker.
- enable_forward_pre_hook()#
- disable_forward_pre_hook(param_sync=True)#
- train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- eval_mode: bool = False,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
Train the policy on a batch of data with a given loss function.
- get_logprobs(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- micro_batch_size: Optional[int] = None,
Get the logprobs of the model for a batch of data.
Uses the configured logprob_batch_size to do microbatching. Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format. If micro_batch_size is provided, it will be used instead of the configured logprob_batch_size.
- Returns:
a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- use_reference_model()#
Context manager that temporarily swaps the reference model and active model.
On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references On exit: Restores original references and re-flips cuda/cpu
- get_reference_policy_logprobs(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- micro_batch_size: Optional[int] = None,
Get the logprobs from thereference policy for a batch of data.
If micro_batch_size is provided, it will be used instead of the configured logprob_batch_size.
- Returns:
a BatchedDataDict with key “reference_logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- generate(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate a batch of data using huggingface framework generation.
- Parameters:
data – BatchedDataDict containing input_ids and input_lengths tensors
- Returns:
output_ids: input + generated token IDs
logprobs: Log probabilities for each token
generation_lengths: Lengths of each response
- Return type:
BatchedDataDict conforming to GenerationOutputSpec
- zero_out_weights()#
Zero out the weights of the model.
- report_device_id() str #
Report the UUID of the current CUDA device using NVML.
- Returns:
UUID of the device in the format “GPU-xxxxx”
- Return type:
str
- prepare_refit_info() None #
- prepare_weights_for_ipc() tuple[list[tuple[str, int]], float] #
Prepare Megatron model weights for IPC transfer to vLLM.
Collects information about weight tensors (names and sizes). Returns a list of (parameter_name, size_in_bytes) tuples.
- get_weights_ipc_handles(*, keys: list[str]) dict[str, Any] #
Get IPC handles for the requested Megatron model weights.
- Parameters:
keys – List of parameter names to get handles for
- Returns:
Dict mapping device UUID to list of (mapped_key, handle) tuples
- prepare_for_lp_inference()#
- prepare_for_training(*args, **kwargs)#
- offload_before_refit()#
Offload the optimizer and buffers to the CPU.
- offload_after_refit()#
- move_model(model, device: str, move_params=True, move_grads=True)#
- save_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
- **kwargs,
Save a training checkpoint.
- Parameters:
weights_path – The specific directory path where the checkpoint will be saved.
optimizer_path – If not None, optimizer and scheduler states are saved if they exist.
- abstractmethod load_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
Load a training checkpoint.
- Parameters:
weights_path – The exact directory path from which to load the checkpoint.
optimizer_path – If not None, attempts to load optimizer and scheduler states if self.optimizer and self.scheduler are initialized.
- shutdown()#
Shutdown the policy.
- start_gpu_profiling() None #
Start GPU profiling.
- stop_gpu_profiling() None #
Stop GPU profiling.