nemo_rl.models.policy.workers.megatron_policy_worker#

Module Contents#

Classes#

Functions#

broadcast_object_across_pp_ranks

Broadcast an object across pipeline parallel ranks.

Data#

API#

nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType#

‘TypeVar(…)’

nemo_rl.models.policy.workers.megatron_policy_worker.broadcast_object_across_pp_ranks(obj)#

Broadcast an object across pipeline parallel ranks.

This utility function handles broadcasting an object from the rank that owns it to all other pipeline parallel ranks. If only one rank has the object (non-None), it will be broadcast to all other ranks.

Parameters:

obj – The object to broadcast. Can be None on ranks that don’t own it.

Returns:

The object on all ranks (either the original or the broadcast copy).

Raises:

ValueError – If the object doesn’t exist on any pipeline parallel rank.

class nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker(
config: nemo_rl.models.policy.PolicyConfig,
tokenizer: nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType,
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
init_optimizer: bool = True,
init_reference_model: bool = True,
*,
worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding,
**kwargs: Any,
)#

Bases: nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker, nemo_rl.models.policy.interfaces.ColocatablePolicyInterface

__repr__()#

Customizes the actor’s prefix in the Ray logs.

This makes it easier to identify which worker is producing specific log messages.

enable_forward_pre_hook()#
disable_forward_pre_hook(param_sync=True)#
train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
eval_mode: bool = False,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) dict[str, Any]#

Train the policy on a batch of data with a given loss function.

get_logprobs(
*,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
micro_batch_size: Optional[int] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec]#

Get the logprobs of the model for a batch of data.

Uses the configured logprob_batch_size to do microbatching. Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format. If micro_batch_size is provided, it will be used instead of the configured logprob_batch_size.

Returns:

a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.

use_reference_model()#

Context manager that temporarily swaps the reference model and active model.

On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references On exit: Restores original references and re-flips cuda/cpu

get_topk_logits(
*,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
k: int,
micro_batch_size: Optional[int] = None,
)#

Get the top-k logits and indices for a batch of data.

The major difference from get_logprobs is that we compute top-k logits and indices for each position in the sequence.

Returns:

  • topk_logits: Tensor of top-k logits for each position in the sequence

  • topk_indices: Tensor of top-k indices for each position in the sequence

Return type:

BatchedDataDict containing

generate(
*,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]#

Generate a batch of data using huggingface framework generation.

Parameters:

data – BatchedDataDict containing input_ids and input_lengths tensors

Returns:

  • output_ids: input + generated token IDs

  • logprobs: Log probabilities for each token

  • generation_lengths: Lengths of each response

Return type:

BatchedDataDict conforming to GenerationOutputSpec

prepare_refit_info() None#

Prepare state dict metadata for weight refitting and IPC streaming.

_calculate_refit_param_info() list[tuple[str, int]]#

Calculate parameter information for refit.

Each task contains:

  • param_name: Local parameter name without module prefixes

  • mapping: MegatronParamMapping instance for weight transformation

  • pp_rank: Pipeline-parallel rank owning the parameter

  • vp_stage: Virtual-pipeline stage index

  • megatron_module: Reference to Megatron model/submodule

  • param_weight: Target parameter tensor for converted weight

Returns:

List of (parameter_name, size_in_bytes) tuples.

_iter_params_with_optional_kv_scales(
kv_scales: Optional[dict[str, float]] = None,
) Iterator[tuple[str, torch.Tensor]]#

Yield exported HF parameters and optionally append FP8 KV/Q scale tensors.

This helper is used by both IPC-based streaming and collective broadcast so that the logic for adding KV scales stays consistent in one place.

stream_weights_via_ipc_zmq(
buffer_size_bytes: int = 0,
kv_scales: Optional[dict[str, float]] = None,
) None#

Stream model weights to peer process via ZMQ IPC socket.

broadcast_weights_for_collective(
kv_scales: Optional[dict[str, float]] = None,
) None#

Broadcast the weights for collective communication.

prepare_for_lp_inference()#
prepare_for_training(*args, **kwargs)#
offload_before_refit()#

Offload the optimizer and buffers to the CPU.

offload_after_refit()#

Offload as much as possible on the CPU.

move_model(
model: torch.nn.Module,
device: str,
move_params: bool = True,
move_grads: bool = True,
) torch.nn.Module#
move_optimizer(device: str)#
save_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
**kwargs,
)#

Save a training checkpoint.

Parameters:
  • weights_path – The specific directory path where the checkpoint will be saved.

  • optimizer_path – If not None, optimizer and scheduler states are saved if they exist.

abstractmethod load_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
)#

Load a training checkpoint.

Parameters:
  • weights_path – The exact directory path from which to load the checkpoint.

  • optimizer_path – If not None, attempts to load optimizer and scheduler states if self.optimizer and self.scheduler are initialized.

check_tensor_parallel_attributes() dict[str, Any]#

Check tensor parallel attributes on model parameters.

Returns:

  • tp_params: List of parameter names that have tensor_model_parallel=True

  • non_tp_params: List of parameter names that have tensor_model_parallel=False

  • total_params: Total number of parameters checked

  • tp_size: Tensor parallel size from config

Return type:

Dictionary containing information about tensor parallel parameters

calibrate_qkv_fp8_scales(
*,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
micro_batch_size: Optional[int] = None,
percentile: float = 99.9,
margin: float = 1.05,
include_q: bool = False,
) dict[str, Any]#

One-shot calibration of Q/K/V activation scales (for FP8 KV cache).

  • Captures each layer’s query_key_value output through forward hooks, splits Q/K/V, and computes percentile amax.

  • In parallel (DP/TP/PP) environments, first computes local percentiles, then takes max across all ranks for conservativeness.

  • By default only returns and saves K/V scales, optionally returns Q.

Parameters:
  • data – Representative sample batch for calibration, following get_logprobs input conventions.

  • micro_batch_size – Micro batch size during calibration; if None, reuses logprob_batch_size.

  • percentile – Percentile for amax (e.g. 99.9).

  • margin – Margin factor, e.g. 1.05.

  • save_path – If provided, rank0 will save results as JSON.

  • include_q – Whether to also return Q scale (usually only K/V needed).

Returns:

“fp8”, “percentile”: float, “margin”: float, “layers”: { layer_name: {“k_scale”: float, “v_scale”: float[, “q_scale”: float] } } }

Return type:

{ “format”