bridge.diffusion.data.common.diffusion_energon_datamodule#

Module Contents#

Classes#

DiffusionDataModuleConfig

DiffusionDataModule

A PyTorch Lightning DataModule for handling multimodal datasets with images and text.

API#

class bridge.diffusion.data.common.diffusion_energon_datamodule.DiffusionDataModuleConfig#

Bases: megatron.bridge.data.utils.DatasetProvider

path: str#

None

seq_length: int#

None

micro_batch_size: int#

None

packing_buffer_size: int#

None

global_batch_size: int#

None

num_workers: torch.int_repr#

None

task_encoder_seq_length: int#

None

dataloader_type: str#

‘external’

use_train_split_for_val: bool#

False

build_datasets(
context: megatron.bridge.data.utils.DatasetBuildContext,
)#
class bridge.diffusion.data.common.diffusion_energon_datamodule.DiffusionDataModule(
path: str,
seq_length: int = 2048,
micro_batch_size: int = 1,
global_batch_size: int = 8,
num_workers: int = 1,
pin_memory: bool = True,
packing_buffer_size: int = None,
task_encoder: megatron.energon.DefaultTaskEncoder = None,
use_train_split_for_val: bool = False,
)#

Bases: megatron.bridge.data.energon.base_energon_datamodule.EnergonMultiModalDataModule

A PyTorch Lightning DataModule for handling multimodal datasets with images and text.

This data module is designed to work with multimodal datasets that involve both images and text. It provides a seamless interface to load training and validation data, manage batching, and handle the state of the data pipeline across training epochs. The module integrates with the Megatron-Energon framework for efficient data handling in large-scale distributed training.

Attributes: path (str): Path to the energon dataset. tokenizer (Tokenizer): The tokenizer used for processing text. image_processor (ImageProcessor): The image processor used for preprocessing images. seq_length (int): The maximum sequence length for tokenized text. micro_batch_size (int): The batch size for training and validation. num_workers (int): Number of workers for data loading. pin_memory (bool): Whether to pin memory in the DataLoader. multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples. task_encoder (MultiModalTaskEncoder): Encoder responsible for encoding and batching samples. init_global_step (int): The initial global step for the trainer, used for resuming training. data_sampler (SequentialMegatronSampler): Sampler responsible for generating sequential samples. train_dataloader_object (Optional): The DataLoader object for training data. val_dataloader_object (Optional): The DataLoader object for validation data.

Initialization

Initialize the SimpleMultiModalDataModule.

Parameters: path (str): Path to the dataset. tokenizer (Tokenizer): The tokenizer used for processing text. image_processor (ImageProcessor): The image processor used for preprocessing images. seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048. micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1. num_workers (int, optional): Number of workers for data loading. Defaults to 1. pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True.

datasets_provider(
worker_config,
split: Literal[train, val] = 'val',
)#

Provide the dataset for training or validation.

This method retrieves the dataset for the specified split (either ‘train’ or ‘val’) and configures it according to the worker configuration.

Parameters: worker_config: Configuration for the data loader workers. split (Literal[‘train’, ‘val’], optional): The data split to retrieve (‘train’ or ‘val’). Defaults to ‘val’.

Returns: Dataset: The dataset configured for the specified split.

val_dataloader()#

Configure the validation DataLoader.

This method configures the DataLoader for validation data.

Parameters: worker_config: Configuration for the data loader workers.

Returns: DataLoader: The DataLoader for validation data.

load_state_dict(state_dict: Dict[str, Any]) None#

Load the state of the data module from a checkpoint.

This method is called when loading a checkpoint. It restores the state of the data module, including the state of the dataloader and the number of consumed samples.

Parameters: state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module.