nemo_rl.models.automodel.train#
Training utilities for automodel (DTensor-based) policy workers.
This module provides post-processor classes and forward/backward functions that follow the same pattern as nemo_rl/models/megatron/train.py.
Key differences from megatron approach:
Post-processors compute results directly (no callable return pattern)
forward_with_post_processing_fn calls post-processor directly
automodel_forward_backward uses PyTorch autograd instead of Megatron’s pipeline
Module Contents#
Classes#
Post-processor for computing training loss from model outputs. |
|
Post-processor for computing log probabilities from model outputs. |
|
Post-processor for computing top-k logits from model outputs. |
|
Post-processor for computing reward model scores from model outputs. |
Functions#
Perform a single forward pass through the model. |
|
Extract logits from model outputs. |
|
Apply temperature scaling to logits. |
|
Apply top-k and top-p filtering to the non-distributed logits. |
|
Redistribute logits for context parallel processing. |
|
Prepare data for context parallel processing. |
|
Perform forward pass with pre-processed microbatch and apply post-processing. |
|
Execute forward and backward passes for automodel. |
|
Aggregate training statistics across microbatches and ranks. |
Data#
API#
- nemo_rl.models.automodel.train.PostProcessingFunction#
None
- nemo_rl.models.automodel.train.model_forward(
- model: torch.nn.Module,
- processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs,
- is_reward_model: bool = False,
- allow_flash_attn_args: bool = True,
Perform a single forward pass through the model.
- Parameters:
model – The model to run forward pass on
processed_inputs – ProcessedInputs containing all tensors for forward pass
is_reward_model – Whether this is a reward model
allow_flash_attn_args – Whether to pass flash_attn_kwargs to model
- Returns:
Output tensor from the model (logits)
- Return type:
torch.Tensor
- nemo_rl.models.automodel.train.extract_logits(
- model: torch.nn.Module,
- outputs: Any,
Extract logits from model outputs.
- Parameters:
model – The model (used for lm_head if needed)
outputs – Model outputs (can be tensor, DTensor, or object with logits attribute)
- Returns:
Logits tensor
- Return type:
torch.Tensor
- nemo_rl.models.automodel.train.apply_temperature_scaling(
- logits: torch.Tensor,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams],
Apply temperature scaling to logits.
- Parameters:
logits – Logits tensor to scale
sampling_params – Sampling parameters
- Returns:
Temperature-scaled logits
- Return type:
torch.Tensor
- nemo_rl.models.automodel.train.apply_top_k_top_p_filtering_for_local_logits(
- logits: torch.Tensor,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams],
Apply top-k and top-p filtering to the non-distributed logits.
- Parameters:
logits – Logits tensor to filter
sampling_params – Sampling parameters
- Returns:
Filtered logits
- Return type:
torch.Tensor
- nemo_rl.models.automodel.train.redistribute_logits_for_cp(
- logits: torch.Tensor,
- device_mesh: Any,
- cp_mesh: Any,
- sequence_dim: int = 1,
Redistribute logits for context parallel processing.
Handles the case where logits may be TP-sharded DTensor or regular tensor, and converts them to CP+TP sharded DTensor.
- Parameters:
logits – Logits tensor (may be DTensor or regular tensor)
device_mesh – Full device mesh
cp_mesh – Context parallel mesh (kept for signature compatibility)
sequence_dim – Dimension for sequence sharding
- Returns:
DTensor sharded on both CP and TP dimensions
- nemo_rl.models.automodel.train.prepare_data_for_cp(
- mb: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs,
- cp_mesh: Any,
- sequence_dim: int = 1,
Prepare data for context parallel processing.
Converts seq_index to full tensor and wraps CP-sharded tensors in DTensor.
- Parameters:
mb – Microbatch data dictionary
processed_inputs – Processed inputs containing CP buffers
cp_mesh – Context parallel mesh
sequence_dim – Dimension for sequence sharding
- Returns:
Tuple of (seq_index_dtensor, updated_mb)
- nemo_rl.models.automodel.train.forward_with_post_processing_fn(
- model: torch.nn.Module,
- post_processing_fn: nemo_rl.models.automodel.train.PostProcessingFunction,
- processed_mb: nemo_rl.models.automodel.data.ProcessedMicrobatch,
- is_reward_model: bool = False,
- allow_flash_attn_args: bool = True,
- global_valid_seqs: Optional[torch.Tensor] = None,
- global_valid_toks: Optional[torch.Tensor] = None,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
- sequence_dim: int = 1,
Perform forward pass with pre-processed microbatch and apply post-processing.
This function takes a pre-processed microbatch (with sequence packing already handled), runs the forward step through the model, and applies the post-processing function to compute the result.
Unlike the megatron approach which returns a callable, this directly computes and returns the result since automodel uses PyTorch autograd.
- Parameters:
model – The model to run forward pass on
post_processing_fn – Post-processing function to apply to the logits
processed_mb – Pre-fetched ProcessedMicrobatch containing data and processed inputs
is_reward_model – Whether this is a reward model
allow_flash_attn_args – Whether to pass flash_attn_kwargs to model
global_valid_seqs – Global valid sequence count for loss normalization
global_valid_toks – Global valid token count for loss normalization
sampling_params – Sampling parameters (top-k, top-p, temperature)
sequence_dim – Sequence dimension
- Returns:
(result, metrics, processed_microbatch) - result: Output from post-processing (loss, logprobs, topk, or scores) - metrics: Dictionary of metrics from post-processing - processed_microbatch: The ProcessedMicrobatch that was processed
- Return type:
tuple
- nemo_rl.models.automodel.train.automodel_forward_backward(
- model: torch.nn.Module,
- data_iterator: Iterator[nemo_rl.models.automodel.data.ProcessedMicrobatch],
- post_processing_fn: nemo_rl.models.automodel.train.PostProcessingFunction,
- forward_only: bool = False,
- is_reward_model: bool = False,
- allow_flash_attn_args: bool = True,
- global_valid_seqs: Optional[torch.Tensor] = None,
- global_valid_toks: Optional[torch.Tensor] = None,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
- sequence_dim: int = 1,
- dp_size: int = 1,
- cp_size: int = 1,
- num_global_batches: int = 1,
- train_context_fn: Optional[Callable[[nemo_rl.models.automodel.data.ProcessedInputs], Any]] = None,
- num_valid_microbatches: Optional[int] = None,
- on_microbatch_start: Optional[Callable[[int], None]] = None,
Execute forward and backward passes for automodel.
This is the main training loop function that coordinates forward and backward passes across multiple microbatches using PyTorch autograd.
Unlike megatron_forward_backward which uses Megatron’s pipeline parallel framework, this uses standard PyTorch operations.
- Parameters:
model – The model to train
data_iterator – Iterator yielding ProcessedMicrobatch objects (already processed)
num_microbatches – Number of microbatches to process
post_processing_fn – Post-processing function to apply to the logits
forward_only – If True, skip backward pass
is_reward_model – Whether this is a reward model
allow_flash_attn_args – Whether to pass flash_attn_kwargs to model
global_valid_seqs – Global valid sequence count for loss normalization
global_valid_toks – Global valid token count for loss normalization
sampling_params – Sampling parameters (top-k, top-p, temperature)
sequence_dim – Sequence dimension
dp_size – Data parallel size
cp_size – Context parallel size
num_global_batches – Number of global batches (for metric scaling)
train_context_fn – Optional callable that takes ProcessedInputs and returns a context manager for the forward/backward pass. If None, no context is used.
num_valid_microbatches – Number of valid (non-dummy) microbatches. If provided, microbatches beyond this index are treated as dummy batches (loss *= 0). If None, all microbatches are considered valid.
on_microbatch_start – Optional callback called at the start of each microbatch with the microbatch index. Useful for cache clearing, etc.
- Returns:
List of (result, metrics) tuples from each microbatch
- class nemo_rl.models.automodel.train.LossPostProcessor(
- loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- cfg: nemo_rl.models.policy.PolicyConfig,
- device_mesh: Any,
- cp_mesh: Any,
- tp_mesh: Any,
- cp_size: int,
- dp_size: int,
- enable_seq_packing: bool = False,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
Post-processor for computing training loss from model outputs.
Initialization
Initialize LossPostProcessor.
- Parameters:
loss_fn – Loss function to compute loss
cfg – Configuration dictionary
device_mesh – Full device mesh
cp_mesh – Context parallel mesh
tp_mesh – Tensor parallel mesh
cp_size – Context parallel size
dp_size – Data parallel size
enable_seq_packing – Whether sequence packing is enabled
sampling_params – Sampling parameters
- __call__(
- logits: torch.Tensor,
- data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs,
- global_valid_seqs: torch.Tensor,
- global_valid_toks: torch.Tensor,
- sequence_dim: int = 1,
Compute loss from logits.
- Parameters:
logits – Model output logits
data_dict – Microbatch data
processed_inputs – Processed inputs
global_valid_seqs – Global valid sequence count
global_valid_toks – Global valid token count
sequence_dim – Sequence dimension
- Returns:
Tuple of (loss, metrics)
- class nemo_rl.models.automodel.train.LogprobsPostProcessor(
- cfg: nemo_rl.models.policy.PolicyConfig,
- device_mesh: Any,
- cp_mesh: Any,
- tp_mesh: Any,
- cp_size: int,
- enable_seq_packing: bool = False,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
Post-processor for computing log probabilities from model outputs.
Initialization
Initialize LogprobsPostProcessor.
- Parameters:
cfg – Configuration dictionary
device_mesh – Full device mesh
cp_mesh – Context parallel mesh
tp_mesh – Tensor parallel mesh
cp_size – Context parallel size
enable_seq_packing – Whether sequence packing is enabled
sampling_params – Sampling parameters
- __call__(
- logits: torch.Tensor,
- data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs,
- original_batch_size: int,
- original_seq_len: int,
- sequence_dim: int = 1,
Compute token log probabilities from logits.
- Parameters:
logits – Model output logits
data_dict – Microbatch data
processed_inputs – Processed inputs
original_batch_size – Original batch size before packing
original_seq_len – Original sequence length before packing
sequence_dim – Sequence dimension
- Returns:
Token log probabilities tensor [batch_size, seq_length]
- _compute_local_logprobs(
- logits: torch.Tensor,
- input_ids: torch.Tensor,
Compute logprobs locally without distributed processing.
- Parameters:
logits – Model output logits
input_ids – Input token IDs
- Returns:
Token log probabilities
- class nemo_rl.models.automodel.train.TopkLogitsPostProcessor(
- cfg: nemo_rl.models.policy.PolicyConfig,
- device_mesh: Any,
- cp_mesh: Any,
- tp_mesh: Any,
- cp_size: int,
- k: int,
- enable_seq_packing: bool = False,
Post-processor for computing top-k logits from model outputs.
Initialization
Initialize TopkLogitsPostProcessor.
- Parameters:
cfg – Configuration dictionary
device_mesh – Full device mesh
cp_mesh – Context parallel mesh
tp_mesh – Tensor parallel mesh
cp_size – Context parallel size
k – Number of top logits to return
enable_seq_packing – Whether sequence packing is enabled
- __call__(
- logits: torch.Tensor,
- data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs,
- original_batch_size: int,
- original_seq_len: int,
- sequence_dim: int = 1,
Compute top-k logits and indices from model outputs.
- Parameters:
logits – Model output logits
data_dict – Microbatch data
processed_inputs – Processed inputs
original_batch_size – Original batch size before packing
original_seq_len – Original sequence length before packing
sequence_dim – Sequence dimension
- Returns:
Tuple of (top-k values, top-k indices) tensors
- class nemo_rl.models.automodel.train.ScorePostProcessor(cfg: nemo_rl.models.policy.PolicyConfig)#
Post-processor for computing reward model scores from model outputs.
Initialization
Initialize ScorePostProcessor.
- Parameters:
cfg – Configuration dictionary
- __call__(logits: torch.Tensor) torch.Tensor#
Extract scores from reward model outputs.
- Parameters:
logits – Model output logits
- Returns:
Scores tensor
- nemo_rl.models.automodel.train.aggregate_training_statistics(
- losses: list[float],
- all_mb_metrics: list[dict[str, Any]],
- grad_norm: Optional[torch.Tensor],
- dp_group: Any,
- dtype: torch.dtype,
Aggregate training statistics across microbatches and ranks.
- Parameters:
losses – List of loss values from each microbatch
all_mb_metrics – List of metrics dictionaries from each microbatch
grad_norm – Gradient norm tensor (or None if eval mode)
dp_group – Data parallel process group for all-reduce
dtype – Model dtype for metrics
- Returns:
Dictionary containing aggregated metrics including global_loss, grad_norm, etc.