nemo_rl.models.policy.megatron_policy_worker#
Module Contents#
Classes#
Functions#
Broadcast an object across pipeline parallel ranks. |
|
Safely destroy parallel state and reset async call tracking. |
Data#
API#
- nemo_rl.models.policy.megatron_policy_worker.TokenizerType#
‘TypeVar(…)’
- nemo_rl.models.policy.megatron_policy_worker.broadcast_object_across_pp_ranks(obj)#
Broadcast an object across pipeline parallel ranks.
This utility function handles broadcasting an object from the rank that owns it to all other pipeline parallel ranks. If only one rank has the object (non-None), it will be broadcast to all other ranks.
- Parameters:
obj – The object to broadcast. Can be None on ranks that don’t own it.
- Returns:
The object on all ranks (either the original or the broadcast copy).
- Raises:
ValueError – If the object doesn’t exist on any pipeline parallel rank.
- nemo_rl.models.policy.megatron_policy_worker.setup_megatron_model(
- policy_cfg: nemo_rl.models.policy.PolicyConfig,
- cfg: megatron.bridge.training.config.ConfigContainer,
- load_optimizer: bool = True,
- get_embedding_ranks=None,
- get_position_embedding_ranks=None,
- nemo_rl.models.policy.megatron_policy_worker.destroy_parallel_state()#
Safely destroy parallel state and reset async call tracking.
This function is called during initialization to clean up temporary distributed state from model import operations. Resetting async call tracking ensures that when the main Megatron distributed context is created, all ranks start with consistent call_idx values for async checkpointing.
- class nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker(
- config: nemo_rl.models.policy.PolicyConfig,
- tokenizer: nemo_rl.models.policy.megatron_policy_worker.TokenizerType,
- weights_path: Optional[str] = None,
- optimizer_path: Optional[str] = None,
- init_optimizer: bool = True,
- init_reference_model: bool = True,
- *,
- worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding,
- pre_init_communication_queue: ray.util.queue.Queue,
- **kwargs: Any,
Initialization
- __repr__()#
Customizes the actor’s prefix in the Ray logs.
This makes it easier to identify which worker is producing specific log messages.
- init_collective(
- ip: str,
- port: int,
- world_size: int,
- *,
- train_world_size: int,
Initialize the collective communication.
- is_alive()#
- reset_peak_memory_stats() None#
- get_gpu_info()#
Return information about the GPU being used by this worker.
- enable_forward_pre_hook()#
- disable_forward_pre_hook(param_sync=True)#
- train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- eval_mode: bool = False,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
Train the policy on a batch of data with a given loss function.
- get_logprobs(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- micro_batch_size: Optional[int] = None,
Get the logprobs of the model for a batch of data.
Uses the configured logprob_batch_size to do microbatching. Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format. If micro_batch_size is provided, it will be used instead of the configured logprob_batch_size.
- Returns:
a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- use_reference_model()#
Context manager that temporarily swaps the reference model and active model.
On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references On exit: Restores original references and re-flips cuda/cpu
- get_reference_policy_logprobs(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- micro_batch_size: Optional[int] = None,
Get the logprobs from thereference policy for a batch of data.
If micro_batch_size is provided, it will be used instead of the configured logprob_batch_size.
- Returns:
a BatchedDataDict with key “reference_logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- get_topk_logits(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- k: int,
- micro_batch_size: Optional[int] = None,
Get the top-k logits and indices for a batch of data.
The major difference from get_logprobs is that we compute top-k logits and indices for each position in the sequence.
- Returns:
topk_logits: Tensor of top-k logits for each position in the sequence
topk_indices: Tensor of top-k indices for each position in the sequence
- Return type:
BatchedDataDict containing
- generate(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate a batch of data using huggingface framework generation.
- Parameters:
data – BatchedDataDict containing input_ids and input_lengths tensors
- Returns:
output_ids: input + generated token IDs
logprobs: Log probabilities for each token
generation_lengths: Lengths of each response
- Return type:
BatchedDataDict conforming to GenerationOutputSpec
- zero_out_weights()#
Zero out the weights of the model.
- report_device_id() str#
Report the UUID of the current CUDA device using NVML.
- Returns:
UUID of the device in the format “GPU-xxxxx”
- Return type:
str
- get_zmq_address()#
Get the ZMQ address for the current device.
- maybe_init_zmq()#
Initialize the ZMQ socket if it doesn’t exist.
- prepare_refit_info() None#
Prepare state dict metadata for weight refitting and IPC streaming.
- _calculate_refit_param_info() list[tuple[str, int]]#
Calculate parameter information for refit.
Each task contains:
param_name: Local parameter name without module prefixes
mapping: MegatronParamMapping instance for weight transformation
pp_rank: Pipeline-parallel rank owning the parameter
vp_stage: Virtual-pipeline stage index
megatron_module: Reference to Megatron model/submodule
param_weight: Target parameter tensor for converted weight
- Returns:
List of (parameter_name, size_in_bytes) tuples.
- get_free_memory_bytes() int#
Get the available free memory.
- stream_weights_via_ipc_zmq(buffer_size_bytes: int = 0) None#
Stream model weights to peer process via ZMQ IPC socket.
- broadcast_weights_for_collective() None#
Broadcast the weights for collective communication.
- prepare_for_lp_inference()#
- prepare_for_training(*args, **kwargs)#
- offload_before_refit()#
Offload the optimizer and buffers to the CPU.
- offload_after_refit()#
Offload as much as possible on the CPU.
- move_model(
- model: torch.nn.Module,
- device: str,
- move_params: bool = True,
- move_grads: bool = True,
- save_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
- **kwargs,
Save a training checkpoint.
- Parameters:
weights_path – The specific directory path where the checkpoint will be saved.
optimizer_path – If not None, optimizer and scheduler states are saved if they exist.
- abstractmethod load_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
Load a training checkpoint.
- Parameters:
weights_path – The exact directory path from which to load the checkpoint.
optimizer_path – If not None, attempts to load optimizer and scheduler states if self.optimizer and self.scheduler are initialized.
- shutdown()#
Shutdown the policy.
- start_gpu_profiling() None#
Start GPU profiling.
- stop_gpu_profiling() None#
Stop GPU profiling.
- report_node_ip_and_gpu_id() list[tuple[str, int]]#
Report the node IP and GPU ID of the current worker.