nemo_rl.models.value.workers.megatron_value_worker#

Module Contents#

Classes#

ValueHead

Simple linear value head that maps hidden states to scalar values.

_ValueOutputLayerBypass

Replaces the output_layer during value model forward.

MegatronValueWorkerImpl

Megatron-Core based value function worker for PPO.

MegatronValueWorker

Functions#

_unwrap_model

Unwrap a model from DDP/Float16Module wrappers to get the base GPTModel.

forward_step_value

Forward step for value model that captures hidden states and computes values.

Data#

API#

nemo_rl.models.value.workers.megatron_value_worker.TokenizerType#

‘TypeVar(…)’

class nemo_rl.models.value.workers.megatron_value_worker.ValueHead(hidden_size: int, dtype: torch.dtype)#

Bases: torch.nn.Module

Simple linear value head that maps hidden states to scalar values.

Works correctly with tensor parallelism by operating on the full hidden dimension (which is not split across TP ranks at the output_layer input). With sequence parallelism, each TP rank processes its shard of the sequence independently; results are gathered later.

Initialization

forward(hidden_states: torch.Tensor) torch.Tensor#

Map hidden states to scalar values.

Parameters:

hidden_states – [batch, seq, hidden_size] (may be seq-parallel sharded)

Returns:

[batch, seq, 1]

Return type:

values

nemo_rl.models.value.workers.megatron_value_worker._unwrap_model(model)#

Unwrap a model from DDP/Float16Module wrappers to get the base GPTModel.

class nemo_rl.models.value.workers.megatron_value_worker._ValueOutputLayerBypass(captured_hidden: dict)#

Bases: torch.nn.Module

Replaces the output_layer during value model forward.

Skips the expensive logits computation (hidden_size -> vocab_size). Instead of computing [S, B, vocab_size] logits, captures the hidden states and returns a minimal [S, B, 1] tensor, saving both memory and FLOPS.

Initialization

forward(hidden_states, *args, **kwargs)#
nemo_rl.models.value.workers.megatron_value_worker.forward_step_value(
state,
global_valid_seqs,
global_valid_toks,
data_iterator,
model,
*,
value_head,
loss_fn,
pack_sequences=False,
defer_fp32_logits=None,
cp_normalize=True,
policy_cfg=None,
)#

Forward step for value model that captures hidden states and computes values.

This is similar to forward_step_arbitrary_loss but intercepts hidden states before the language model head and applies a value head instead.

class nemo_rl.models.value.workers.megatron_value_worker.MegatronValueWorkerImpl(
config: nemo_rl.models.value.config.ValueConfig,
tokenizer: nemo_rl.models.value.workers.megatron_value_worker.TokenizerType,
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
init_optimizer: bool = True,
*,
worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding,
**kwargs: Any,
)#

Bases: nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker

Megatron-Core based value function worker for PPO.

This worker wraps a Megatron-Core GPT model backbone with a value head (Linear: hidden_size -> 1) to predict per-token state values.

Supports:

  • Tensor parallelism (TP)

  • Pipeline parallelism (PP)

  • Context parallelism (CP)

  • Sequence packing

  • Activation checkpointing

  • FP8 quantization

  • MoE models

Initialization

Initialize the MegatronValueWorker.

Parameters:
  • config – Value model configuration.

  • tokenizer – HuggingFace tokenizer.

  • weights_path – Path to load finetuned weights from (optional).

  • optimizer_path – Path to load optimizer state from (optional).

  • init_optimizer – Whether to initialize the optimizer.

  • worker_sharding_annotations – Sharding topology for distributed training.

__repr__()#
_add_value_head_to_optimizer()#

Add value head parameters to the Megatron optimizer.

Since the Megatron optimizer manages parameters through DDP buffers, we create a separate PyTorch optimizer for the value head and step both during training.

enable_forward_pre_hook()#
disable_forward_pre_hook(param_sync=True)#
train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
eval_mode: bool = False,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) dict[str, Any]#

Train the value function on a batch of data with a given loss function.

Parameters:
  • data

    BatchedDataDict containing training data with keys:

    • input_ids, input_lengths, token_mask, sample_mask, returns

  • loss_fn – Value loss function (e.g., MseValueLossFn).

  • eval_mode – If True, run forward only without parameter updates.

  • gbs – Global batch size override.

  • mbs – Micro batch size override.

Returns:

Dictionary with training metrics (global_loss, grad_norm, etc.)

get_values(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
micro_batch_size: Optional[int] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.value.interfaces.ValueOutputSpec]#

Get per-token value predictions for a batch of data.

Parameters:
  • data – BatchedDataDict containing input_ids and input_lengths.

  • micro_batch_size – Override for inference micro batch size.

Returns:

BatchedDataDict with “values” key of shape [batch_size, seq_length].

prepare_for_training() None#

Move model and optimizer to CUDA for training.

prepare_for_inference()#

Prepare model for value inference.

move_model(
model: torch.nn.Module,
device: str,
move_params: bool = True,
move_grads: bool = True,
) torch.nn.Module#

Move model parameters and gradient buffers to the specified device.

move_optimizer(device: str)#

Move optimizer state to the specified device.

save_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
**kwargs,
)#

Save a checkpoint of the value model.

Saves both the Megatron backbone checkpoint and the value head weights.

abstractmethod load_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
)#

Load a checkpoint for the value model.

finish_inference() None#

Offload model params to CPU after inference.

finish_training() None#

Offload model, gradients, and optimizer to CPU after training.

class nemo_rl.models.value.workers.megatron_value_worker.MegatronValueWorker(
config: nemo_rl.models.value.config.ValueConfig,
tokenizer: nemo_rl.models.value.workers.megatron_value_worker.TokenizerType,
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
init_optimizer: bool = True,
*,
worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding,
**kwargs: Any,
)#

Bases: nemo_rl.models.value.workers.megatron_value_worker.MegatronValueWorkerImpl