bridge.data.loaders#

Module Contents#

Functions#

get_blend_and_blend_per_split

Determine dataset blends from command-line arguments or config files.

cyclic_iter

Create an infinite iterator from a finite iterable.

get_train_valid_test_num_samples

Calculate the number of samples for train, validation, and test sets.

build_train_valid_test_datasets

Build train, validation, and test datasets using a provider function.

build_train_valid_test_data_loaders

Build train, validation, and test data loaders.

build_train_valid_test_data_iterators

Build train, validation, and test data iterators.

setup_data_iterators

Set up data iterators, handling virtual pipeline parallelism if enabled.

API#

bridge.data.loaders.get_blend_and_blend_per_split(
data_paths: Optional[list[str]] = None,
data_args_path: Optional[str] = None,
per_split_data_args_path: Optional[str] = None,
train_data_paths: Optional[list[str]] = None,
valid_data_paths: Optional[list[str]] = None,
test_data_paths: Optional[list[str]] = None,
) tuple[Optional[list[str]], Optional[list[list[str]]]]#

Determine dataset blends from command-line arguments or config files.

Parses different ways dataset paths/weights can be specified (single list, per-split lists, config files) and returns the blend information.

Parameters:
  • data_paths – List of paths/weights for a single blended dataset.

  • data_args_path – Path to a file containing data paths/weights for a single blend.

  • per_split_data_args_path – Path to a JSON file containing train/valid/test splits, each with its own list of paths/weights.

  • train_data_paths – List of paths/weights specifically for the training split.

  • valid_data_paths – List of paths/weights specifically for the validation split.

  • test_data_paths – List of paths/weights specifically for the test split.

Returns:

  • blend: A list representing a single data blend, or None.

  • blend_per_split: A list containing blends for train, valid, test splits, or None. Only one of blend or blend_per_split will be non-None.

Return type:

A tuple (blend, blend_per_split)

bridge.data.loaders.cyclic_iter(iter: Iterable) Iterator#

Create an infinite iterator from a finite iterable.

bridge.data.loaders.get_train_valid_test_num_samples(
cfg: megatron.bridge.training.config.ConfigContainer,
) tuple[int, int, int]#

Calculate the number of samples for train, validation, and test sets.

Determines sample counts based on training iterations, global batch size, and evaluation interval/iterations specified in the config.

Parameters:

cfg – The main configuration container.

Returns:

A tuple (train_samples, valid_samples, test_samples).

bridge.data.loaders.build_train_valid_test_datasets(
cfg: megatron.bridge.training.config.ConfigContainer,
build_train_valid_test_datasets_provider: Callable,
) tuple[Any, Any, Any]#

Build train, validation, and test datasets using a provider function.

Parameters:
  • cfg – The main configuration container.

  • build_train_valid_test_datasets_provider – A function that takes train_val_test_num_samples and dataset_config and returns the datasets.

Returns:

A tuple (train_dataset, valid_dataset, test_dataset).

bridge.data.loaders.build_train_valid_test_data_loaders(
cfg: megatron.bridge.training.config.ConfigContainer,
train_state: megatron.bridge.training.state.TrainState,
build_train_valid_test_datasets_provider: Callable,
) tuple[Optional[torch.utils.data.DataLoader], Optional[torch.utils.data.DataLoader], Optional[torch.utils.data.DataLoader]]#

Build train, validation, and test data loaders.

First builds the datasets using the provided provider function, then constructs PyTorch DataLoaders with appropriate sampling and configuration.

Parameters:
  • cfg – The main configuration container.

  • train_state – The current training state.

  • build_train_valid_test_datasets_provider – A function to build the datasets.

Returns:

A tuple (train_dataloader, valid_dataloader, test_dataloader).

bridge.data.loaders.build_train_valid_test_data_iterators(
cfg: megatron.bridge.training.config.ConfigContainer,
train_state: megatron.bridge.training.state.TrainState,
build_train_valid_test_datasets_provider: Callable,
) tuple[Optional[megatron.core.rerun_state_machine.RerunDataIterator], Optional[megatron.core.rerun_state_machine.RerunDataIterator], Optional[megatron.core.rerun_state_machine.RerunDataIterator]]#

Build train, validation, and test data iterators.

Builds the data loaders first, then wraps them in appropriate iterators (e.g., RerunDataIterator, cyclic_iter) based on the configuration.

Parameters:
  • cfg – The main configuration container.

  • train_state – The current training state.

  • build_train_valid_test_datasets_provider – A function to build the datasets.

Returns:

A tuple (train_data_iterator, valid_data_iterator, test_data_iterator).

bridge.data.loaders.setup_data_iterators(
cfg: megatron.bridge.training.config.ConfigContainer,
train_state: megatron.bridge.training.state.TrainState,
model_length: int,
train_valid_test_datasets_provider: Callable,
) tuple[Union[Optional[megatron.core.rerun_state_machine.RerunDataIterator], list[Optional[megatron.core.rerun_state_machine.RerunDataIterator]]], Union[Optional[megatron.core.rerun_state_machine.RerunDataIterator], list[Optional[megatron.core.rerun_state_machine.RerunDataIterator]]], Union[Optional[megatron.core.rerun_state_machine.RerunDataIterator], list[Optional[megatron.core.rerun_state_machine.RerunDataIterator]]]]#

Set up data iterators, handling virtual pipeline parallelism if enabled.

Calls build_train_valid_test_data_iterators potentially multiple times if virtual pipeline parallelism is used, creating separate iterators for each virtual stage.

Parameters:
  • cfg – The main configuration container.

  • train_state – The current training state.

  • model_length – The number of model chunks (used for virtual pipeline parallelism).

  • train_valid_test_datasets_provider – A function to build the datasets.

Returns:

A tuple (train_data_iterator, valid_data_iterator, test_data_iterator). Each element can be a single iterator or a list of iterators if virtual pipeline parallelism is enabled.