bridge.data.mimo.base_provider#
Base class for MIMO dataset providers.
Module Contents#
Classes#
Abstract base class for MIMO dataset providers. |
API#
- class bridge.data.mimo.base_provider.MimoDatasetProvider#
Bases:
megatron.bridge.training.config.DatasetProviderAbstract base class for MIMO dataset providers.
All MIMO dataset providers must inherit from this class and implement the required methods. This ensures a consistent interface for MIMO data loading.
Required methods: - build_datasets: Build train/valid/test datasets - get_collate_fn: Return the collate function for batching
.. rubric:: Example
class MyMimoProvider(MimoDatasetProvider): … def build_datasets(self, context): … # Build and return datasets … return train_ds, valid_ds, test_ds … … def get_collate_fn(self): … # Return collate function … return my_collate_fn
- abstractmethod build_datasets(
- context: megatron.bridge.training.config.DatasetBuildContext,
Build train, validation, and test datasets.
- Parameters:
context – Build context with sample counts.
- Returns:
Tuple of (train_dataset, valid_dataset, test_dataset). Any element can be None if not needed.
- abstractmethod get_collate_fn() Callable#
Return the collate function for batching.
The collate function should handle the modality_inputs dict and batch them appropriately for the model.
- Returns:
Callable that takes a list of samples and returns a batch dict.