bridge.data.megatron_mimo.base_provider#

Base class for MegatronMIMO dataset providers.

Module Contents#

Classes#

MegatronMIMODatasetProvider

Abstract base class for MegatronMIMO dataset providers.

API#

class bridge.data.megatron_mimo.base_provider.MegatronMIMODatasetProvider#

Bases: megatron.bridge.training.config.DatasetProvider

Abstract base class for MegatronMIMO dataset providers.

All MegatronMIMO dataset providers must inherit from this class and implement the required methods. This ensures a consistent interface for MegatronMIMO data loading.

Required methods: - build_datasets: Build train/valid/test datasets - get_collate_fn: Return the collate function for batching

.. rubric:: Example

class MyMegatronMIMOProvider(MegatronMIMODatasetProvider): … def build_datasets(self, context): … # Build and return datasets … return train_ds, valid_ds, test_ds … … def get_collate_fn(self): … # Return collate function … return my_collate_fn

dataloader_type: Optional[Literal[single, cyclic, external]]#

‘single’

Dataloader type: ‘single’ (default, sequential + resume-aware), ‘cyclic’ (shuffled across epochs, also resume-aware), or ‘external’ (pass-through).

abstractmethod build_datasets(
context: megatron.bridge.training.config.DatasetBuildContext,
) Tuple[Optional[torch.utils.data.Dataset], Optional[torch.utils.data.Dataset], Optional[torch.utils.data.Dataset]]#

Build train, validation, and test datasets.

Parameters:

context – Build context with sample counts.

Returns:

Tuple of (train_dataset, valid_dataset, test_dataset). Any element can be None if not needed.

abstractmethod get_collate_fn() Callable#

Return the collate function for batching.

The collate function should handle the modality_inputs dict and batch them appropriately for the model.

Returns:

Callable that takes a list of samples and returns a batch dict.