core.datasets.data_schedule#

Module Contents#

Classes#

HybridCPDataLoaderWrapper

A wrapper class that wraps around an existing data_iterator. For every next call,

API#

class core.datasets.data_schedule.HybridCPDataLoaderWrapper(
data_iterator,
config,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

A wrapper class that wraps around an existing data_iterator. For every next call,

  1. Each DP rank pulls a batch of packed samples.

  2. Extracts the sequence lengths of each sub-sample and all-gathers across the DP group.

  3. Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler.

  4. Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all.

  5. Returns the assigned sub-samples to this rank.

Parameters:
  • data_iterator – The original data_iterator to wrap around

  • config – The config object containing the max_seqlen_per_dp_cp_rank

  • dp_cp_group – Data parallel context parallel group.

Initialization

__iter__()#

Return self as an iterator.

get_global_seqlens(
subsample_seqlens: torch.Tensor,
) List[int]#

Gathers the sequence lengths of all subsamples from all DP ranks. Each DP rank loads the same number of microbatches but each microbatch may have a different number of subsamples.

We find the number of subsamples each rank holds and then gather the sequence lengths of all subsamples from all ranks.

get_global_id_seqlens(num_local_subsamples, offsets, seqlens_gathered)#

Calculates the global ID for each subsample.

We assign a unique global ID to each subsample.

Returns: global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. global_ids_this_rank: list of global IDs locally present on this rank.

_gid_to_src_rank(gid: int, offsets: List[int]) int#
reroute_samples_to_hdp_ranks(
batch,
global_ids_this_rank,
global_id_seqlens,
sample_id_groups,
offsets,
)#

Reroutes the sub-samples to the correct rank after scheduling.

For each key in the batch dict, we perform an all-to-all communication to transfer the data to the correct ranks. Since all CP ranks within a DP group have the same data, we only need to transfer data between matching CP ranks.

unpack_batch(batch)#

Unpacks the packed samples into a list of sub-samples. Since each sub-sample may be routed to different DPxCP ranks, we unpack the sample here to avoid unnecessarily transferring the entire packed sample.

__next__() Any#

Get the next item from the dataset, pull scheduling metadata and return it.