bridge.diffusion.data.flux.flux_taskencoder#

Module Contents#

Classes#

FluxTaskEncoder

Task encoder for Flux dataset. .. attribute:: cookers

Functions#

cook

Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys.

API#

bridge.diffusion.data.flux.flux_taskencoder.cook(sample: dict) dict#

Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys.

Parameters:

sample (dict) – The input dictionary containing the raw sample data.

Returns:

A new dictionary containing the processed sample data with the following keys: - All keys from the result of basic_sample_keys(sample) - ‘json’: The contains meta data like resolution, etc. - ‘pth’: contains image latent tensor - ‘pickle’: contains text embeddings (T5 and CLIP pooled)

Return type:

dict

class bridge.diffusion.data.flux.flux_taskencoder.FluxTaskEncoder(
*args,
vae_scale_factor: int = 8,
seq_length: int = 1024,
latent_channels: int = 16,
**kwargs,
)#

Bases: megatron.bridge.diffusion.data.common.diffusion_task_encoder_with_sp.DiffusionTaskEncoderWithSequencePacking

Task encoder for Flux dataset. .. attribute:: cookers

A list of Cooker objects used for processing.

Type:

list

.. attribute:: vae_scale_factor

The VAE downsampling factor. Defaults to 8.

Type:

int

.. attribute:: seq_length

The sequence length. Defaults to 1024.

Type:

int

.. attribute:: latent_channels

Number of latent channels from VAE. Defaults to 16.

Type:

int

Initialization

cookers#

None

encode_sample(sample: dict) dict#
batch(
samples: List[megatron.bridge.diffusion.data.common.diffusion_sample.DiffusionSample],
) dict#

Return dictionary with data for batch.