nemo_rl.models.policy.workers.megatron_policy_worker#
Module Contents#
Classes#
Functions#
Broadcast an object across pipeline parallel ranks. |
Data#
API#
- nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType#
‘TypeVar(…)’
- nemo_rl.models.policy.workers.megatron_policy_worker.broadcast_object_across_pp_ranks(obj)#
Broadcast an object across pipeline parallel ranks.
This utility function handles broadcasting an object from the rank that owns it to all other pipeline parallel ranks. If only one rank has the object (non-None), it will be broadcast to all other ranks.
- Parameters:
obj – The object to broadcast. Can be None on ranks that don’t own it.
- Returns:
The object on all ranks (either the original or the broadcast copy).
- Raises:
ValueError – If the object doesn’t exist on any pipeline parallel rank.
- class nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker(
- config: nemo_rl.models.policy.PolicyConfig,
- tokenizer: nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType,
- weights_path: Optional[str] = None,
- optimizer_path: Optional[str] = None,
- init_optimizer: bool = True,
- init_reference_model: bool = True,
- *,
- worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding,
- **kwargs: Any,
Bases:
nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker,nemo_rl.models.policy.interfaces.ColocatablePolicyInterface- __repr__()#
Customizes the actor’s prefix in the Ray logs.
This makes it easier to identify which worker is producing specific log messages.
- enable_forward_pre_hook()#
- disable_forward_pre_hook(param_sync=True)#
- train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- eval_mode: bool = False,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
Train the policy on a batch of data with a given loss function.
- get_logprobs(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- micro_batch_size: Optional[int] = None,
Get the logprobs of the model for a batch of data.
Uses the configured logprob_batch_size to do microbatching. Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format. If micro_batch_size is provided, it will be used instead of the configured logprob_batch_size.
- Returns:
a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- use_reference_model()#
Context manager that temporarily swaps the reference model and active model.
On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references On exit: Restores original references and re-flips cuda/cpu
- get_topk_logits(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- k: int,
- micro_batch_size: Optional[int] = None,
Get the top-k logits and indices for a batch of data.
The major difference from get_logprobs is that we compute top-k logits and indices for each position in the sequence.
- Returns:
topk_logits: Tensor of top-k logits for each position in the sequence
topk_indices: Tensor of top-k indices for each position in the sequence
- Return type:
BatchedDataDict containing
- generate(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate a batch of data using huggingface framework generation.
- Parameters:
data – BatchedDataDict containing input_ids and input_lengths tensors
- Returns:
output_ids: input + generated token IDs
logprobs: Log probabilities for each token
generation_lengths: Lengths of each response
- Return type:
BatchedDataDict conforming to GenerationOutputSpec
- prepare_refit_info() None#
Prepare state dict metadata for weight refitting and IPC streaming.
- _calculate_refit_param_info() list[tuple[str, int]]#
Calculate parameter information for refit.
Each task contains:
param_name: Local parameter name without module prefixes
mapping: MegatronParamMapping instance for weight transformation
pp_rank: Pipeline-parallel rank owning the parameter
vp_stage: Virtual-pipeline stage index
megatron_module: Reference to Megatron model/submodule
param_weight: Target parameter tensor for converted weight
- Returns:
List of (parameter_name, size_in_bytes) tuples.
- _iter_params_with_optional_kv_scales(
- kv_scales: Optional[dict[str, float]] = None,
Yield exported HF parameters and optionally append FP8 KV/Q scale tensors.
This helper is used by both IPC-based streaming and collective broadcast so that the logic for adding KV scales stays consistent in one place.
- stream_weights_via_ipc_zmq(
- buffer_size_bytes: int = 0,
- kv_scales: Optional[dict[str, float]] = None,
Stream model weights to peer process via ZMQ IPC socket.
- broadcast_weights_for_collective(
- kv_scales: Optional[dict[str, float]] = None,
Broadcast the weights for collective communication.
- prepare_for_lp_inference()#
- prepare_for_training(*args, **kwargs)#
- offload_before_refit()#
Offload the optimizer and buffers to the CPU.
- offload_after_refit()#
Offload as much as possible on the CPU.
- move_model(
- model: torch.nn.Module,
- device: str,
- move_params: bool = True,
- move_grads: bool = True,
- move_optimizer(device: str)#
- save_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
- **kwargs,
Save a training checkpoint.
- Parameters:
weights_path – The specific directory path where the checkpoint will be saved.
optimizer_path – If not None, optimizer and scheduler states are saved if they exist.
- abstractmethod load_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
Load a training checkpoint.
- Parameters:
weights_path – The exact directory path from which to load the checkpoint.
optimizer_path – If not None, attempts to load optimizer and scheduler states if self.optimizer and self.scheduler are initialized.
- check_tensor_parallel_attributes() dict[str, Any]#
Check tensor parallel attributes on model parameters.
- Returns:
tp_params: List of parameter names that have tensor_model_parallel=True
non_tp_params: List of parameter names that have tensor_model_parallel=False
total_params: Total number of parameters checked
tp_size: Tensor parallel size from config
- Return type:
Dictionary containing information about tensor parallel parameters
- calibrate_qkv_fp8_scales(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- micro_batch_size: Optional[int] = None,
- percentile: float = 99.9,
- margin: float = 1.05,
- include_q: bool = False,
One-shot calibration of Q/K/V activation scales (for FP8 KV cache).
Captures each layer’s
query_key_valueoutput through forward hooks, splits Q/K/V, and computes percentile amax.In parallel (DP/TP/PP) environments, first computes local percentiles, then takes max across all ranks for conservativeness.
By default only returns and saves K/V scales, optionally returns Q.
- Parameters:
data – Representative sample batch for calibration, following get_logprobs input conventions.
micro_batch_size – Micro batch size during calibration; if None, reuses logprob_batch_size.
percentile – Percentile for amax (e.g. 99.9).
margin – Margin factor, e.g. 1.05.
save_path – If provided, rank0 will save results as JSON.
include_q – Whether to also return Q scale (usually only K/V needed).
- Returns:
“fp8”, “percentile”: float, “margin”: float, “layers”: { layer_name: {“k_scale”: float, “v_scale”: float[, “q_scale”: float] } } }
- Return type:
{ “format”