bridge.diffusion.models.flux.flux_step#

FLUX Forward Step.

This is a prototype showing how to integrate the FlowMatchingPipeline into Megatron’s training flow, reusing the well-tested flow matching logic.

Module Contents#

Classes#

FluxForwardStep

Forward step for FLUX using FlowMatchingPipeline.

Functions#

flux_data_step

Process batch data for FLUX model.

Data#

API#

bridge.diffusion.models.flux.flux_step.logger#

‘getLogger(…)’

bridge.diffusion.models.flux.flux_step.flux_data_step(dataloader_iter, store_in_state=False)#

Process batch data for FLUX model.

Parameters:
  • dataloader_iter – Iterator over the dataloader.

  • store_in_state – If True, store the batch in GlobalState for callbacks.

Returns:

Processed batch dictionary with tensors moved to CUDA.

class bridge.diffusion.models.flux.flux_step.FluxForwardStep(
timestep_sampling: str = 'logit_normal',
logit_mean: float = 0.0,
logit_std: float = 1.0,
flow_shift: float = 1.0,
scheduler_steps: int = 1000,
guidance_scale: float = 3.5,
use_loss_weighting: bool = False,
)#

Forward step for FLUX using FlowMatchingPipeline.

This class demonstrates how to integrate the FlowMatchingPipeline

Parameters:
  • timestep_sampling – Method for sampling timesteps (“logit_normal”, “uniform”, “mode”).

  • logit_mean – Mean for logit-normal sampling.

  • logit_std – Standard deviation for logit-normal sampling.

  • flow_shift – Shift parameter for timestep transformation (default: 1.0 for FLUX).

  • scheduler_steps – Number of scheduler training steps.

  • guidance_scale – Guidance scale for FLUX-dev models.

  • use_loss_weighting – Whether to apply flow-based loss weighting.

Initialization

__call__(
state: megatron.bridge.training.state.GlobalState,
data_iterator: Iterable,
model: megatron.core.models.common.vision_module.vision_module.VisionModule,
) tuple[torch.Tensor, functools.partial]#

Forward training step using FlowMatchingPipeline.

Parameters:
  • state – Global state for the run.

  • data_iterator – Input data iterator.

  • model – The FLUX model.

Returns:

Tuple containing the output tensor and the loss function.

_prepare_batch_for_pipeline(batch: dict) dict#

Prepare Megatron batch for FlowMatchingPipeline.

Maps Megatron batch keys to FlowMatchingPipeline expected format:

  • latents -> image_latents (for consistency)

  • Keeps prompt_embeds, pooled_prompt_embeds, text_ids as-is

_training_step_with_pipeline(
model: megatron.core.models.common.vision_module.vision_module.VisionModule,
batch: dict,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor] | torch.Tensor#

Perform single training step using FlowMatchingPipeline.

Parameters:
  • model – The FLUX model.

  • batch – Data batch prepared for pipeline.

Returns:

tuple of (output_tensor, loss, loss_mask). On other stages: output tensor.

Return type:

On last pipeline stage

_create_loss_function(
loss_mask: torch.Tensor,
check_for_nan_in_loss: bool,
check_for_spiky_loss: bool,
) functools.partial#

Create a partial loss function with the specified configuration.

Parameters:
  • loss_mask – Used to mask out some portions of the loss.

  • check_for_nan_in_loss – Whether to check for NaN values in the loss.

  • check_for_spiky_loss – Whether to check for spiky loss values.

Returns:

A partial function that can be called with output_tensor to compute the loss.