bridge.data.megatron_mimo.loaders#
Data loader utilities for MegatronMIMO training.
Module Contents#
Functions#
Build MegatronMIMO data loaders with globally consistent sampling. |
API#
- bridge.data.megatron_mimo.loaders.build_megatron_mimo_data_loaders(
- cfg: megatron.bridge.training.config.ConfigContainer,
- train_state: megatron.bridge.training.state.TrainState,
- megatron_mimo_provider: megatron.bridge.training.config.DatasetProvider,
- train_samples: int,
- valid_samples: int,
- test_samples: int,
Build MegatronMIMO data loaders with globally consistent sampling.
All data-loading ranks receive identical global micro-batches (the sampler uses dp_size=1). Per-module DP sub-sharding is deferred to
slice_batch_for_megatron_mimoin the forward step, ensuring consistency with the BridgeCommunicator’s fan-in/fan-out routing for asymmetric DP configs. Only ranks that need data (first/last PP stage) will get non-None loaders.- Parameters:
cfg – Configuration container with MegatronMIMOProvider as cfg.model.
train_state – Current training state.
megatron_mimo_provider – MegatronMIMO dataset provider (e.g., MockMegatronMIMOProvider) with get_collate_fn() method.
train_samples – Number of training samples.
valid_samples – Number of validation samples.
test_samples – Number of test samples.
- Returns:
Tuple of (train_loader, valid_loader, test_loader). Returns (None, None, None) if this rank doesn’t need data.
- Raises:
ValueError – If cfg.model is not MegatronMIMOProvider or megatron_mimo_parallelism_config is None.
.. rubric:: Example
from megatron.bridge.data.megatron_mimo import MockMegatronMIMOProvider, build_megatron_mimo_data_loaders provider = MockMegatronMIMOProvider( … seq_length=2048, … processor_paths={“vision”: “openai/clip-vit-large-patch14”}, … tokenizer_path=”meta-llama/Llama-2-7b-hf”, … special_token_ids={“vision”: 32000}, … modality_configs={“vision”: {“type”: “image”, “width”: 224, “height”: 224}}, … ) train_loader, valid_loader, test_loader = build_megatron_mimo_data_loaders( … cfg, train_state, provider, … train_samples=10000, valid_samples=1000, test_samples=1000, … )