bridge.training.gpt_step
#
Module Contents#
Functions#
Extract packed sequence parameters from the batch. |
|
Get a batch of data from the iterator. |
|
Get a batch from the data iterator, handling TP broadcasting. |
|
Generate a batch. |
|
Forward training step. |
Data#
API#
- bridge.training.gpt_step.logger#
‘getLogger(…)’
- bridge.training.gpt_step.get_packed_seq_params(
- batch: dict[str, torch.Tensor],
Extract packed sequence parameters from the batch.
Creates and returns a PackedSeqParams object with appropriate parameters for packed sequence processing.
- Parameters:
batch – Input batch containing packed sequence information
- Returns:
Parameters for packed sequence processing
- Return type:
PackedSeqParams
- bridge.training.gpt_step.get_batch_from_iterator(
- data_iterator: Iterable,
- use_mtp: bool = False,
Get a batch of data from the iterator.
- Parameters:
data_iterator – The data iterator to get the batch from.
use_mtp – Whether Multi-Token Prediction layers are enabled.
- Returns:
A dictionary containing the batch data.
- Return type:
dict[str, torch.Tensor]
- bridge.training.gpt_step.get_batch_on_this_tp_rank(
- data_iterator: Iterable,
- cfg: megatron.bridge.training.config.ConfigContainer,
- use_mtp: bool = False,
Get a batch from the data iterator, handling TP broadcasting.
On TP rank 0, it fetches the next batch from the iterator and broadcasts the necessary tensors to other TP ranks based on the pipeline stage. On other TP ranks, it allocates tensors and receives the broadcasted data.
- Parameters:
data_iterator – The data iterator.
cfg – The configuration container.
use_mtp – Whether Multi-Token Prediction layers are enabled.
- Returns:
A dictionary containing the batch data for the current rank.
- bridge.training.gpt_step.get_batch(
- data_iterator: Iterable,
- cfg: megatron.bridge.training.config.ConfigContainer,
- use_mtp: bool = False,
Generate a batch.
- Parameters:
data_iterator – Input data iterator
cfg – Configuration container
use_mtp – Whether Multi-Token Prediction layers are enabled
- Returns:
tuple of tensors containing tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens, cu_seqlens_argmin, and max_seqlen
- bridge.training.gpt_step.forward_step(
- state: megatron.bridge.training.state.GlobalState,
- data_iterator: Iterable,
- model: megatron.core.models.gpt.GPTModel,
Forward training step.
- Parameters:
state – Global state for the run
data_iterator – Input data iterator
model – The GPT Model
- Returns:
tuple containing the output tensor and the loss function