bridge.diffusion.data.flux.flux_mock_datamodule#
Mock data module for FLUX model training.
Module Contents#
Classes#
A mock dataset class for text-to-image tasks, simulating data samples for training and testing. |
|
Configuration for FLUX mock data module. |
Functions#
Collate function to batch samples from _MockT2IDataset. |
API#
- class bridge.diffusion.data.flux.flux_mock_datamodule._MockT2IDataset(
- image_H: int = 1024,
- image_W: int = 1024,
- length: int = 100000,
- image_precached: bool = True,
- text_precached: bool = True,
- prompt_seq_len: int = 512,
- pooled_prompt_dim: int = 768,
- context_dim: int = 4096,
- vae_scale_factor: int = 8,
- vae_channels: int = 16,
Bases:
torch.utils.data.DatasetA mock dataset class for text-to-image tasks, simulating data samples for training and testing.
This dataset generates synthetic data for both image and text inputs, with options to use pre-cached latent representations or raw data. The class is designed for use in testing and prototyping machine learning models.
.. attribute:: image_H
Height of the generated images.
- Type:
int
.. attribute:: image_W
Width of the generated images.
- Type:
int
.. attribute:: length
Total number of samples in the dataset.
- Type:
int
.. attribute:: image_precached
Whether to use pre-cached latent representations for images.
- Type:
bool
.. attribute:: text_precached
Whether to use pre-cached embeddings for text.
- Type:
bool
.. attribute:: prompt_seq_len
Sequence length for text prompts.
- Type:
int
.. attribute:: pooled_prompt_dim
Dimensionality of pooled text embeddings.
- Type:
int
.. attribute:: context_dim
Dimensionality of the text embedding context.
- Type:
int
.. attribute:: vae_scale_factor
Scaling factor for the VAE latent representation.
- Type:
int
.. attribute:: vae_channels
Number of channels in the VAE latent representation.
- Type:
int
Initialization
- __getitem__(index)#
Retrieves a single sample from the dataset.
The sample includes pre-cached latent representations for images and text.
- Parameters:
index (int) â Index of the sample to retrieve.
- Returns:
A dictionary containing the generated data sample with keys: - âlatentsâ: Pre-cached latent representation of the image [C, H, W]. - âprompt_embedsâ: Pre-cached text prompt embeddings [seq_len, context_dim]. - âpooled_prompt_embedsâ: Pooled text prompt embeddings [pooled_dim]. - âtext_idsâ: Text position IDs [seq_len, 3].
- Return type:
dict
- __len__()#
Returns the total number of samples in the dataset.
- bridge.diffusion.data.flux.flux_mock_datamodule._collate_fn(samples)#
Collate function to batch samples from _MockT2IDataset.
- Parameters:
samples â List of sample dictionaries from the dataset.
- Returns:
Batched dictionary with stacked tensors.
- Return type:
dict
- class bridge.diffusion.data.flux.flux_mock_datamodule.FluxMockDataModuleConfig#
Bases:
megatron.bridge.data.utils.DatasetProviderConfiguration for FLUX mock data module.
This data module generates synthetic data for FLUX model training, matching the expected input format of FluxForwardStep.
.. attribute:: path
Unused, kept for interface compatibility.
.. attribute:: seq_length
Sequence length (unused for FLUX, kept for interface compatibility).
.. attribute:: packing_buffer_size
Packing buffer size (unused for FLUX).
.. attribute:: micro_batch_size
Micro batch size for training.
.. attribute:: global_batch_size
Global batch size for training.
.. attribute:: num_workers
Number of data loading workers.
.. attribute:: dataloader_type
Type of dataloader (âexternalâ for mock data).
.. attribute:: image_H
Height of input images.
.. attribute:: image_W
Width of input images.
.. attribute:: vae_channels
Number of VAE latent channels.
.. attribute:: vae_scale_factor
VAE spatial downsampling factor.
.. attribute:: prompt_seq_len
Sequence length for T5 text embeddings.
.. attribute:: context_dim
Dimensionality of T5 text embeddings.
.. attribute:: pooled_prompt_dim
Dimensionality of CLIP pooled embeddings.
.. attribute:: image_precached
Whether images are pre-encoded as VAE latents.
.. attribute:: text_precached
Whether text is pre-encoded as embeddings.
.. attribute:: num_train_samples
Number of training samples.
- path: str = <Multiline-String>#
- seq_length: int#
1024
- packing_buffer_size: int#
None
- micro_batch_size: int#
1
- global_batch_size: int#
4
- num_workers: int#
8
- dataloader_type: str#
âexternalâ
- image_H: int#
1024
- image_W: int#
1024
- vae_channels: int#
16
- vae_scale_factor: int#
8
- prompt_seq_len: int#
512
- context_dim: int#
4096
- pooled_prompt_dim: int#
768
- image_precached: bool#
True
- text_precached: bool#
True
- num_train_samples: int#
10000
- __post_init__()#
Initialize the mock dataset and dataloader.
- build_datasets(
- _context: megatron.bridge.data.utils.DatasetBuildContext,
Build and return train/val/test dataloaders.