nemo_rl.models.automodel.train#

Training utilities for automodel (DTensor-based) policy workers.

This module provides post-processor classes and forward/backward functions that follow the same pattern as nemo_rl/models/megatron/train.py.

Key differences from megatron approach:

  • Post-processors compute results directly (no callable return pattern)

  • forward_with_post_processing_fn calls post-processor directly

  • automodel_forward_backward uses PyTorch autograd instead of Megatron’s pipeline

Module Contents#

Classes#

LossPostProcessor

Post-processor for computing training loss from model outputs.

LogprobsPostProcessor

Post-processor for computing log probabilities from model outputs.

TopkLogitsPostProcessor

Post-processor for computing top-k logits from model outputs.

ScorePostProcessor

Post-processor for computing reward model scores from model outputs.

Functions#

model_forward

Perform a single forward pass through the model.

extract_logits

Extract logits from model outputs.

apply_temperature_scaling

Apply temperature scaling to logits.

apply_top_k_top_p_filtering_for_local_logits

Apply top-k and top-p filtering to the non-distributed logits.

redistribute_logits_for_cp

Redistribute logits for context parallel processing.

prepare_data_for_cp

Prepare data for context parallel processing.

forward_with_post_processing_fn

Perform forward pass with pre-processed microbatch and apply post-processing.

automodel_forward_backward

Execute forward and backward passes for automodel.

aggregate_training_statistics

Aggregate training statistics across microbatches and ranks.

Data#

API#

nemo_rl.models.automodel.train.PostProcessingFunction#

None

nemo_rl.models.automodel.train.model_forward(
model: torch.nn.Module,
processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs,
is_reward_model: bool = False,
allow_flash_attn_args: bool = True,
) torch.Tensor#

Perform a single forward pass through the model.

Parameters:
  • model – The model to run forward pass on

  • processed_inputs – ProcessedInputs containing all tensors for forward pass

  • is_reward_model – Whether this is a reward model

  • allow_flash_attn_args – Whether to pass flash_attn_kwargs to model

Returns:

Output tensor from the model (logits)

Return type:

torch.Tensor

nemo_rl.models.automodel.train.extract_logits(
model: torch.nn.Module,
outputs: Any,
) torch.Tensor#

Extract logits from model outputs.

Parameters:
  • model – The model (used for lm_head if needed)

  • outputs – Model outputs (can be tensor, DTensor, or object with logits attribute)

Returns:

Logits tensor

Return type:

torch.Tensor

nemo_rl.models.automodel.train.apply_temperature_scaling(
logits: torch.Tensor,
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams],
) torch.Tensor#

Apply temperature scaling to logits.

Parameters:
  • logits – Logits tensor to scale

  • sampling_params – Sampling parameters

Returns:

Temperature-scaled logits

Return type:

torch.Tensor

nemo_rl.models.automodel.train.apply_top_k_top_p_filtering_for_local_logits(
logits: torch.Tensor,
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams],
) torch.Tensor#

Apply top-k and top-p filtering to the non-distributed logits.

Parameters:
  • logits – Logits tensor to filter

  • sampling_params – Sampling parameters

Returns:

Filtered logits

Return type:

torch.Tensor

nemo_rl.models.automodel.train.redistribute_logits_for_cp(
logits: torch.Tensor,
device_mesh: Any,
cp_mesh: Any,
sequence_dim: int = 1,
) torch.distributed.tensor.DTensor#

Redistribute logits for context parallel processing.

Handles the case where logits may be TP-sharded DTensor or regular tensor, and converts them to CP+TP sharded DTensor.

Parameters:
  • logits – Logits tensor (may be DTensor or regular tensor)

  • device_mesh – Full device mesh

  • cp_mesh – Context parallel mesh (kept for signature compatibility)

  • sequence_dim – Dimension for sequence sharding

Returns:

DTensor sharded on both CP and TP dimensions

nemo_rl.models.automodel.train.prepare_data_for_cp(
mb: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs,
cp_mesh: Any,
sequence_dim: int = 1,
) tuple[torch.Tensor, nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]]#

Prepare data for context parallel processing.

Converts seq_index to full tensor and wraps CP-sharded tensors in DTensor.

Parameters:
  • mb – Microbatch data dictionary

  • processed_inputs – Processed inputs containing CP buffers

  • cp_mesh – Context parallel mesh

  • sequence_dim – Dimension for sequence sharding

Returns:

Tuple of (seq_index_dtensor, updated_mb)

