bridge.data.megatron_mimo.loaders#

Data loader utilities for MegatronMIMO training.

Module Contents#

Functions#

build_megatron_mimo_data_loaders

Build MegatronMIMO data loaders with globally consistent sampling.

API#

bridge.data.megatron_mimo.loaders.build_megatron_mimo_data_loaders(
cfg: megatron.bridge.training.config.ConfigContainer,
train_state: megatron.bridge.training.state.TrainState,
megatron_mimo_provider: megatron.bridge.training.config.DatasetProvider,
train_samples: int,
valid_samples: int,
test_samples: int,
) Tuple[Optional[torch.utils.data.DataLoader], Optional[torch.utils.data.DataLoader], Optional[torch.utils.data.DataLoader]]#

Build MegatronMIMO data loaders with globally consistent sampling.

All data-loading ranks receive identical global micro-batches (the sampler uses dp_size=1). Per-module DP sub-sharding is deferred to slice_batch_for_megatron_mimo in the forward step, ensuring consistency with the BridgeCommunicator’s fan-in/fan-out routing for asymmetric DP configs. Only ranks that need data (first/last PP stage) will get non-None loaders.

Parameters:
  • cfg – Configuration container with MegatronMIMOProvider as cfg.model.

  • train_state – Current training state.

  • megatron_mimo_provider – MegatronMIMO dataset provider (e.g., MockMegatronMIMOProvider) with get_collate_fn() method.

  • train_samples – Number of training samples.

  • valid_samples – Number of validation samples.

  • test_samples – Number of test samples.

Returns:

Tuple of (train_loader, valid_loader, test_loader). Returns (None, None, None) if this rank doesn’t need data.

Raises:

ValueError – If cfg.model is not MegatronMIMOProvider or megatron_mimo_parallelism_config is None.

.. rubric:: Example

from megatron.bridge.data.megatron_mimo import MockMegatronMIMOProvider, build_megatron_mimo_data_loaders provider = MockMegatronMIMOProvider( … seq_length=2048, … processor_paths={“vision”: “openai/clip-vit-large-patch14”}, … tokenizer_path=”meta-llama/Llama-2-7b-hf”, … special_token_ids={“vision”: 32000}, … modality_configs={“vision”: {“type”: “image”, “width”: 224, “height”: 224}}, … ) train_loader, valid_loader, test_loader = build_megatron_mimo_data_loaders( … cfg, train_state, provider, … train_samples=10000, valid_samples=1000, test_samples=1000, … )