nemo_rl.models.megatron.train#

Module Contents#

Classes#

Functions#

model_forward

Perform a single forward pass through the model.

apply_temperature_scaling

Apply temperature scaling to logits.

forward_with_post_processing_fn

Perform forward pass with pre-processed microbatch and return output tensor and post-processing function.

megatron_forward_backward

Execute forward and backward passes using Megatron’s utilities.

aggregate_training_statistics

Aggregate training statistics across microbatches and data-parallel ranks.

Data#

API#

nemo_rl.models.megatron.train.PostProcessingFunction#

None

nemo_rl.models.megatron.train.model_forward(
model: megatron.core.models.gpt.GPTModel,
data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
input_ids_cp_sharded: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
defer_fp32_logits: Optional[bool] = False,
straggler_timer: Optional[megatron.core.utils.StragglerDetector] = None,
use_linear_ce_fusion_loss: bool = False,
) torch.Tensor#

Perform a single forward pass through the model.

Parameters:
  • model – The model to run forward pass on

  • data_dict – Dictionary containing batch data

  • input_ids_cp_sharded – Context-parallel sharded input token IDs

  • position_ids – Position IDs for tokens

  • attention_mask – Attention mask for the sequence

  • packed_seq_params – Parameters for packed sequences (optional)

  • defer_fp32_logits – Whether to skip the conversion of logits to fp32

  • straggler_timer – Straggler detector for profiling the forward pass

  • use_linear_ce_fusion_loss – Whether to use linear CE fusion loss

Returns:

Output tensor from the model (logits)

Return type:

torch.Tensor

nemo_rl.models.megatron.train.apply_temperature_scaling(
logits: torch.Tensor,
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams],
) torch.Tensor#

Apply temperature scaling to logits.

Parameters:
  • logits – Logits tensor to scale

  • sampling_params – Sampling parameters

Returns:

Temperature-scaled logits

Return type:

torch.Tensor

nemo_rl.models.megatron.train.forward_with_post_processing_fn(
data_iterator: Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch],
model: megatron.core.models.gpt.GPTModel,
post_processing_fn: nemo_rl.models.megatron.train.PostProcessingFunction,
defer_fp32_logits: Optional[bool] = False,
global_valid_seqs: Optional[torch.Tensor] = None,
global_valid_toks: Optional[torch.Tensor] = None,
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
straggler_timer: Optional[megatron.core.utils.StragglerDetector] = None,
draft_model: Optional[nemo_rl.models.megatron.config.MegatronModule] = None,
enable_hidden_capture: Optional[bool] = False,
use_linear_ce_fusion_loss: bool = False,
) Tuple[torch.Tensor, Callable]#

Perform forward pass with pre-processed microbatch and return output tensor and post-processing function.

This function takes a pre-processed microbatch (with sequence packing already handled), runs the forward step through the model, and prepares a post-processing function for post-processing the outputs.

Parameters:
  • data_iterator – Iterator yielding ProcessedMicrobatch objects (already processed)

  • model – The model to run forward pass on

  • post_processing_fn – Post-processing function to post-process the logits

  • defer_fp32_logits – Whether to defer FP32 conversion of logits

  • global_valid_seqs – Global valid sequence count for loss normalization

  • global_valid_toks – Global valid token count for loss normalization

  • sampling_params – Sampling parameters (top-k, top-p, temperature)

  • straggler_timer – Straggler detector for profiling the forward pass

  • draft_model – Draft model for online draft model training

  • enable_hidden_capture – Whether to enable hidden state capture for draft model training

Returns:

(output_tensor, post_processing_fn_wrapped) - output_tensor: Raw model outputs (logits) - post_processing_fn_wrapped: Function to create output post-processing function when called

Return type:

tuple

nemo_rl.models.megatron.train.megatron_forward_backward(
model: megatron.core.models.gpt.GPTModel,
data_iterator: Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch],
num_microbatches: int,
seq_length: int,
mbs: int,
post_processing_fn: nemo_rl.models.megatron.train.PostProcessingFunction,
forward_only: bool = False,
defer_fp32_logits: Optional[bool] = False,
global_valid_seqs: Optional[torch.Tensor] = None,
global_valid_toks: Optional[torch.Tensor] = None,
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
straggler_timer: Optional[megatron.core.utils.StragglerDetector] = None,
draft_model: Optional[nemo_rl.models.megatron.config.MegatronModule] = None,
enable_hidden_capture: Optional[bool] = False,
use_linear_ce_fusion_loss: bool = False,
) Any#