nemo_rl.models.automodel.train.forward_with_post_processing_fn(
model: torch.nn.Module,
post_processing_fn: nemo_rl.models.automodel.train.PostProcessingFunction,
processed_mb: nemo_rl.models.automodel.data.ProcessedMicrobatch,
is_reward_model: bool = False,
allow_flash_attn_args: bool = True,
global_valid_seqs: Optional[torch.Tensor] = None,
global_valid_toks: Optional[torch.Tensor] = None,
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
sequence_dim: int = 1,
) Tuple[Any, dict[str, Any], nemo_rl.models.automodel.data.ProcessedMicrobatch]#

Perform forward pass with pre-processed microbatch and apply post-processing.

This function takes a pre-processed microbatch (with sequence packing already handled), runs the forward step through the model, and applies the post-processing function to compute the result.

Unlike the megatron approach which returns a callable, this directly computes and returns the result since automodel uses PyTorch autograd.

Parameters:
  • model – The model to run forward pass on

  • post_processing_fn – Post-processing function to apply to the logits

  • processed_mb – Pre-fetched ProcessedMicrobatch containing data and processed inputs

  • is_reward_model – Whether this is a reward model

  • allow_flash_attn_args – Whether to pass flash_attn_kwargs to model

  • global_valid_seqs – Global valid sequence count for loss normalization

  • global_valid_toks – Global valid token count for loss normalization

  • sampling_params – Sampling parameters (top-k, top-p, temperature)

  • sequence_dim – Sequence dimension

Returns:

(result, metrics, processed_microbatch) - result: Output from post-processing (loss, logprobs, topk, or scores) - metrics: Dictionary of metrics from post-processing - processed_microbatch: The ProcessedMicrobatch that was processed

Return type:

tuple

nemo_rl.models.automodel.train.automodel_forward_backward(
model: torch.nn.Module,
data_iterator: Iterator[nemo_rl.models.automodel.data.ProcessedMicrobatch],
post_processing_fn: nemo_rl.models.automodel.train.PostProcessingFunction,
forward_only: bool = False,
is_reward_model: bool = False,
allow_flash_attn_args: bool = True,
global_valid_seqs: Optional[torch.Tensor] = None,
global_valid_toks: Optional[torch.Tensor] = None,
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
sequence_dim: int = 1,
dp_size: int = 1,
cp_size: int = 1,
num_global_batches: int = 1,
train_context_fn: Optional[Callable[[nemo_rl.models.automodel.data.ProcessedInputs], Any]] = None,
num_valid_microbatches: Optional[int] = None,
on_microbatch_start: Optional[Callable[[int], None]] = None,
) list[Tuple[Any, dict[str, Any]]]#

Execute forward and backward passes for automodel.

This is the main training loop function that coordinates forward and backward passes across multiple microbatches using PyTorch autograd.

Unlike megatron_forward_backward which uses Megatron’s pipeline parallel framework, this uses standard PyTorch operations.

Parameters:
  • model – The model to train

  • data_iterator – Iterator yielding ProcessedMicrobatch objects (already processed)

  • num_microbatches – Number of microbatches to process

  • post_processing_fn – Post-processing function to apply to the logits

  • forward_only – If True, skip backward pass

  • is_reward_model – Whether this is a reward model

  • allow_flash_attn_args – Whether to pass flash_attn_kwargs to model

  • global_valid_seqs – Global valid sequence count for loss normalization

  • global_valid_toks – Global valid token count for loss normalization

  • sampling_params – Sampling parameters (top-k, top-p, temperature)

  • sequence_dim – Sequence dimension

  • dp_size – Data parallel size

  • cp_size – Context parallel size

  • num_global_batches – Number of global batches (for metric scaling)

  • train_context_fn – Optional callable that takes ProcessedInputs and returns a context manager for the forward/backward pass. If None, no context is used.

  • num_valid_microbatches – Number of valid (non-dummy) microbatches. If provided, microbatches beyond this index are treated as dummy batches (loss *= 0). If None, all microbatches are considered valid.

  • on_microbatch_start – Optional callback called at the start of each microbatch with the microbatch index. Useful for cache clearing, etc.

Returns:

List of (result, metrics) tuples from each microbatch

class nemo_rl.models.automodel.train.LossPostProcessor(
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
cfg: nemo_rl.models.policy.PolicyConfig,
device_mesh: Any,
cp_mesh: Any,
tp_mesh: Any,
cp_size: int,
dp_size: int,
enable_seq_packing: bool = False,
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
)#

Post-processor for computing training loss from model outputs.

Initialization

Initialize LossPostProcessor.

