core.datasets.data_schedule#
Module Contents#
Classes#
A wrapper class that wraps around an existing data_iterator. For every next call, |
API#
- class core.datasets.data_schedule.HybridCPDataLoaderWrapper(
- data_iterator,
- config,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
A wrapper class that wraps around an existing data_iterator. For every next call,
Each DP rank pulls a batch of packed samples.
Extracts the sequence lengths of each sub-sample and all-gathers across the DP group.
Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler.
Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all.
Returns the assigned sub-samples to this rank.
- Parameters:
data_iterator – The original data_iterator to wrap around
config – The config object containing the max_seqlen_per_dp_cp_rank
dp_cp_group – Data parallel context parallel group.
Initialization
- __iter__()#
Return self as an iterator.
- get_global_seqlens(
- subsample_seqlens: torch.Tensor,
Gathers the sequence lengths of all subsamples from all DP ranks. Each DP rank loads the same number of microbatches but each microbatch may have a different number of subsamples.
We find the number of subsamples each rank holds and then gather the sequence lengths of all subsamples from all ranks.
- get_global_id_seqlens(num_local_subsamples, offsets, seqlens_gathered)#
Calculates the global ID for each subsample.
We assign a unique global ID to each subsample.
Returns: global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. global_ids_this_rank: list of global IDs locally present on this rank.
- _gid_to_src_rank(gid: int, offsets: List[int]) int#
- reroute_samples_to_hdp_ranks(
- batch,
- global_ids_this_rank,
- global_id_seqlens,
- sample_id_groups,
- offsets,
Reroutes the sub-samples to the correct rank after scheduling.
For each key in the batch dict, we perform an all-to-all communication to transfer the data to the correct ranks. Since all CP ranks within a DP group have the same data, we only need to transfer data between matching CP ranks.
- unpack_batch(batch)#
Unpacks the packed samples into a list of sub-samples. Since each sub-sample may be routed to different DPxCP ranks, we unpack the sample here to avoid unnecessarily transferring the entire packed sample.
- __next__() Any#
Get the next item from the dataset, pull scheduling metadata and return it.