bridge.diffusion.recipes.flux.flux#

Module Contents#

Functions#

model_config

Configure the FLUX model.

pretrain_config

Create a pre-training configuration for FLUX model.

API#

bridge.diffusion.recipes.flux.flux.model_config(
tensor_parallelism: int = 1,
pipeline_parallelism: int = 1,
pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16,
virtual_pipeline_parallelism: Optional[int] = None,
context_parallelism: int = 1,
sequence_parallelism: bool = False,
seq_length: int = 1024,
num_joint_layers: int = 19,
num_single_layers: int = 38,
hidden_size: int = 3072,
num_attention_heads: int = 24,
in_channels: int = 64,
context_dim: int = 4096,
guidance_embed: bool = False,
guidance_scale: float = 3.5,
) megatron.bridge.diffusion.models.flux.flux_provider.FluxProvider#

Configure the FLUX model.

Parameters:
  • tensor_parallelism (int) – Degree of tensor model parallelism.

  • pipeline_parallelism (int) – Degree of pipeline model parallelism.

  • pipeline_parallelism_dtype (Optional[torch.dtype]) – Data type for pipeline parallelism.

  • virtual_pipeline_parallelism (Optional[int]) – Size of virtual pipeline parallelism.

  • context_parallelism (int) – Degree of context parallelism.

  • sequence_parallelism (bool) – Whether to use sequence parallelism.

  • seq_length (int) – Sequence length for the model.

  • num_joint_layers (int) – Number of double (joint) transformer blocks.

  • num_single_layers (int) – Number of single transformer blocks.

  • hidden_size (int) – Hidden dimension size.

  • num_attention_heads (int) – Number of attention heads.

  • in_channels (int) – Number of input channels (latent channels).

  • context_dim (int) – Text encoder context dimension.

  • guidance_embed (bool) – Whether to use guidance embedding (for FLUX-dev).

  • guidance_scale (float) – Classifier-free guidance scale.

Returns:

Configuration for the FLUX model.

Return type:

FluxProvider

bridge.diffusion.recipes.flux.flux.pretrain_config(
dir: Optional[str] = None,
name: str = 'default',
data_paths: Optional[List[str]] = None,
mock: bool = False,
tensor_parallelism: int = 1,
pipeline_parallelism: int = 1,
pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16,
virtual_pipeline_parallelism: Optional[int] = None,
context_parallelism: int = 1,
sequence_parallelism: bool = False,
use_megatron_fsdp: bool = False,
num_joint_layers: int = 19,
num_single_layers: int = 38,
hidden_size: int = 3072,
num_attention_heads: int = 24,
in_channels: int = 64,
context_dim: int = 4096,
guidance_embed: bool = False,
guidance_scale: float = 3.5,
image_H: int = 1024,
image_W: int = 1024,
vae_channels: int = 16,
vae_scale_factor: int = 8,
prompt_seq_len: int = 512,
pooled_prompt_dim: int = 768,
train_iters: int = 10000,
global_batch_size: int = 4,
micro_batch_size: int = 1,
lr: float = 0.0001,
lr_warmup_iters: int = 1000,
precision_config: Optional[Union[megatron.bridge.training.mixed_precision.MixedPrecisionConfig, str]] = 'bf16_mixed',
comm_overlap_config: Optional[megatron.bridge.training.comm_overlap.CommOverlapConfig] = None,
) megatron.bridge.training.config.ConfigContainer#

Create a pre-training configuration for FLUX model.

Parameters:
  • dir (Optional[str]) – Base directory for saving logs and checkpoints.

  • name (str) – Name of the pre-training run.

  • data_paths (Optional[List[str]]) – List of paths to dataset files. If None, mock data will be used.

  • mock (bool) – Whether to use mock data. If True, ignores data_paths.

  • tensor_parallelism (int) – Degree of tensor model parallelism.

  • pipeline_parallelism (int) – Degree of pipeline model parallelism.

  • pipeline_parallelism_dtype (Optional[torch.dtype]) – Data type for pipeline parallelism.

  • virtual_pipeline_parallelism (Optional[int]) – Size of virtual pipeline parallelism.

  • context_parallelism (int) – Degree of context parallelism.

  • sequence_parallelism (bool) – Whether to use sequence parallelism.

  • use_megatron_fsdp (bool) – Whether to use Megatron FSDP.

  • num_joint_layers (int) – Number of double (joint) transformer blocks.

  • num_single_layers (int) – Number of single transformer blocks.

  • hidden_size (int) – Hidden dimension size.

  • num_attention_heads (int) – Number of attention heads.

  • in_channels (int) – Number of input channels (latent channels).

  • context_dim (int) – Text encoder context dimension.

  • guidance_embed (bool) – Whether to use guidance embedding (for FLUX-dev).

  • guidance_scale (float) – Classifier-free guidance scale.

  • image_H (int) – Image height.

  • image_W (int) – Image width.

  • vae_channels (int) – Number of VAE latent channels.

  • vae_scale_factor (int) – VAE downsampling factor.

  • prompt_seq_len (int) – Sequence length for text prompts (T5).

  • pooled_prompt_dim (int) – Dimensionality of pooled text embeddings (CLIP).

  • train_iters (int) – Total number of training iterations.

  • global_batch_size (int) – Global batch size for training.

  • micro_batch_size (int) – Micro batch size for training.

  • lr (float) – Learning rate.

  • lr_warmup_iters (int) – Number of warmup iterations for the learning rate.

  • precision_config (Optional[Union[MixedPrecisionConfig, str]]) – Precision configuration.

  • comm_overlap_config (Optional[CommOverlapConfig]) – Communication overlap configuration.

Returns:

Configuration for pre-training.

Return type:

ConfigContainer