bridge.diffusion.recipes.flux.flux#
Module Contents#
Functions#
Configure the FLUX model. |
|
Create a pre-training configuration for FLUX model. |
API#
- bridge.diffusion.recipes.flux.flux.model_config(
- tensor_parallelism: int = 1,
- pipeline_parallelism: int = 1,
- pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16,
- virtual_pipeline_parallelism: Optional[int] = None,
- context_parallelism: int = 1,
- sequence_parallelism: bool = False,
- seq_length: int = 1024,
- num_joint_layers: int = 19,
- num_single_layers: int = 38,
- hidden_size: int = 3072,
- num_attention_heads: int = 24,
- in_channels: int = 64,
- context_dim: int = 4096,
- guidance_embed: bool = False,
- guidance_scale: float = 3.5,
Configure the FLUX model.
- Parameters:
tensor_parallelism (int) – Degree of tensor model parallelism.
pipeline_parallelism (int) – Degree of pipeline model parallelism.
pipeline_parallelism_dtype (Optional[torch.dtype]) – Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]) – Size of virtual pipeline parallelism.
context_parallelism (int) – Degree of context parallelism.
sequence_parallelism (bool) – Whether to use sequence parallelism.
seq_length (int) – Sequence length for the model.
num_joint_layers (int) – Number of double (joint) transformer blocks.
num_single_layers (int) – Number of single transformer blocks.
hidden_size (int) – Hidden dimension size.
num_attention_heads (int) – Number of attention heads.
in_channels (int) – Number of input channels (latent channels).
context_dim (int) – Text encoder context dimension.
guidance_embed (bool) – Whether to use guidance embedding (for FLUX-dev).
guidance_scale (float) – Classifier-free guidance scale.
- Returns:
Configuration for the FLUX model.
- Return type:
- bridge.diffusion.recipes.flux.flux.pretrain_config(
- dir: Optional[str] = None,
- name: str = 'default',
- data_paths: Optional[List[str]] = None,
- mock: bool = False,
- tensor_parallelism: int = 1,
- pipeline_parallelism: int = 1,
- pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16,
- virtual_pipeline_parallelism: Optional[int] = None,
- context_parallelism: int = 1,
- sequence_parallelism: bool = False,
- use_megatron_fsdp: bool = False,
- num_joint_layers: int = 19,
- num_single_layers: int = 38,
- hidden_size: int = 3072,
- num_attention_heads: int = 24,
- in_channels: int = 64,
- context_dim: int = 4096,
- guidance_embed: bool = False,
- guidance_scale: float = 3.5,
- image_H: int = 1024,
- image_W: int = 1024,
- vae_channels: int = 16,
- vae_scale_factor: int = 8,
- prompt_seq_len: int = 512,
- pooled_prompt_dim: int = 768,
- train_iters: int = 10000,
- global_batch_size: int = 4,
- micro_batch_size: int = 1,
- lr: float = 0.0001,
- lr_warmup_iters: int = 1000,
- precision_config: Optional[Union[megatron.bridge.training.mixed_precision.MixedPrecisionConfig, str]] = 'bf16_mixed',
- comm_overlap_config: Optional[megatron.bridge.training.comm_overlap.CommOverlapConfig] = None,
Create a pre-training configuration for FLUX model.
- Parameters:
dir (Optional[str]) – Base directory for saving logs and checkpoints.
name (str) – Name of the pre-training run.
data_paths (Optional[List[str]]) – List of paths to dataset files. If None, mock data will be used.
mock (bool) – Whether to use mock data. If True, ignores data_paths.
tensor_parallelism (int) – Degree of tensor model parallelism.
pipeline_parallelism (int) – Degree of pipeline model parallelism.
pipeline_parallelism_dtype (Optional[torch.dtype]) – Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]) – Size of virtual pipeline parallelism.
context_parallelism (int) – Degree of context parallelism.
sequence_parallelism (bool) – Whether to use sequence parallelism.
use_megatron_fsdp (bool) – Whether to use Megatron FSDP.
num_joint_layers (int) – Number of double (joint) transformer blocks.
num_single_layers (int) – Number of single transformer blocks.
hidden_size (int) – Hidden dimension size.
num_attention_heads (int) – Number of attention heads.
in_channels (int) – Number of input channels (latent channels).
context_dim (int) – Text encoder context dimension.
guidance_embed (bool) – Whether to use guidance embedding (for FLUX-dev).
guidance_scale (float) – Classifier-free guidance scale.
image_H (int) – Image height.
image_W (int) – Image width.
vae_channels (int) – Number of VAE latent channels.
vae_scale_factor (int) – VAE downsampling factor.
prompt_seq_len (int) – Sequence length for text prompts (T5).
pooled_prompt_dim (int) – Dimensionality of pooled text embeddings (CLIP).
train_iters (int) – Total number of training iterations.
global_batch_size (int) – Global batch size for training.
micro_batch_size (int) – Micro batch size for training.
lr (float) – Learning rate.
lr_warmup_iters (int) – Number of warmup iterations for the learning rate.
precision_config (Optional[Union[MixedPrecisionConfig, str]]) – Precision configuration.
comm_overlap_config (Optional[CommOverlapConfig]) – Communication overlap configuration.
- Returns:
Configuration for pre-training.
- Return type: