bridge.training.utils.batch_utils#

Module Contents#

Functions#

get_batch_on_this_tp_rank

Get a batch from the data iterator, handling TP broadcasting.

API#

bridge.training.utils.batch_utils.get_batch_on_this_tp_rank(
data_iterator: Iterable,
cfg: megatron.bridge.training.config.ConfigContainer,
use_mtp: bool = False,
) dict[str, torch.Tensor]#

Get a batch from the data iterator, handling TP broadcasting.

This is a generic helper used by multiple recipes. The implementation is identical to the prior one in gpt_step.py.