bridge.data.samplers#

Dataloaders.

Module Contents#

Classes#

MegatronPretrainingSampler

Batch sampler for Megatron pretraining (sequential, non-random).

RandomSeedDataset

A dataset wrapper that sets the random seed based on epoch and index.

MegatronPretrainingRandomSampler

Batch sampler for Megatron pretraining (randomized).

Functions#

build_pretraining_data_loader

Build a dataloader for pretraining.

API#

bridge.data.samplers.build_pretraining_data_loader(
dataset: torch.utils.data.Dataset,
consumed_samples: int,
dataloader_type: str,
micro_batch_size: int,
num_workers: int,
data_sharding: bool,
worker_init_fn: Optional[Callable] = None,
collate_fn: Optional[Callable] = None,
pin_memory: bool = True,
persistent_workers: bool = False,
data_parallel_rank: int = 0,
data_parallel_size: int = 1,
drop_last: Optional[bool] = True,
) Optional[torch.utils.data.DataLoader]#

Build a dataloader for pretraining.

Selects the appropriate sampler (MegatronPretrainingSampler or MegatronPretrainingRandomSampler) based on dataloader_type and constructs a PyTorch DataLoader.

Parameters:
  • dataset – The dataset to load data from.

  • consumed_samples – The number of samples already consumed (for resuming).

  • dataloader_type – Type of dataloader, ‘single’ or ‘cyclic’. ‘external’ passes the dataset through directly.

  • micro_batch_size – The batch size per GPU.

  • num_workers – Number of workers for the DataLoader.

  • data_sharding – Whether data sharding is enabled (used for random sampler).

  • worker_init_fn – Optional function to initialize workers.

  • collate_fn – Optional custom collate function.

  • pin_memory – Whether to pin memory for the DataLoader.

  • persistent_workers – Whether to use persistent workers.

  • drop_last – Whether to drop last batch.

Returns:

A PyTorch DataLoader instance, or the dataset itself if dataloader_type is ‘external’, or None if the input dataset is None.

Raises:

Exception – If an unsupported dataloader_type is provided.

class bridge.data.samplers.MegatronPretrainingSampler(
total_samples: int,
consumed_samples: int,
micro_batch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
)#

Batch sampler for Megatron pretraining (sequential, non-random).

Provides indices for microbatches for a specific data parallel rank, ensuring that data is processed sequentially across ranks and iterations.

Parameters:
  • total_samples – Total number of samples in the dataset.

  • consumed_samples – Number of samples already consumed (for resuming).

  • micro_batch_size – Batch size per GPU.

  • data_parallel_rank – Rank of the current GPU in the data parallel group.

  • data_parallel_size – Total number of GPUs in the data parallel group.

  • drop_last (bool, optional) – If True, drops the last incomplete batch. Defaults to True.

Initialization

__len__() int#

Return the total number of samples.

get_start_end_idx() tuple[int, int]#

Calculate the start and end index for the current rank’s microbatch.

__iter__() Iterator[list[int]]#

Yields lists of indices for each microbatch assigned to this rank.

class bridge.data.samplers.RandomSeedDataset(dataset: torch.utils.data.Dataset, seed: int)#

Bases: torch.utils.data.Dataset

A dataset wrapper that sets the random seed based on epoch and index.

Ensures reproducibility for random operations within the dataset’s getitem when using multiple workers.

Parameters:
  • dataset – The base dataset to wrap.

  • seed – The base random seed.

Initialization

Initialize RandomSeedDataset.

__len__() int#

Return the length of the base dataset.

set_epoch(epoch: int) None#

Set the current epoch number to adjust the random seed.

__getitem__(idx: int) Any#

Get an item from the dataset, setting the random seed first.

class bridge.data.samplers.MegatronPretrainingRandomSampler(
dataset: torch.utils.data.Dataset,
total_samples: int,
consumed_samples: int,
micro_batch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
data_sharding: bool,
)#

Batch sampler for Megatron pretraining (randomized).

Provides indices for microbatches for a specific data parallel rank, randomizing the order of samples within each epoch while supporting resumption. Handles data sharding across ranks if enabled.

Parameters:
  • dataset – The dataset (potentially wrapped with RandomSeedDataset).

  • total_samples – Total number of samples in the dataset.

  • consumed_samples – Number of samples already consumed (for resuming).

  • micro_batch_size – Batch size per GPU.

  • data_parallel_rank – Rank of the current GPU in the data parallel group.

  • data_parallel_size – Total number of GPUs in the data parallel group.

  • data_sharding – Whether data sharding is enabled.

Initialization

__len__() int#

Return the total number of samples.

__iter__() Iterator[list[int]]#

Yields lists of indices for each microbatch assigned to this rank.

Handles randomization within an epoch and data sharding.