nemo_rl.models.policy.workers.megatron_policy_worker#
Module Contents#
Classes#
Data#
API#
- nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType#
‘TypeVar(…)’
- class nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorkerImpl(
- config: nemo_rl.models.policy.PolicyConfig,
- tokenizer: nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType,
- weights_path: Optional[str] = None,
- optimizer_path: Optional[str] = None,
- init_optimizer: bool = True,
- init_reference_model: bool = True,
- *,
- worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding,
- **kwargs: Any,
Bases:
nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker,nemo_rl.models.policy.interfaces.ColocatablePolicyInterface- __repr__()#
Customizes the actor’s prefix in the Ray logs.
This makes it easier to identify which worker is producing specific log messages.
- enable_forward_pre_hook()#
- disable_forward_pre_hook(param_sync=True)#
- train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- eval_mode: bool = False,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
Train the policy on a batch of data with a given loss function.
- get_logprobs(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- micro_batch_size: Optional[int] = None,
Get the logprobs of the model for a batch of data.
Uses the configured logprob_batch_size to do microbatching. Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format. If micro_batch_size is provided, it will be used instead of the configured logprob_batch_size.
- Returns:
a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- use_reference_model()#
Context manager that temporarily swaps the reference model and active model.
On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references. Also disables top-k/top-p filtering since the reference policy’s distribution is different from the current policy, making filtered logprobs incompatible. On exit: Restores original references and re-flips cuda/cpu, restores sampling_params.
- get_topk_logits(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- k: int,
- micro_batch_size: Optional[int] = None,
Get the top-k logits and indices for a batch of data.
The major difference from get_logprobs is that we compute top-k logits and indices for each position in the sequence.
- Returns:
topk_logits: Tensor of top-k logits for each position in the sequence
topk_indices: Tensor of top-k indices for each position in the sequence
- Return type:
BatchedDataDict containing
- generate(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate a batch of data using huggingface framework generation.
- Parameters:
data – BatchedDataDict containing input_ids and input_lengths tensors
- Returns:
output_ids: input + generated token IDs
logprobs: Log probabilities for each token
generation_lengths: Lengths of each response
- Return type:
BatchedDataDict conforming to GenerationOutputSpec
- prepare_refit_info() None#
Prepare state dict metadata for weight refitting and IPC streaming.
- _calculate_refit_param_info() list[tuple[str, int]]#
Calculate parameter information for refit.
Each task contains:
param_name: Local parameter name without module prefixes
mapping: MegatronParamMapping instance for weight transformation
pp_rank: Pipeline-parallel rank owning the parameter
vp_stage: Virtual-pipeline stage index
megatron_module: Reference to Megatron model/submodule
param_weight: Target parameter tensor for converted weight
- Returns:
List of (parameter_name, size_in_bytes) tuples.
- _iter_params_with_optional_kv_scales(
- kv_scales: Optional[dict[str, float]] = None,
Yield exported HF parameters and optionally append FP8 KV/Q scale tensors.
This helper is used by both IPC-based streaming and collective broadcast so that the logic for adding KV scales stays consistent in one place.
- stream_weights_via_ipc_zmq(
- buffer_size_bytes: int = 0,
- kv_scales: Optional[dict[str, float]] = None,
Stream model weights to peer process via ZMQ IPC socket.
- broadcast_weights_for_collective(
- kv_scales: Optional[dict[str, float]] = None,
Broadcast the weights for collective communication.
- prepare_for_lp_inference()#
- prepare_for_training(*args, **kwargs)#
- offload_before_refit()#
Offload the optimizer and buffers to the CPU.
- offload_after_refit()#
Offload as much as possible on the CPU.
- move_model(
- model: torch.nn.Module,
- device: str,
- move_params: bool = True,
- move_grads: bool = True,
- move_optimizer(device: str)#
- save_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
- **kwargs,
Save a training checkpoint.
- Parameters:
weights_path – The specific directory path where the checkpoint will be saved.
optimizer_path – If not None, optimizer and scheduler states are saved if they exist.
- abstractmethod load_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
Load a training checkpoint.
- Parameters:
weights_path – The exact directory path from which to load the checkpoint.
optimizer_path – If not None, attempts to load optimizer and scheduler states if self.optimizer and self.scheduler are initialized.
- check_tensor_parallel_attributes() dict[str, Any]#
Check tensor parallel attributes on model parameters.
- Returns:
tp_params: List of parameter names that have tensor_model_parallel=True
non_tp_params: List of parameter names that have tensor_model_parallel=False
total_params: Total number of parameters checked
tp_size: Tensor parallel size from config
- Return type:
Dictionary containing information about tensor parallel parameters
- calibrate_qkv_fp8_scales(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- micro_batch_size: Optional[int] = None,
- percentile: float = 99.9,
- margin: float = 1.05,
- include_q: bool = False,
One-shot calibration of Q/K/V activation scales (for FP8 KV cache).
Captures each layer’s
query_key_valueoutput through forward hooks, splits Q/K/V, and computes percentile amax.In parallel (DP/TP/PP) environments, first computes local percentiles, then takes max across all ranks for conservativeness.
By default only returns and saves K/V scales, optionally returns Q.
- Parameters:
data – Representative sample batch for calibration, following get_logprobs input conventions.
micro_batch_size – Micro batch size during calibration; if None, reuses logprob_batch_size.
percentile – Percentile for amax (e.g. 99.9).
margin – Margin factor, e.g. 1.05.
save_path – If provided, rank0 will save results as JSON.
include_q – Whether to also return Q scale (usually only K/V needed).
- Returns:
“fp8”, “percentile”: float, “margin”: float, “layers”: { layer_name: {“k_scale”: float, “v_scale”: float[, “q_scale”: float] } } }
- Return type:
{ “format”
- class nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker(
- config: nemo_rl.models.policy.PolicyConfig,
- tokenizer: nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType,
- weights_path: Optional[str] = None,
- optimizer_path: Optional[str] = None,
- init_optimizer: bool = True,
- init_reference_model: bool = True,
- *,
- worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding,
- **kwargs: Any,
Bases:
nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorkerImpl