bridge.diffusion.data.common.diffusion_energon_datamodule#
Module Contents#
Classes#
A PyTorch Lightning DataModule for handling multimodal datasets with images and text. |
API#
- class bridge.diffusion.data.common.diffusion_energon_datamodule.DiffusionDataModuleConfig#
Bases:
megatron.bridge.data.utils.DatasetProvider- path: str#
None
- seq_length: int#
None
- micro_batch_size: int#
None
- packing_buffer_size: int#
None
- global_batch_size: int#
None
- num_workers: torch.int_repr#
None
- task_encoder_seq_length: int#
None
- dataloader_type: str#
‘external’
- use_train_split_for_val: bool#
False
- build_datasets(
- context: megatron.bridge.data.utils.DatasetBuildContext,
- class bridge.diffusion.data.common.diffusion_energon_datamodule.DiffusionDataModule(
- path: str,
- seq_length: int = 2048,
- micro_batch_size: int = 1,
- global_batch_size: int = 8,
- num_workers: int = 1,
- pin_memory: bool = True,
- packing_buffer_size: int = None,
- task_encoder: megatron.energon.DefaultTaskEncoder = None,
- use_train_split_for_val: bool = False,
Bases:
megatron.bridge.data.energon.base_energon_datamodule.EnergonMultiModalDataModuleA PyTorch Lightning DataModule for handling multimodal datasets with images and text.
This data module is designed to work with multimodal datasets that involve both images and text. It provides a seamless interface to load training and validation data, manage batching, and handle the state of the data pipeline across training epochs. The module integrates with the Megatron-Energon framework for efficient data handling in large-scale distributed training.
Attributes: path (str): Path to the energon dataset. tokenizer (Tokenizer): The tokenizer used for processing text. image_processor (ImageProcessor): The image processor used for preprocessing images. seq_length (int): The maximum sequence length for tokenized text. micro_batch_size (int): The batch size for training and validation. num_workers (int): Number of workers for data loading. pin_memory (bool): Whether to pin memory in the DataLoader. multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples. task_encoder (MultiModalTaskEncoder): Encoder responsible for encoding and batching samples. init_global_step (int): The initial global step for the trainer, used for resuming training. data_sampler (SequentialMegatronSampler): Sampler responsible for generating sequential samples. train_dataloader_object (Optional): The DataLoader object for training data. val_dataloader_object (Optional): The DataLoader object for validation data.
Initialization
Initialize the SimpleMultiModalDataModule.
Parameters: path (str): Path to the dataset. tokenizer (Tokenizer): The tokenizer used for processing text. image_processor (ImageProcessor): The image processor used for preprocessing images. seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048. micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1. num_workers (int, optional): Number of workers for data loading. Defaults to 1. pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True.
- datasets_provider(
- worker_config,
- split: Literal[train, val] = 'val',
Provide the dataset for training or validation.
This method retrieves the dataset for the specified split (either ‘train’ or ‘val’) and configures it according to the worker configuration.
Parameters: worker_config: Configuration for the data loader workers. split (Literal[‘train’, ‘val’], optional): The data split to retrieve (‘train’ or ‘val’). Defaults to ‘val’.
Returns: Dataset: The dataset configured for the specified split.
- val_dataloader()#
Configure the validation DataLoader.
This method configures the DataLoader for validation data.
Parameters: worker_config: Configuration for the data loader workers.
Returns: DataLoader: The DataLoader for validation data.
- load_state_dict(state_dict: Dict[str, Any]) None#
Load the state of the data module from a checkpoint.
This method is called when loading a checkpoint. It restores the state of the data module, including the state of the dataloader and the number of consumed samples.
Parameters: state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module.