nemo_automodel.components.flow_matching.pipeline
nemo_automodel.components.flow_matching.pipeline
FlowMatching Pipeline - Model-agnostic implementation with adapter pattern.
This module provides a unified FlowMatchingPipeline class that is completely independent of specific model implementations through the ModelAdapter abstraction.
Features:
- Model-agnostic design via ModelAdapter protocol
- Various timestep sampling strategies (uniform, logit_normal, mode, lognorm)
- Flow shift transformation
- Sigma clamping for finetuning
- Loss weighting
- Detailed training logging
Module Contents
Classes
Functions
Data
API
Flow Matching Pipeline - Model-agnostic implementation.
This pipeline handles all flow matching training logic while delegating model-specific operations to a ModelAdapter. This allows adding support for new model architectures without modifying the pipeline code.
Features:
- Noise scheduling with linear interpolation
- Timestep sampling with various strategies
- Optional sigma-noise flow shift toggle
- Flow shift transformation
- Sigma clamping for finetuning
- Loss weighting
- Detailed training logging
Build Bell-Shaped Midpoint Noise Timestep Weighting table.
Returns a 1D tensor of length num_train_timesteps with weights following a Gaussian bell curve centered at the midpoint (t=steps/2).
Log detailed training information.
Log detailed loss information.
Sample u values from the configured distribution.
Compute flow matching loss with optional weighting.
Loss weight: w = 1 + flow_shift * Ī
Parameters:
Model prediction
Target (velocity = noise - clean)
Sigma values for each sample
Optional batch dictionary containing loss_mask
Returns: torch.Tensor
Per-element weighted loss
Determine task type based on data type and randomization.
Sample timesteps and compute sigma values with flow shift.
Implements the flow shift transformation: Ī = shift / (shift + (1/u - 1))
Parameters:
Number of timesteps to sample
Device for tensor operations
Returns: torch.Tensor
Sigma values in [sigma_min, sigma_max]
Execute a single training step with flow matching.
Expected batch format: { âvideo_latentsâ: torch.Tensor, # [B, C, F, H, W] for video OR âimage_latentsâ: torch.Tensor, # [B, C, H, W] for image âtext_embeddingsâ: torch.Tensor, # [B, seq_len, dim] âdata_typeâ: str, # âvideoâ or âimageâ (optional)
âĻ additional model-specific keys handled by adapter
}
Parameters:
The model to train
Batch of training data
Device to use
Data type for operations
Current training step (for logging)
Whether to collect scalar diagnostics. Disable in hot training paths to avoid host/device synchronizations.
Whether to run scalar loss explosion checks. Disable in hot training paths to avoid host/device synchronizations.
Returns: torch.Tensor
Per-element weighted loss
Simple linear interpolation schedule for flow matching.
Linear interpolation: x_t = (1 - Ī) * x_0 + Ī * x_1
Parameters:
Starting point (clean latents)
Ending point (noise)
Sigma values in [0, 1]
Returns: torch.Tensor
Interpolated tensor at sigma
Factory function to create a model adapter by name.
Parameters:
Type of adapter (âhunyuanâ, âsimpleâ, âfluxâ, âflux2â, âqwen_imageâ)
Additional arguments passed to the adapter constructor
Returns: ModelAdapter
ModelAdapter instance
Factory function to create a pipeline with a specific adapter.
Parameters:
Type of adapter (âhunyuanâ, âsimpleâ)
Arguments for the adapter constructor
Arguments for the pipeline constructor
Returns: FlowMatchingPipeline
FlowMatchingPipeline instance