nemo_automodel.components.flow_matching.pipeline#

FlowMatching Pipeline - Model-agnostic implementation with adapter pattern.

This module provides a unified FlowMatchingPipeline class that is completely independent of specific model implementations through the ModelAdapter abstraction.

Features:

  • Model-agnostic design via ModelAdapter protocol

  • Various timestep sampling strategies (uniform, logit_normal, mode, lognorm)

  • Flow shift transformation

  • Sigma clamping for finetuning

  • Loss weighting

  • Detailed training logging

Module Contents#

Classes#

LinearInterpolationSchedule

Simple linear interpolation schedule for flow matching.

FlowMatchingPipeline

Flow Matching Pipeline - Model-agnostic implementation.

Functions#

create_adapter

Factory function to create a model adapter by name.

create_pipeline

Factory function to create a pipeline with a specific adapter.

Data#

API#

nemo_automodel.components.flow_matching.pipeline.logger#

‘getLogger(…)’

class nemo_automodel.components.flow_matching.pipeline.LinearInterpolationSchedule#

Simple linear interpolation schedule for flow matching.

forward(
x0: torch.Tensor,
x1: torch.Tensor,
sigma: torch.Tensor,
) torch.Tensor#

Linear interpolation: x_t = (1 - σ) * x_0 + σ * x_1

Parameters:
  • x0 – Starting point (clean latents)

  • x1 – Ending point (noise)

  • sigma – Sigma values in [0, 1]

Returns:

Interpolated tensor at sigma

class nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline(
model_adapter: nemo_automodel.components.flow_matching.adapters.ModelAdapter,
num_train_timesteps: int = 1000,
timestep_sampling: str = 'logit_normal',
flow_shift: float = 3.0,
i2v_prob: float = 0.3,
cfg_dropout_prob: float = 0.1,
logit_mean: float = 0.0,
logit_std: float = 1.0,
mix_uniform_ratio: float = 0.1,
use_sigma_noise: bool = True,
sigma_min: float = 0.0,
sigma_max: float = 1.0,
use_loss_weighting: bool = True,
log_interval: int = 100,
summary_log_interval: int = 10,
device: Optional[torch.device] = None,
)#

Flow Matching Pipeline - Model-agnostic implementation.

This pipeline handles all flow matching training logic while delegating model-specific operations to a ModelAdapter. This allows adding support for new model architectures without modifying the pipeline code.

Features:

  • Noise scheduling with linear interpolation

  • Timestep sampling with various strategies

  • Optional sigma-noise flow shift toggle

  • Flow shift transformation

  • Sigma clamping for finetuning

  • Loss weighting

  • Detailed training logging

.. rubric:: Example

Create pipeline with HunyuanVideo adapter#

from automodel.flow_matching.adapters import HunyuanAdapter

pipeline = FlowMatchingPipeline( model_adapter=HunyuanAdapter(), flow_shift=3.0, timestep_sampling=”logit_normal”, )

Training step#

weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step(model, batch, device, dtype, global_step)

Initialization

Initialize the FlowMatching pipeline.

Parameters:
  • model_adapter – ModelAdapter instance for model-specific operations

  • num_train_timesteps – Total number of timesteps for the flow

  • timestep_sampling

    Sampling strategy:

    • ”uniform”: Pure uniform sampling

    • ”logit_normal”: SD3-style logit-normal (recommended)

    • ”mode”: Mode-based sampling

    • ”lognorm”: Log-normal based sampling

    • ”mix”: Mix of lognorm and uniform

  • flow_shift – Shift parameter for timestep transformation

  • i2v_prob – Probability of using image-to-video conditioning

  • cfg_dropout_prob – Probability of dropping text embeddings for CFG training

  • logit_mean – Mean for logit-normal distribution

  • logit_std – Std for logit-normal distribution

  • mix_uniform_ratio – Ratio of uniform samples when using mix

  • use_sigma_noise – Whether to use shifted sigma-noise sampling. If False, sample sigma uniformly without flow shift (“uniform_no_shift”)

  • sigma_min – Minimum sigma (0.0 for pretrain)

  • sigma_max – Maximum sigma (1.0 for pretrain)

  • use_loss_weighting – Whether to apply flow-based loss weighting

  • log_interval – Steps between detailed logs

  • summary_log_interval – Steps between summary logs

  • device – Device to use for computations

