bridge.training.gpt_step
#
Module Contents#
Functions#
Get a batch of data from the iterator. |
|
Generate a batch. |
|
Forward training step. |
|
Create a partial loss function with the specified configuration. |
Data#
API#
- bridge.training.gpt_step.logger#
βgetLogger(β¦)β
- bridge.training.gpt_step.get_batch_from_iterator(
- data_iterator: Iterable,
- use_mtp: bool = False,
- skip_getting_attention_mask_from_dataset: bool = True,
Get a batch of data from the iterator.
- Parameters:
data_iterator β The data iterator to get the batch from.
use_mtp β Whether Multi-Token Prediction layers are enabled.
skip_getting_attention_mask_from_dataset β If set, the dataset will pass a None attention mask.
- Returns:
A dictionary containing the batch data.
- Return type:
dict[str, torch.Tensor]
- bridge.training.gpt_step.get_batch(
- data_iterator: Iterable,
- cfg: megatron.bridge.training.config.ConfigContainer,
- use_mtp: bool = False,
Generate a batch.
- Parameters:
data_iterator β Input data iterator
cfg β Configuration container
use_mtp β Whether Multi-Token Prediction layers are enabled
- Returns:
tuple of tensors containing tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens, cu_seqlens_argmin, and max_seqlen
- bridge.training.gpt_step.forward_step(
- state: megatron.bridge.training.state.GlobalState,
- data_iterator: Iterable,
- model: megatron.core.models.gpt.GPTModel,
- return_schedule_plan: bool = False,
Forward training step.
- Parameters:
state β Global state for the run
data_iterator β Input data iterator
model β The GPT Model
return_schedule_plan (bool) β Whether to return the schedule plan instead of the output tensor
- Returns:
tuple containing the output tensor and the loss function
- bridge.training.gpt_step._create_loss_function(
- loss_mask: torch.Tensor,
- check_for_nan_in_loss: bool,
- check_for_spiky_loss: bool,
Create a partial loss function with the specified configuration.
Kept here for backward compatibility with tests and callers that patch
megatron.bridge.training.gpt_step.masked_next_token_loss
.