bridge.data.megatron_mimo.base_provider#
Base class for MegatronMIMO dataset providers.
Module Contents#
Classes#
Abstract base class for MegatronMIMO dataset providers. |
API#
- class bridge.data.megatron_mimo.base_provider.MegatronMIMODatasetProvider#
Bases:
megatron.bridge.training.config.DatasetProviderAbstract base class for MegatronMIMO dataset providers.
All MegatronMIMO dataset providers must inherit from this class and implement the required methods. This ensures a consistent interface for MegatronMIMO data loading.
Required methods: - build_datasets: Build train/valid/test datasets - get_collate_fn: Return the collate function for batching
.. rubric:: Example
class MyMegatronMIMOProvider(MegatronMIMODatasetProvider): … def build_datasets(self, context): … # Build and return datasets … return train_ds, valid_ds, test_ds … … def get_collate_fn(self): … # Return collate function … return my_collate_fn
- dataloader_type: Optional[Literal[single, cyclic, external]]#
‘single’
Dataloader type: ‘single’ (default, sequential + resume-aware), ‘cyclic’ (shuffled across epochs, also resume-aware), or ‘external’ (pass-through).
- abstractmethod build_datasets(
- context: megatron.bridge.training.config.DatasetBuildContext,
Build train, validation, and test datasets.
- Parameters:
context – Build context with sample counts.
- Returns:
Tuple of (train_dataset, valid_dataset, test_dataset). Any element can be None if not needed.
- abstractmethod get_collate_fn() Callable#
Return the collate function for batching.
The collate function should handle the modality_inputs dict and batch them appropriately for the model.
- Returns:
Callable that takes a list of samples and returns a batch dict.