nemo_rl.models.policy.workers.megatron_policy_worker#

Module Contents#

Classes#

Data#

API#

nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType#

‘TypeVar(…)’

class nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorkerImpl(
config: nemo_rl.models.policy.PolicyConfig,
tokenizer: nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType,
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
init_optimizer: bool = True,
init_reference_model: bool = True,
*,
worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding,
**kwargs: Any,
)#

Bases: nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker, nemo_rl.models.policy.interfaces.ColocatablePolicyInterface

__repr__()#

Customizes the actor’s prefix in the Ray logs.

This makes it easier to identify which worker is producing specific log messages.

enable_forward_pre_hook()#
disable_forward_pre_hook(param_sync=True)#
train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
eval_mode: bool = False,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) dict[str, Any]#

Train the policy on a batch of data with a given loss function.

get_logprobs(
*,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
micro_batch_size: Optional[int] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec]#

Get the logprobs of the model for a batch of data.

Uses the configured logprob_batch_size to do microbatching. Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format. If micro_batch_size is provided, it will be used instead of the configured logprob_batch_size.

Returns:

a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.

use_reference_model()#

Context manager that temporarily swaps the reference model and active model.

On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references. Also disables top-k/top-p filtering since the reference policy’s distribution is different from the current policy, making filtered logprobs incompatible. On exit: Restores original references and re-flips cuda/cpu, restores sampling_params.

get_topk_logits(
*,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
k: int,
micro_batch_size: Optional[int] = None,
)#

Get the top-k logits and indices for a batch of data.

The major difference from get_logprobs is that we compute top-k logits and indices for each position in the sequence.

Returns:

  • topk_logits: Tensor of top-k logits for each position in the sequence

  • topk_indices: Tensor of top-k indices for each position in the sequence

Return type:

BatchedDataDict containing

generate(
*,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]#

Generate a batch of data using huggingface framework generation.

Parameters:

data – BatchedDataDict containing input_ids and input_lengths tensors

Returns:

  • output_ids: input + generated token IDs

  • logprobs: Log probabilities for each token

  • generation_lengths: Lengths of each response

Return type:

BatchedDataDict conforming to GenerationOutputSpec

prepare_refit_info() None#

Prepare state dict metadata for weight refitting and IPC streaming.

_calculate_refit_param_info() list[tuple[str, int]]#

Calculate parameter information for refit.

Each task contains:

  • param_name: Local parameter name without module prefixes

  • mapping: MegatronParamMapping instance for weight transformation

  • pp_rank: Pipeline-parallel rank owning the parameter

  • vp_stage: Virtual-pipeline stage index

  • megatron_module: Reference to Megatron model/submodule

  • param_weight: Target parameter tensor for converted weight

Returns:

List of (parameter_name, size_in_bytes) tuples.

_iter_params_with_optional_kv_scales(
kv_scales: Optional[dict[str, float]] = None,
) Iterator[tuple[str, torch.Tensor]]#

Yield exported HF parameters and optionally append FP8 KV/Q scale tensors.

This helper is used by both IPC-based streaming and collective broadcast so that the logic for adding KV scales stays consistent in one place.

stream_weights_via_ipc_zmq(
buffer_size_bytes: int = 0,
kv_scales: Optional[dict[str, float]] = None,
) None#

Stream model weights to peer process via ZMQ IPC socket.

broadcast_weights_for_collective(
kv_scales: Optional[dict[str, float]] = None,
) None#

Broadcast the weights for collective communication.

prepare_for_lp_inference()#
prepare_for_training(*args, **kwargs)#
offload_before_refit()#

Offload the optimizer and buffers to the CPU.

offload_after_refit()#

Offload as much as possible on the CPU.

move_model(
model: torch.nn.Module,
device: str,
move_params: bool = True,
move_grads: bool = True,
) torch.nn.Module#
move_optimizer(device: str)#
save_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
**kwargs,
)#

Save a training checkpoint.

Parameters:
  • weights_path – The specific directory path where the checkpoint will be saved.

  • optimizer_path – If not None, optimizer and scheduler states are saved if they exist.

abstractmethod load_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
)#

Load a training checkpoint.

Parameters:
  • weights_path – The exact directory path from which to load the checkpoint.

  • optimizer_path – If not None, attempts to load optimizer and scheduler states if self.optimizer and self.scheduler are initialized.

check_tensor_parallel_attributes() dict[str, Any]#

Check tensor parallel attributes on model parameters.

Returns:

  • tp_params: List of parameter names that have tensor_model_parallel=True

  • non_tp_params: List of parameter names that have tensor_model_parallel=False

  • total_params: Total number of parameters checked

  • tp_size: Tensor parallel size from config

Return type:

Dictionary containing information about tensor parallel parameters

calibrate_qkv_fp8_scales(
*,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
micro_batch_size: Optional[int] = None,
percentile: float = 99.9,
margin: float = 1.05,
include_q: bool = False,
) dict[str, Any]#

One-shot calibration of Q/K/V activation scales (for FP8 KV cache).

  • Captures each layer’s query_key_value output through forward hooks, splits Q/K/V, and computes percentile amax.

  • In parallel (DP/TP/PP) environments, first computes local percentiles, then takes max across all ranks for conservativeness.

  • By default only returns and saves K/V scales, optionally returns Q.

Parameters:
  • data – Representative sample batch for calibration, following get_logprobs input conventions.

  • micro_batch_size – Micro batch size during calibration; if None, reuses logprob_batch_size.

  • percentile – Percentile for amax (e.g. 99.9).

  • margin – Margin factor, e.g. 1.05.

  • save_path – If provided, rank0 will save results as JSON.

  • include_q – Whether to also return Q scale (usually only K/V needed).

Returns:

“fp8”, “percentile”: float, “margin”: float, “layers”: { layer_name: {“k_scale”: float, “v_scale”: float[, “q_scale”: float] } } }

Return type:

{ “format”

class nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker(
config: nemo_rl.models.policy.PolicyConfig,
tokenizer: nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType,
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
init_optimizer: bool = True,
init_reference_model: bool = True,
*,
worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding,
**kwargs: Any,
)#

Bases: nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorkerImpl