bridge.diffusion.models.flux.flux_step#
FLUX Forward Step.
This is a prototype showing how to integrate the FlowMatchingPipeline into Megatron’s training flow, reusing the well-tested flow matching logic.
Module Contents#
Classes#
Forward step for FLUX using FlowMatchingPipeline. |
Functions#
Process batch data for FLUX model. |
Data#
API#
- bridge.diffusion.models.flux.flux_step.logger#
‘getLogger(…)’
- bridge.diffusion.models.flux.flux_step.flux_data_step(dataloader_iter, store_in_state=False)#
Process batch data for FLUX model.
- Parameters:
dataloader_iter – Iterator over the dataloader.
store_in_state – If True, store the batch in GlobalState for callbacks.
- Returns:
Processed batch dictionary with tensors moved to CUDA.
- class bridge.diffusion.models.flux.flux_step.FluxForwardStep(
- timestep_sampling: str = 'logit_normal',
- logit_mean: float = 0.0,
- logit_std: float = 1.0,
- flow_shift: float = 1.0,
- scheduler_steps: int = 1000,
- guidance_scale: float = 3.5,
- use_loss_weighting: bool = False,
Forward step for FLUX using FlowMatchingPipeline.
This class demonstrates how to integrate the FlowMatchingPipeline
- Parameters:
timestep_sampling – Method for sampling timesteps (“logit_normal”, “uniform”, “mode”).
logit_mean – Mean for logit-normal sampling.
logit_std – Standard deviation for logit-normal sampling.
flow_shift – Shift parameter for timestep transformation (default: 1.0 for FLUX).
scheduler_steps – Number of scheduler training steps.
guidance_scale – Guidance scale for FLUX-dev models.
use_loss_weighting – Whether to apply flow-based loss weighting.
Initialization
- __call__(
- state: megatron.bridge.training.state.GlobalState,
- data_iterator: Iterable,
- model: megatron.core.models.common.vision_module.vision_module.VisionModule,
Forward training step using FlowMatchingPipeline.
- Parameters:
state – Global state for the run.
data_iterator – Input data iterator.
model – The FLUX model.
- Returns:
Tuple containing the output tensor and the loss function.
- _prepare_batch_for_pipeline(batch: dict) dict#
Prepare Megatron batch for FlowMatchingPipeline.
Maps Megatron batch keys to FlowMatchingPipeline expected format:
latents -> image_latents (for consistency)
Keeps prompt_embeds, pooled_prompt_embeds, text_ids as-is
- _training_step_with_pipeline(
- model: megatron.core.models.common.vision_module.vision_module.VisionModule,
- batch: dict,
Perform single training step using FlowMatchingPipeline.
- Parameters:
model – The FLUX model.
batch – Data batch prepared for pipeline.
- Returns:
tuple of (output_tensor, loss, loss_mask). On other stages: output tensor.
- Return type:
On last pipeline stage
- _create_loss_function(
- loss_mask: torch.Tensor,
- check_for_nan_in_loss: bool,
- check_for_spiky_loss: bool,
Create a partial loss function with the specified configuration.
- Parameters:
loss_mask – Used to mask out some portions of the loss.
check_for_nan_in_loss – Whether to check for NaN values in the loss.
check_for_spiky_loss – Whether to check for spiky loss values.
- Returns:
A partial function that can be called with output_tensor to compute the loss.