bridge.training.gpt_step#

Module Contents#

Functions#

get_packed_seq_params

Extract packed sequence parameters from the batch.

get_batch_from_iterator

Get a batch of data from the iterator.

get_batch_on_this_tp_rank

Get a batch from the data iterator, handling TP broadcasting.

get_batch

Generate a batch.

forward_step

Forward training step.

Data#

API#

bridge.training.gpt_step.logger#

‘getLogger(…)’

bridge.training.gpt_step.get_packed_seq_params(
batch: dict[str, torch.Tensor],
) megatron.core.packed_seq_params.PackedSeqParams#

Extract packed sequence parameters from the batch.

Creates and returns a PackedSeqParams object with appropriate parameters for packed sequence processing.

Parameters:

batch – Input batch containing packed sequence information

Returns:

Parameters for packed sequence processing

Return type:

PackedSeqParams

bridge.training.gpt_step.get_batch_from_iterator(
data_iterator: Iterable,
use_mtp: bool = False,
) dict[str, torch.Tensor]#

Get a batch of data from the iterator.

Parameters:
  • data_iterator – The data iterator to get the batch from.

  • use_mtp – Whether Multi-Token Prediction layers are enabled.

Returns:

A dictionary containing the batch data.

Return type:

dict[str, torch.Tensor]

bridge.training.gpt_step.get_batch_on_this_tp_rank(
data_iterator: Iterable,
cfg: megatron.bridge.training.config.ConfigContainer,
use_mtp: bool = False,
) dict[str, torch.Tensor]#

Get a batch from the data iterator, handling TP broadcasting.

On TP rank 0, it fetches the next batch from the iterator and broadcasts the necessary tensors to other TP ranks based on the pipeline stage. On other TP ranks, it allocates tensors and receives the broadcasted data.

Parameters:
  • data_iterator – The data iterator.

  • cfg – The configuration container.

  • use_mtp – Whether Multi-Token Prediction layers are enabled.

Returns:

A dictionary containing the batch data for the current rank.

bridge.training.gpt_step.get_batch(
data_iterator: Iterable,
cfg: megatron.bridge.training.config.ConfigContainer,
use_mtp: bool = False,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]#

Generate a batch.

Parameters:
  • data_iterator – Input data iterator

  • cfg – Configuration container

  • use_mtp – Whether Multi-Token Prediction layers are enabled

Returns:

tuple of tensors containing tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens, cu_seqlens_argmin, and max_seqlen

bridge.training.gpt_step.forward_step(
state: megatron.bridge.training.state.GlobalState,
data_iterator: Iterable,
model: megatron.core.models.gpt.GPTModel,
) tuple[torch.Tensor, functools.partial]#

Forward training step.

Parameters:
  • state – Global state for the run

  • data_iterator – Input data iterator

  • model – The GPT Model

Returns:

tuple containing the output tensor and the loss function