Execute forward and backward passes using Megatron’s utilities.

This is the main training loop function that coordinates forward and backward passes across multiple microbatches using Megatron’s pipeline parallel execution framework.

Parameters:
  • model – The model to train

  • data_iterator – Iterator yielding ProcessedMicrobatch objects (already processed)

  • num_microbatches – Number of microbatches to process

  • seq_length – Sequence length

  • mbs – Micro batch size

  • post_processing_fn – Post-processing function to post-process the logits

  • forward_only – If True, skip backward pass

  • defer_fp32_logits – Whether to skip the conversion of logits to fp32

  • global_valid_seqs – Global valid sequence count for loss normalization

  • global_valid_toks – Global valid token count for loss normalization

  • sampling_params – Sampling parameters (top-k, top-p, temperature)

  • straggler_timer – Straggler detector for profiling the forward pass

  • draft_model – Draft model for online draft model training

  • enable_hidden_capture – Whether to enable hidden state capture for draft model training

Returns:

Results from the forward/backward execution

class nemo_rl.models.megatron.train.LossPostProcessor(
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
cfg: nemo_rl.models.policy.PolicyConfig,
num_microbatches: int = 1,
cp_normalize: bool = True,
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
draft_model: Optional[nemo_rl.models.megatron.config.MegatronModule] = None,
)#

Initialization

__call__(
data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
global_valid_seqs: Optional[torch.Tensor] = None,
global_valid_toks: Optional[torch.Tensor] = None,
) Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, Any]]]#

Create a loss post-processing function for training.

This function wraps a loss function with the necessary context and parameters to compute loss and metrics from model outputs. It handles sequence packing and context parallelism normalization.

Parameters:
  • data_dict – Batched data dictionary for the current microbatch

  • packed_seq_params – Parameters for packed sequences (optional)

  • global_valid_seqs – Global valid sequence count for loss normalization

  • global_valid_toks – Global valid token count for loss normalization

Returns:

Function that takes output tensor and returns (loss, metrics) tuple

Return type:

Callable

class nemo_rl.models.megatron.train.LogprobsPostProcessor(
cfg: nemo_rl.models.policy.PolicyConfig,
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
)#

Initialization

__call__(
data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
input_ids: torch.Tensor,
cu_seqlens_padded: torch.Tensor,
) Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]#

Create a post-processing function that computes token log probabilities.

This function returns a processor that takes model logits and converts them to token-level log probabilities, handling both packed and unpacked sequences.

Parameters:
  • data_dict – Batched data dictionary containing input sequences

  • input_ids – Processed input token IDs

  • cu_seqlens_padded – Cumulative sequence lengths for packed sequences

Returns:

Function that takes output tensor and returns (dummy_loss, {“logprobs”: token_logprobs})

Return type:

Callable

class nemo_rl.models.megatron.train.TopkLogitsPostProcessor(
cfg: nemo_rl.models.policy.PolicyConfig,
k: int,
)#

Initialization

__call__(
data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
cu_seqlens_padded: torch.Tensor,
) Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]#

Create a post-processing function that computes top-k logits and indices.

This function returns a processor that extracts the top-k highest logits and their corresponding vocabulary indices from model outputs. It handles tensor parallelism, context parallelism, and sequence packing.

Parameters:
  • data_dict – Batched data dictionary

  • cu_seqlens_padded – Cumulative sequence lengths for packed sequences

Returns:

Function that takes output tensor and returns (dummy_loss, {“topk_logits”: values, “topk_indices”: indices})

Return type:

Callable

nemo_rl.models.megatron.train.aggregate_training_statistics(
all_mb_metrics: List[Dict[str, Any]],
losses: List[float],
data_parallel_group: torch.distributed.ProcessGroup,
) Tuple[Dict[str, List[Any]], torch.Tensor]#

Aggregate training statistics across microbatches and data-parallel ranks.

Computes a global loss by all-reducing per-gradient-buffer losses across the data-parallel group, then collects per-microbatch metrics into lists keyed by metric name.

Parameters:
  • all_mb_metrics – List of metric dicts from each microbatch.

  • losses – List of per-gradient-buffer scalar losses on this rank.

  • data_parallel_group – The data-parallel process group for all-reduce.

Returns:

  • mb_metrics: Dict mapping metric names to lists of values across microbatches.

  • global_loss: Tensor of losses summed across all data-parallel ranks.

Return type:

Tuple of