bridge.training.gpt_step
#
Module Contents#
Functions#
Extract packed sequence parameters from the batch. |
|
Get a batch of data from the iterator. |
|
Get a batch from the data iterator, handling TP broadcasting. |
|
Generate a batch. |
|
Forward training step. |
|
Create a partial loss function with the specified configuration. |
Data#
API#
- bridge.training.gpt_step.logger#
βgetLogger(β¦)β
- bridge.training.gpt_step.get_packed_seq_params(
- batch: dict[str, torch.Tensor],
Extract packed sequence parameters from the batch.
Creates and returns a PackedSeqParams object with appropriate parameters for packed sequence processing.
- Parameters:
batch β Input batch containing packed sequence information
- Returns:
Parameters for packed sequence processing
- Return type:
PackedSeqParams
- bridge.training.gpt_step.get_batch_from_iterator(
- data_iterator: Iterable,
- use_mtp: bool = False,
- skip_getting_attention_mask_from_dataset: bool = True,
Get a batch of data from the iterator.
- Parameters:
data_iterator β The data iterator to get the batch from.
use_mtp β Whether Multi-Token Prediction layers are enabled.
skip_getting_attention_mask_from_dataset β If set, the dataset will pass a None attention mask.
- Returns:
A dictionary containing the batch data.
- Return type:
dict[str, torch.Tensor]
- bridge.training.gpt_step.get_batch_on_this_tp_rank(
- data_iterator: Iterable,
- cfg: megatron.bridge.training.config.ConfigContainer,
- use_mtp: bool = False,
Get a batch from the data iterator, handling TP broadcasting.
On TP rank 0, it fetches the next batch from the iterator and broadcasts the necessary tensors to other TP ranks based on the pipeline stage. On other TP ranks, it allocates tensors and receives the broadcasted data.
- Parameters:
data_iterator β The data iterator.
cfg β The configuration container.
use_mtp β Whether Multi-Token Prediction layers are enabled.
- Returns:
A dictionary containing the batch data for the current rank.
- bridge.training.gpt_step.get_batch(
- data_iterator: Iterable,
- cfg: megatron.bridge.training.config.ConfigContainer,
- use_mtp: bool = False,
Generate a batch.
- Parameters:
data_iterator β Input data iterator
cfg β Configuration container
use_mtp β Whether Multi-Token Prediction layers are enabled
- Returns:
tuple of tensors containing tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens, cu_seqlens_argmin, and max_seqlen
- bridge.training.gpt_step.forward_step(
- state: megatron.bridge.training.state.GlobalState,
- data_iterator: Iterable,
- model: megatron.core.models.gpt.GPTModel,
- return_schedule_plan: bool = False,
Forward training step.
- Parameters:
state β Global state for the run
data_iterator β Input data iterator
model β The GPT Model
return_schedule_plan (bool) β Whether to return the schedule plan instead of the output tensor
- Returns:
tuple containing the output tensor and the loss function
- bridge.training.gpt_step._create_loss_function(
- loss_mask: torch.Tensor,
- check_for_nan_in_loss: bool,
- check_for_spiky_loss: bool,
Create a partial loss function with the specified configuration.
- Parameters:
loss_mask β Used to mask out some portions of the loss
check_for_nan_in_loss β Whether to check for NaN values in the loss
check_for_spiky_loss β Whether to check for spiky loss values
- Returns:
A partial function that can be called with output_tensor to compute the loss