Parameters:
  • loss_fn – Loss function to compute loss

  • cfg – Configuration dictionary

  • device_mesh – Full device mesh

  • cp_mesh – Context parallel mesh

  • tp_mesh – Tensor parallel mesh

  • cp_size – Context parallel size

  • dp_size – Data parallel size

  • enable_seq_packing – Whether sequence packing is enabled

  • sampling_params – Sampling parameters

__call__(
logits: torch.Tensor,
data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs,
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor,
sequence_dim: int = 1,
) tuple[torch.Tensor, dict[str, Any]]#

Compute loss from logits.

Parameters:
  • logits – Model output logits

  • data_dict – Microbatch data

  • processed_inputs – Processed inputs

  • global_valid_seqs – Global valid sequence count

  • global_valid_toks – Global valid token count

  • sequence_dim – Sequence dimension

Returns:

Tuple of (loss, metrics)

class nemo_rl.models.automodel.train.LogprobsPostProcessor(
cfg: nemo_rl.models.policy.PolicyConfig,
device_mesh: Any,
cp_mesh: Any,
tp_mesh: Any,
cp_size: int,
enable_seq_packing: bool = False,
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
)#

Post-processor for computing log probabilities from model outputs.

Initialization

Initialize LogprobsPostProcessor.

Parameters:
  • cfg – Configuration dictionary

  • device_mesh – Full device mesh

  • cp_mesh – Context parallel mesh

  • tp_mesh – Tensor parallel mesh

  • cp_size – Context parallel size

  • enable_seq_packing – Whether sequence packing is enabled

  • sampling_params – Sampling parameters

__call__(
logits: torch.Tensor,
data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs,
original_batch_size: int,
original_seq_len: int,
sequence_dim: int = 1,
) torch.Tensor#

Compute token log probabilities from logits.

Parameters:
  • logits – Model output logits

  • data_dict – Microbatch data

  • processed_inputs – Processed inputs

  • original_batch_size – Original batch size before packing

  • original_seq_len – Original sequence length before packing

  • sequence_dim – Sequence dimension

Returns:

Token log probabilities tensor [batch_size, seq_length]

_compute_local_logprobs(
logits: torch.Tensor,
input_ids: torch.Tensor,
) torch.Tensor#

Compute logprobs locally without distributed processing.

Parameters:
  • logits – Model output logits

  • input_ids – Input token IDs

Returns:

Token log probabilities

class nemo_rl.models.automodel.train.TopkLogitsPostProcessor(
cfg: nemo_rl.models.policy.PolicyConfig,
device_mesh: Any,
cp_mesh: Any,
tp_mesh: Any,
cp_size: int,
k: int,
enable_seq_packing: bool = False,
)#

Post-processor for computing top-k logits from model outputs.

Initialization

Initialize TopkLogitsPostProcessor.

Parameters:
  • cfg – Configuration dictionary

  • device_mesh – Full device mesh

  • cp_mesh – Context parallel mesh

  • tp_mesh – Tensor parallel mesh

  • cp_size – Context parallel size

  • k – Number of top logits to return

  • enable_seq_packing – Whether sequence packing is enabled

__call__(
logits: torch.Tensor,
data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs,
original_batch_size: int,
original_seq_len: int,
sequence_dim: int = 1,
) tuple[torch.Tensor, torch.Tensor]#

Compute top-k logits and indices from model outputs.

Parameters:
  • logits – Model output logits

  • data_dict – Microbatch data

  • processed_inputs – Processed inputs

  • original_batch_size – Original batch size before packing

  • original_seq_len – Original sequence length before packing

  • sequence_dim – Sequence dimension

Returns:

Tuple of (top-k values, top-k indices) tensors

class nemo_rl.models.automodel.train.ScorePostProcessor(cfg: nemo_rl.models.policy.PolicyConfig)#

Post-processor for computing reward model scores from model outputs.

Initialization

Initialize ScorePostProcessor.

Parameters:

cfg – Configuration dictionary

__call__(logits: torch.Tensor) torch.Tensor#

Extract scores from reward model outputs.

Parameters:

logits – Model output logits

Returns:

Scores tensor

nemo_rl.models.automodel.train.aggregate_training_statistics(
losses: list[float],
all_mb_metrics: list[dict[str, Any]],
grad_norm: Optional[torch.Tensor],
dp_group: Any,
dtype: torch.dtype,
) dict[str, Any]#

Aggregate training statistics across microbatches and ranks.

Parameters:
  • losses – List of loss values from each microbatch

  • all_mb_metrics – List of metrics dictionaries from each microbatch

  • grad_norm – Gradient norm tensor (or None if eval mode)

  • dp_group – Data parallel process group for all-reduce

  • dtype – Model dtype for metrics

Returns:

Dictionary containing aggregated metrics including global_loss, grad_norm, etc.