nemo_rl.models.megatron.train#
Module Contents#
Classes#
Functions#
Perform a single forward pass through the model. |
|
Apply temperature scaling to logits. |
|
Perform forward pass with pre-processed microbatch and return output tensor and post-processing function. |
|
Execute forward and backward passes using Megatron’s utilities. |
|
Aggregate training statistics across microbatches and data-parallel ranks. |
Data#
API#
- nemo_rl.models.megatron.train.PostProcessingFunction#
None
- nemo_rl.models.megatron.train.model_forward(
- model: megatron.core.models.gpt.GPTModel,
- data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- input_ids_cp_sharded: torch.Tensor,
- position_ids: torch.Tensor,
- attention_mask: torch.Tensor,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- defer_fp32_logits: Optional[bool] = False,
- straggler_timer: Optional[megatron.core.utils.StragglerDetector] = None,
- use_linear_ce_fusion_loss: bool = False,
Perform a single forward pass through the model.
- Parameters:
model – The model to run forward pass on
data_dict – Dictionary containing batch data
input_ids_cp_sharded – Context-parallel sharded input token IDs
position_ids – Position IDs for tokens
attention_mask – Attention mask for the sequence
packed_seq_params – Parameters for packed sequences (optional)
defer_fp32_logits – Whether to skip the conversion of logits to fp32
straggler_timer – Straggler detector for profiling the forward pass
use_linear_ce_fusion_loss – Whether to use linear CE fusion loss
- Returns:
Output tensor from the model (logits)
- Return type:
torch.Tensor
- nemo_rl.models.megatron.train.apply_temperature_scaling(
- logits: torch.Tensor,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams],
Apply temperature scaling to logits.
- Parameters:
logits – Logits tensor to scale
sampling_params – Sampling parameters
- Returns:
Temperature-scaled logits
- Return type:
torch.Tensor
- nemo_rl.models.megatron.train.forward_with_post_processing_fn(
- data_iterator: Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch],
- model: megatron.core.models.gpt.GPTModel,
- post_processing_fn: nemo_rl.models.megatron.train.PostProcessingFunction,
- defer_fp32_logits: Optional[bool] = False,
- global_valid_seqs: Optional[torch.Tensor] = None,
- global_valid_toks: Optional[torch.Tensor] = None,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
- straggler_timer: Optional[megatron.core.utils.StragglerDetector] = None,
- draft_model: Optional[nemo_rl.models.megatron.config.MegatronModule] = None,
- enable_hidden_capture: Optional[bool] = False,
- use_linear_ce_fusion_loss: bool = False,
Perform forward pass with pre-processed microbatch and return output tensor and post-processing function.
This function takes a pre-processed microbatch (with sequence packing already handled), runs the forward step through the model, and prepares a post-processing function for post-processing the outputs.
- Parameters:
data_iterator – Iterator yielding ProcessedMicrobatch objects (already processed)
model – The model to run forward pass on
post_processing_fn – Post-processing function to post-process the logits
defer_fp32_logits – Whether to defer FP32 conversion of logits
global_valid_seqs – Global valid sequence count for loss normalization
global_valid_toks – Global valid token count for loss normalization
sampling_params – Sampling parameters (top-k, top-p, temperature)
straggler_timer – Straggler detector for profiling the forward pass
draft_model – Draft model for online draft model training
enable_hidden_capture – Whether to enable hidden state capture for draft model training
- Returns:
(output_tensor, post_processing_fn_wrapped) - output_tensor: Raw model outputs (logits) - post_processing_fn_wrapped: Function to create output post-processing function when called
- Return type:
tuple
- nemo_rl.models.megatron.train.megatron_forward_backward(
- model: megatron.core.models.gpt.GPTModel,
- data_iterator: Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch],
- num_microbatches: int,
- seq_length: int,
- mbs: int,
- post_processing_fn: nemo_rl.models.megatron.train.PostProcessingFunction,
- forward_only: bool = False,
- defer_fp32_logits: Optional[bool] = False,
- global_valid_seqs: Optional[torch.Tensor] = None,
- global_valid_toks: Optional[torch.Tensor] = None,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
- straggler_timer: Optional[megatron.core.utils.StragglerDetector] = None,
- draft_model: Optional[nemo_rl.models.megatron.config.MegatronModule] = None,
- enable_hidden_capture: Optional[bool] = False,
- use_linear_ce_fusion_loss: bool = False,
Execute forward and backward passes using Megatron’s utilities.
This is the main training loop function that coordinates forward and backward passes across multiple microbatches using Megatron’s pipeline parallel execution framework.
- Parameters:
model – The model to train
data_iterator – Iterator yielding ProcessedMicrobatch objects (already processed)
num_microbatches – Number of microbatches to process
seq_length – Sequence length
mbs – Micro batch size
post_processing_fn – Post-processing function to post-process the logits
forward_only – If True, skip backward pass
defer_fp32_logits – Whether to skip the conversion of logits to fp32
global_valid_seqs – Global valid sequence count for loss normalization
global_valid_toks – Global valid token count for loss normalization
sampling_params – Sampling parameters (top-k, top-p, temperature)
straggler_timer – Straggler detector for profiling the forward pass
draft_model – Draft model for online draft model training
enable_hidden_capture – Whether to enable hidden state capture for draft model training
- Returns:
Results from the forward/backward execution
- class nemo_rl.models.megatron.train.LossPostProcessor(
- loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- cfg: nemo_rl.models.policy.PolicyConfig,
- num_microbatches: int = 1,
- cp_normalize: bool = True,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
- draft_model: Optional[nemo_rl.models.megatron.config.MegatronModule] = None,
Initialization
- __call__(
- data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- global_valid_seqs: Optional[torch.Tensor] = None,
- global_valid_toks: Optional[torch.Tensor] = None,
Create a loss post-processing function for training.
This function wraps a loss function with the necessary context and parameters to compute loss and metrics from model outputs. It handles sequence packing and context parallelism normalization.
- Parameters:
data_dict – Batched data dictionary for the current microbatch
packed_seq_params – Parameters for packed sequences (optional)
global_valid_seqs – Global valid sequence count for loss normalization
global_valid_toks – Global valid token count for loss normalization
- Returns:
Function that takes output tensor and returns (loss, metrics) tuple
- Return type:
Callable
- class nemo_rl.models.megatron.train.LogprobsPostProcessor(
- cfg: nemo_rl.models.policy.PolicyConfig,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
Initialization
- __call__(
- data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- input_ids: torch.Tensor,
- cu_seqlens_padded: torch.Tensor,
Create a post-processing function that computes token log probabilities.
This function returns a processor that takes model logits and converts them to token-level log probabilities, handling both packed and unpacked sequences.
- Parameters:
data_dict – Batched data dictionary containing input sequences
input_ids – Processed input token IDs
cu_seqlens_padded – Cumulative sequence lengths for packed sequences
- Returns:
Function that takes output tensor and returns (dummy_loss, {“logprobs”: token_logprobs})
- Return type:
Callable
- class nemo_rl.models.megatron.train.TopkLogitsPostProcessor(
- cfg: nemo_rl.models.policy.PolicyConfig,
- k: int,
Initialization
- __call__(
- data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- cu_seqlens_padded: torch.Tensor,
Create a post-processing function that computes top-k logits and indices.
This function returns a processor that extracts the top-k highest logits and their corresponding vocabulary indices from model outputs. It handles tensor parallelism, context parallelism, and sequence packing.
- Parameters:
data_dict – Batched data dictionary
cu_seqlens_padded – Cumulative sequence lengths for packed sequences
- Returns:
Function that takes output tensor and returns (dummy_loss, {“topk_logits”: values, “topk_indices”: indices})
- Return type:
Callable
- nemo_rl.models.megatron.train.aggregate_training_statistics(
- all_mb_metrics: List[Dict[str, Any]],
- losses: List[float],
- data_parallel_group: torch.distributed.ProcessGroup,
Aggregate training statistics across microbatches and data-parallel ranks.
Computes a global loss by all-reducing per-gradient-buffer losses across the data-parallel group, then collects per-microbatch metrics into lists keyed by metric name.
- Parameters:
all_mb_metrics – List of metric dicts from each microbatch.
losses – List of per-gradient-buffer scalar losses on this rank.
data_parallel_group – The data-parallel process group for all-reduce.
- Returns:
mb_metrics: Dict mapping metric names to lists of values across microbatches.
global_loss: Tensor of losses summed across all data-parallel ranks.
- Return type:
Tuple of