bridge.diffusion.data.flux.flux_taskencoder#
Module Contents#
Classes#
Task encoder for Flux dataset. .. attribute:: cookers |
Functions#
Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. |
API#
- bridge.diffusion.data.flux.flux_taskencoder.cook(sample: dict) dict#
Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys.
- Parameters:
sample (dict) – The input dictionary containing the raw sample data.
- Returns:
A new dictionary containing the processed sample data with the following keys: - All keys from the result of
basic_sample_keys(sample)- ‘json’: The contains meta data like resolution, etc. - ‘pth’: contains image latent tensor - ‘pickle’: contains text embeddings (T5 and CLIP pooled)- Return type:
dict
- class bridge.diffusion.data.flux.flux_taskencoder.FluxTaskEncoder(
- *args,
- vae_scale_factor: int = 8,
- seq_length: int = 1024,
- latent_channels: int = 16,
- **kwargs,
Bases:
megatron.bridge.diffusion.data.common.diffusion_task_encoder_with_sp.DiffusionTaskEncoderWithSequencePackingTask encoder for Flux dataset. .. attribute:: cookers
A list of Cooker objects used for processing.
- Type:
list
.. attribute:: vae_scale_factor
The VAE downsampling factor. Defaults to 8.
- Type:
int
.. attribute:: seq_length
The sequence length. Defaults to 1024.
- Type:
int
.. attribute:: latent_channels
Number of latent channels from VAE. Defaults to 16.
- Type:
int
Initialization
- cookers#
None
- encode_sample(sample: dict) dict#
- batch(
- samples: List[megatron.bridge.diffusion.data.common.diffusion_sample.DiffusionSample],
Return dictionary with data for batch.