sample_timesteps(
batch_size: int,
device: Optional[torch.device] = None,
) Tuple[torch.Tensor, torch.Tensor, str]#

Sample timesteps and compute sigma values with flow shift.

Implements the flow shift transformation: σ = shift / (shift + (1/u - 1))

Parameters:
  • batch_size – Number of timesteps to sample

  • device – Device for tensor operations

Returns:

Sigma values in [sigma_min, sigma_max] timesteps: Timesteps in [0, num_train_timesteps] sampling_method: Name of the sampling method used

Return type:

sigma

_sample_from_distribution(
batch_size: int,
device: torch.device,
) torch.Tensor#

Sample u values from the configured distribution.

determine_task_type(data_type: str) str#

Determine task type based on data type and randomization.

compute_loss(
model_pred: torch.Tensor,
target: torch.Tensor,
sigma: torch.Tensor,
batch: Optional[Dict[str, Any]] = None,
) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]#

Compute flow matching loss with optional weighting.

Loss weight: w = 1 + flow_shift * σ

Parameters:
  • model_pred – Model prediction

  • target – Target (velocity = noise - clean)

  • sigma – Sigma values for each sample

  • batch – Optional batch dictionary containing loss_mask

Returns:

Per-element weighted loss average_weighted_loss: Scalar average weighted loss unweighted_loss: Per-element raw MSE loss average_unweighted_loss: Scalar average unweighted loss loss_weight: Applied weights loss_mask: Loss mask from batch (or None if not present)

Return type:

weighted_loss

step(
model: torch.nn.Module,
batch: Dict[str, Any],
device: torch.device = torch.device('cuda'),
dtype: torch.dtype = torch.bfloat16,
global_step: int = 0,
) Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Dict[str, Any]]#

Execute a single training step with flow matching.

Expected batch format: { “video_latents”: torch.Tensor, # [B, C, F, H, W] for video OR “image_latents”: torch.Tensor, # [B, C, H, W] for image “text_embeddings”: torch.Tensor, # [B, seq_len, dim] “data_type”: str, # “video” or “image” (optional) # … additional model-specific keys handled by adapter }

Parameters:
  • model – The model to train

  • batch – Batch of training data

  • device – Device to use

  • dtype – Data type for operations

  • global_step – Current training step (for logging)

Returns:

Per-element weighted loss average_weighted_loss: Scalar average weighted loss loss_mask: Mask indicating valid loss elements (or None) metrics: Dictionary of training metrics

Return type:

weighted_loss

_log_detailed(
global_step: int,
sampling_method: str,
batch_size: int,
sigma: torch.Tensor,
timesteps: torch.Tensor,
latents: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
)#

Log detailed training information.

_log_loss_detailed(
global_step: int,
model_pred: torch.Tensor,
target: torch.Tensor,
loss_weight: torch.Tensor,
unweighted_loss: torch.Tensor,
weighted_loss: torch.Tensor,
)#

Log detailed loss information.

nemo_automodel.components.flow_matching.pipeline.create_adapter(
adapter_type: str,
**kwargs,
) nemo_automodel.components.flow_matching.adapters.ModelAdapter#

Factory function to create a model adapter by name.

Parameters:
  • adapter_type – Type of adapter (“hunyuan”, “simple”, “flux”)

  • **kwargs – Additional arguments passed to the adapter constructor

Returns:

ModelAdapter instance

nemo_automodel.components.flow_matching.pipeline.create_pipeline(
adapter_type: str,
adapter_kwargs: Optional[Dict[str, Any]] = None,
**pipeline_kwargs,
) nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline#

Factory function to create a pipeline with a specific adapter.

Parameters:
  • adapter_type – Type of adapter (“hunyuan”, “simple”)

  • adapter_kwargs – Arguments for the adapter constructor

  • **pipeline_kwargs – Arguments for the pipeline constructor

Returns:

FlowMatchingPipeline instance

.. rubric:: Example

pipeline = create_pipeline( adapter_type=”hunyuan”, adapter_kwargs={“use_condition_latents”: True}, flow_shift=3.0, timestep_sampling=”logit_normal”, )