nemo_automodel.components.flow_matching.pipeline#
FlowMatching Pipeline - Model-agnostic implementation with adapter pattern.
This module provides a unified FlowMatchingPipeline class that is completely independent of specific model implementations through the ModelAdapter abstraction.
Features:
Model-agnostic design via ModelAdapter protocol
Various timestep sampling strategies (uniform, logit_normal, mode, lognorm)
Flow shift transformation
Sigma clamping for finetuning
Loss weighting
Detailed training logging
Module Contents#
Classes#
Simple linear interpolation schedule for flow matching. |
|
Flow Matching Pipeline - Model-agnostic implementation. |
Functions#
Factory function to create a model adapter by name. |
|
Factory function to create a pipeline with a specific adapter. |
Data#
API#
- nemo_automodel.components.flow_matching.pipeline.logger#
‘getLogger(…)’
- class nemo_automodel.components.flow_matching.pipeline.LinearInterpolationSchedule#
Simple linear interpolation schedule for flow matching.
- forward(
- x0: torch.Tensor,
- x1: torch.Tensor,
- sigma: torch.Tensor,
Linear interpolation: x_t = (1 - σ) * x_0 + σ * x_1
- Parameters:
x0 – Starting point (clean latents)
x1 – Ending point (noise)
sigma – Sigma values in [0, 1]
- Returns:
Interpolated tensor at sigma
- class nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline(
- model_adapter: nemo_automodel.components.flow_matching.adapters.ModelAdapter,
- num_train_timesteps: int = 1000,
- timestep_sampling: str = 'logit_normal',
- flow_shift: float = 3.0,
- i2v_prob: float = 0.3,
- cfg_dropout_prob: float = 0.1,
- logit_mean: float = 0.0,
- logit_std: float = 1.0,
- mix_uniform_ratio: float = 0.1,
- use_sigma_noise: bool = True,
- sigma_min: float = 0.0,
- sigma_max: float = 1.0,
- use_loss_weighting: bool = True,
- log_interval: int = 100,
- summary_log_interval: int = 10,
- device: Optional[torch.device] = None,
Flow Matching Pipeline - Model-agnostic implementation.
This pipeline handles all flow matching training logic while delegating model-specific operations to a ModelAdapter. This allows adding support for new model architectures without modifying the pipeline code.
Features:
Noise scheduling with linear interpolation
Timestep sampling with various strategies
Optional sigma-noise flow shift toggle
Flow shift transformation
Sigma clamping for finetuning
Loss weighting
Detailed training logging
.. rubric:: Example
Create pipeline with HunyuanVideo adapter#
from automodel.flow_matching.adapters import HunyuanAdapter
pipeline = FlowMatchingPipeline( model_adapter=HunyuanAdapter(), flow_shift=3.0, timestep_sampling=”logit_normal”, )
Training step#
weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step(model, batch, device, dtype, global_step)
Initialization
Initialize the FlowMatching pipeline.
- Parameters:
model_adapter – ModelAdapter instance for model-specific operations
num_train_timesteps – Total number of timesteps for the flow
timestep_sampling –
Sampling strategy:
”uniform”: Pure uniform sampling
”logit_normal”: SD3-style logit-normal (recommended)
”mode”: Mode-based sampling
”lognorm”: Log-normal based sampling
”mix”: Mix of lognorm and uniform
flow_shift – Shift parameter for timestep transformation
i2v_prob – Probability of using image-to-video conditioning
cfg_dropout_prob – Probability of dropping text embeddings for CFG training
logit_mean – Mean for logit-normal distribution
logit_std – Std for logit-normal distribution
mix_uniform_ratio – Ratio of uniform samples when using mix
use_sigma_noise – Whether to use shifted sigma-noise sampling. If False, sample sigma uniformly without flow shift (“uniform_no_shift”)
sigma_min – Minimum sigma (0.0 for pretrain)
sigma_max – Maximum sigma (1.0 for pretrain)
use_loss_weighting – Whether to apply flow-based loss weighting
log_interval – Steps between detailed logs
summary_log_interval – Steps between summary logs
device – Device to use for computations
- sample_timesteps(
- batch_size: int,
- device: Optional[torch.device] = None,
Sample timesteps and compute sigma values with flow shift.
Implements the flow shift transformation: σ = shift / (shift + (1/u - 1))
- Parameters:
batch_size – Number of timesteps to sample
device – Device for tensor operations
- Returns:
Sigma values in [sigma_min, sigma_max] timesteps: Timesteps in [0, num_train_timesteps] sampling_method: Name of the sampling method used
- Return type:
sigma
- _sample_from_distribution(
- batch_size: int,
- device: torch.device,
Sample u values from the configured distribution.
- determine_task_type(data_type: str) str#
Determine task type based on data type and randomization.
- compute_loss(
- model_pred: torch.Tensor,
- target: torch.Tensor,
- sigma: torch.Tensor,
- batch: Optional[Dict[str, Any]] = None,
Compute flow matching loss with optional weighting.
Loss weight: w = 1 + flow_shift * σ
- Parameters:
model_pred – Model prediction
target – Target (velocity = noise - clean)
sigma – Sigma values for each sample
batch – Optional batch dictionary containing loss_mask
- Returns:
Per-element weighted loss average_weighted_loss: Scalar average weighted loss unweighted_loss: Per-element raw MSE loss average_unweighted_loss: Scalar average unweighted loss loss_weight: Applied weights loss_mask: Loss mask from batch (or None if not present)
- Return type:
weighted_loss
- step(
- model: torch.nn.Module,
- batch: Dict[str, Any],
- device: torch.device = torch.device('cuda'),
- dtype: torch.dtype = torch.bfloat16,
- global_step: int = 0,
Execute a single training step with flow matching.
Expected batch format: { “video_latents”: torch.Tensor, # [B, C, F, H, W] for video OR “image_latents”: torch.Tensor, # [B, C, H, W] for image “text_embeddings”: torch.Tensor, # [B, seq_len, dim] “data_type”: str, # “video” or “image” (optional) # … additional model-specific keys handled by adapter }
- Parameters:
model – The model to train
batch – Batch of training data
device – Device to use
dtype – Data type for operations
global_step – Current training step (for logging)
- Returns:
Per-element weighted loss average_weighted_loss: Scalar average weighted loss loss_mask: Mask indicating valid loss elements (or None) metrics: Dictionary of training metrics
- Return type:
weighted_loss
- _log_detailed(
- global_step: int,
- sampling_method: str,
- batch_size: int,
- sigma: torch.Tensor,
- timesteps: torch.Tensor,
- latents: torch.Tensor,
- noise: torch.Tensor,
- noisy_latents: torch.Tensor,
Log detailed training information.
- _log_loss_detailed(
- global_step: int,
- model_pred: torch.Tensor,
- target: torch.Tensor,
- loss_weight: torch.Tensor,
- unweighted_loss: torch.Tensor,
- weighted_loss: torch.Tensor,
Log detailed loss information.
- nemo_automodel.components.flow_matching.pipeline.create_adapter(
- adapter_type: str,
- **kwargs,
Factory function to create a model adapter by name.
- Parameters:
adapter_type – Type of adapter (“hunyuan”, “simple”, “flux”)
**kwargs – Additional arguments passed to the adapter constructor
- Returns:
ModelAdapter instance
- nemo_automodel.components.flow_matching.pipeline.create_pipeline(
- adapter_type: str,
- adapter_kwargs: Optional[Dict[str, Any]] = None,
- **pipeline_kwargs,
Factory function to create a pipeline with a specific adapter.
- Parameters:
adapter_type – Type of adapter (“hunyuan”, “simple”)
adapter_kwargs – Arguments for the adapter constructor
**pipeline_kwargs – Arguments for the pipeline constructor
- Returns:
FlowMatchingPipeline instance
.. rubric:: Example
pipeline = create_pipeline( adapter_type=”hunyuan”, adapter_kwargs={“use_condition_latents”: True}, flow_shift=3.0, timestep_sampling=”logit_normal”, )