bridge.data.mimo.base_provider#

Base class for MIMO dataset providers.

Module Contents#

Classes#

MimoDatasetProvider

Abstract base class for MIMO dataset providers.

API#

class bridge.data.mimo.base_provider.MimoDatasetProvider#

Bases: megatron.bridge.training.config.DatasetProvider

Abstract base class for MIMO dataset providers.

All MIMO dataset providers must inherit from this class and implement the required methods. This ensures a consistent interface for MIMO data loading.

Required methods: - build_datasets: Build train/valid/test datasets - get_collate_fn: Return the collate function for batching

.. rubric:: Example

class MyMimoProvider(MimoDatasetProvider): … def build_datasets(self, context): … # Build and return datasets … return train_ds, valid_ds, test_ds … … def get_collate_fn(self): … # Return collate function … return my_collate_fn

abstractmethod build_datasets(
context: megatron.bridge.training.config.DatasetBuildContext,
) Tuple[Optional[torch.utils.data.Dataset], Optional[torch.utils.data.Dataset], Optional[torch.utils.data.Dataset]]#

Build train, validation, and test datasets.

Parameters:

context – Build context with sample counts.

Returns:

Tuple of (train_dataset, valid_dataset, test_dataset). Any element can be None if not needed.

abstractmethod get_collate_fn() Callable#

Return the collate function for batching.

The collate function should handle the modality_inputs dict and batch them appropriately for the model.

Returns:

Callable that takes a list of samples and returns a batch dict.