bridge.diffusion.data.flux.flux_mock_datamodule#

Mock data module for FLUX model training.

Module Contents#

Classes#

_MockT2IDataset

A mock dataset class for text-to-image tasks, simulating data samples for training and testing.

FluxMockDataModuleConfig

Configuration for FLUX mock data module.

Functions#

_collate_fn

Collate function to batch samples from _MockT2IDataset.

API#

class bridge.diffusion.data.flux.flux_mock_datamodule._MockT2IDataset(
image_H: int = 1024,
image_W: int = 1024,
length: int = 100000,
image_precached: bool = True,
text_precached: bool = True,
prompt_seq_len: int = 512,
pooled_prompt_dim: int = 768,
context_dim: int = 4096,
vae_scale_factor: int = 8,
vae_channels: int = 16,
)#

Bases: torch.utils.data.Dataset

A mock dataset class for text-to-image tasks, simulating data samples for training and testing.

This dataset generates synthetic data for both image and text inputs, with options to use pre-cached latent representations or raw data. The class is designed for use in testing and prototyping machine learning models.

.. attribute:: image_H

Height of the generated images.

Type:

int

.. attribute:: image_W

Width of the generated images.

Type:

int

.. attribute:: length

Total number of samples in the dataset.

Type:

int

.. attribute:: image_precached

Whether to use pre-cached latent representations for images.

Type:

bool

.. attribute:: text_precached

Whether to use pre-cached embeddings for text.

Type:

bool

.. attribute:: prompt_seq_len

Sequence length for text prompts.

Type:

int

.. attribute:: pooled_prompt_dim

Dimensionality of pooled text embeddings.

Type:

int

.. attribute:: context_dim

Dimensionality of the text embedding context.

Type:

int

.. attribute:: vae_scale_factor

Scaling factor for the VAE latent representation.

Type:

int

.. attribute:: vae_channels

Number of channels in the VAE latent representation.

Type:

int

Initialization

__getitem__(index)#

Retrieves a single sample from the dataset.

The sample includes pre-cached latent representations for images and text.

Parameters:

index (int) – Index of the sample to retrieve.

Returns:

A dictionary containing the generated data sample with keys: - ‘latents’: Pre-cached latent representation of the image [C, H, W]. - ‘prompt_embeds’: Pre-cached text prompt embeddings [seq_len, context_dim]. - ‘pooled_prompt_embeds’: Pooled text prompt embeddings [pooled_dim]. - ‘text_ids’: Text position IDs [seq_len, 3].

Return type:

dict

__len__()#

Returns the total number of samples in the dataset.

bridge.diffusion.data.flux.flux_mock_datamodule._collate_fn(samples)#

Collate function to batch samples from _MockT2IDataset.

Parameters:

samples – List of sample dictionaries from the dataset.

Returns:

Batched dictionary with stacked tensors.

Return type:

dict

class bridge.diffusion.data.flux.flux_mock_datamodule.FluxMockDataModuleConfig#

Bases: megatron.bridge.data.utils.DatasetProvider

Configuration for FLUX mock data module.

This data module generates synthetic data for FLUX model training, matching the expected input format of FluxForwardStep.

.. attribute:: path

Unused, kept for interface compatibility.

.. attribute:: seq_length

Sequence length (unused for FLUX, kept for interface compatibility).

.. attribute:: packing_buffer_size

Packing buffer size (unused for FLUX).

.. attribute:: micro_batch_size

Micro batch size for training.

.. attribute:: global_batch_size

Global batch size for training.

.. attribute:: num_workers

Number of data loading workers.

.. attribute:: dataloader_type

Type of dataloader (“external” for mock data).

.. attribute:: image_H

Height of input images.

.. attribute:: image_W

Width of input images.

.. attribute:: vae_channels

Number of VAE latent channels.

.. attribute:: vae_scale_factor

VAE spatial downsampling factor.

.. attribute:: prompt_seq_len

Sequence length for T5 text embeddings.

.. attribute:: context_dim

Dimensionality of T5 text embeddings.

.. attribute:: pooled_prompt_dim

Dimensionality of CLIP pooled embeddings.

.. attribute:: image_precached

Whether images are pre-encoded as VAE latents.

.. attribute:: text_precached

Whether text is pre-encoded as embeddings.

.. attribute:: num_train_samples

Number of training samples.

path: str = <Multiline-String>#
seq_length: int#

1024

packing_buffer_size: int#

None

micro_batch_size: int#

1

global_batch_size: int#

4

num_workers: int#

8

dataloader_type: str#

‘external’

image_H: int#

1024

image_W: int#

1024

vae_channels: int#

16

vae_scale_factor: int#

8

prompt_seq_len: int#

512

context_dim: int#

4096

pooled_prompt_dim: int#

768

image_precached: bool#

True

text_precached: bool#

True

num_train_samples: int#

10000

__post_init__()#

Initialize the mock dataset and dataloader.

build_datasets(
_context: megatron.bridge.data.utils.DatasetBuildContext,
)#

Build and return train/val/test dataloaders.