nemo_automodel.components.flow_matching.pipeline

View as Markdown

FlowMatching Pipeline - Model-agnostic implementation with adapter pattern.

This module provides a unified FlowMatchingPipeline class that is completely independent of specific model implementations through the ModelAdapter abstraction.

Features:

  • Model-agnostic design via ModelAdapter protocol
  • Various timestep sampling strategies (uniform, logit_normal, mode, lognorm)
  • Flow shift transformation
  • Sigma clamping for finetuning
  • Loss weighting
  • Detailed training logging

Module Contents

Classes

NameDescription
FlowMatchingPipelineFlow Matching Pipeline - Model-agnostic implementation.
LinearInterpolationScheduleSimple linear interpolation schedule for flow matching.

Functions

NameDescription
create_adapterFactory function to create a model adapter by name.
create_pipelineFactory function to create a pipeline with a specific adapter.

Data

logger

API

class nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline(
model_adapter: nemo_automodel.components.flow_matching.adapters.ModelAdapter,
num_train_timesteps: int = 1000,
timestep_sampling: str = 'logit_normal',
flow_shift: float = 3.0,
i2v_prob: float = 0.3,
cfg_dropout_prob: float = 0.1,
logit_mean: float = 0.0,
logit_std: float = 1.0,
mix_uniform_ratio: float = 0.1,
use_sigma_noise: bool = True,
sigma_min: float = 0.0,
sigma_max: float = 1.0,
use_loss_weighting: bool = True,
loss_weighting_scheme: str = 'linear',
log_interval: int = 100,
summary_log_interval: int = 10,
device: typing.Optional[torch.device] = None
)

Flow Matching Pipeline - Model-agnostic implementation.

This pipeline handles all flow matching training logic while delegating model-specific operations to a ModelAdapter. This allows adding support for new model architectures without modifying the pipeline code.

Features:

  • Noise scheduling with linear interpolation
  • Timestep sampling with various strategies
  • Optional sigma-noise flow shift toggle
  • Flow shift transformation
  • Sigma clamping for finetuning
  • Loss weighting
  • Detailed training logging
_bsmntw_weights
= self._build_bsmntw_weights()
device
noise_schedule
= LinearInterpolationSchedule()
nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline._build_bsmntw_weights() -> torch.Tensor

Build Bell-Shaped Midpoint Noise Timestep Weighting table.

Returns a 1D tensor of length num_train_timesteps with weights following a Gaussian bell curve centered at the midpoint (t=steps/2).

nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline._log_detailed(
global_step: int,
sampling_method: str,
batch_size: int,
sigma: torch.Tensor,
timesteps: torch.Tensor,
latents: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor
)

Log detailed training information.

nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline._log_loss_detailed(
global_step: int,
model_pred: torch.Tensor,
target: torch.Tensor,
loss_weight: torch.Tensor,
unweighted_loss: torch.Tensor,
weighted_loss: torch.Tensor
)

Log detailed loss information.

nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline._sample_from_distribution(
batch_size: int,
device: torch.device
) -> torch.Tensor

Sample u values from the configured distribution.

nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline.compute_loss(
model_pred: torch.Tensor,
target: torch.Tensor,
sigma: torch.Tensor,
batch: typing.Optional[typing.Dict[str, typing.Any]] = None
) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, typing.Optional[torch.Tensor]]

Compute flow matching loss with optional weighting.

Loss weight: w = 1 + flow_shift * ΃

Parameters:

model_pred
torch.Tensor

Model prediction

target
torch.Tensor

Target (velocity = noise - clean)

sigma
torch.Tensor

Sigma values for each sample

batch
Optional[Dict[str, Any]]Defaults to None

Optional batch dictionary containing loss_mask

Returns: torch.Tensor

Per-element weighted loss

nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline.determine_task_type(
data_type: str
) -> str

Determine task type based on data type and randomization.

nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline.sample_timesteps(
batch_size: int,
device: typing.Optional[torch.device] = None
) -> typing.Tuple[torch.Tensor, torch.Tensor, str]

Sample timesteps and compute sigma values with flow shift.

Implements the flow shift transformation: ΃ = shift / (shift + (1/u - 1))

Parameters:

batch_size
int

Number of timesteps to sample

device
Optional[torch.device]Defaults to None

Device for tensor operations

Returns: torch.Tensor

Sigma values in [sigma_min, sigma_max]

nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline.step(
model: torch.nn.Module,
batch: typing.Dict[str, typing.Any],
device: torch.device = torch.device('cuda'),
dtype: torch.dtype = torch.bfloat16,
global_step: int = 0,
collect_metrics: bool = True,
check_loss: bool = True
) -> typing.Tuple[torch.Tensor, torch.Tensor, typing.Optional[torch.Tensor], typing.Dict[str, typing.Any]]

Execute a single training step with flow matching.

Expected batch format: { “video_latents”: torch.Tensor, # [B, C, F, H, W] for video OR “image_latents”: torch.Tensor, # [B, C, H, W] for image “text_embeddings”: torch.Tensor, # [B, seq_len, dim] “data_type”: str, # “video” or “image” (optional)

â€Ļ additional model-specific keys handled by adapter

}

Parameters:

model
nn.Module

The model to train

batch
Dict[str, Any]

Batch of training data

device
torch.deviceDefaults to torch.device('cuda')

Device to use

dtype
torch.dtypeDefaults to torch.bfloat16

Data type for operations

global_step
intDefaults to 0

Current training step (for logging)

collect_metrics
boolDefaults to True

Whether to collect scalar diagnostics. Disable in hot training paths to avoid host/device synchronizations.

check_loss
boolDefaults to True

Whether to run scalar loss explosion checks. Disable in hot training paths to avoid host/device synchronizations.

Returns: torch.Tensor

Per-element weighted loss

class nemo_automodel.components.flow_matching.pipeline.LinearInterpolationSchedule()

Simple linear interpolation schedule for flow matching.

nemo_automodel.components.flow_matching.pipeline.LinearInterpolationSchedule.forward(
x0: torch.Tensor,
x1: torch.Tensor,
sigma: torch.Tensor
) -> torch.Tensor

Linear interpolation: x_t = (1 - ΃) * x_0 + ΃ * x_1

Parameters:

x0
torch.Tensor

Starting point (clean latents)

x1
torch.Tensor

Ending point (noise)

sigma
torch.Tensor

Sigma values in [0, 1]

Returns: torch.Tensor

Interpolated tensor at sigma

nemo_automodel.components.flow_matching.pipeline.create_adapter(
adapter_type: str,
kwargs = {}
) -> nemo_automodel.components.flow_matching.adapters.ModelAdapter

Factory function to create a model adapter by name.

Parameters:

adapter_type
str

Type of adapter (“hunyuan”, “simple”, “flux”, “flux2”, “qwen_image”)

**kwargs
Defaults to {}

Additional arguments passed to the adapter constructor

Returns: ModelAdapter

ModelAdapter instance

nemo_automodel.components.flow_matching.pipeline.create_pipeline(
adapter_type: str,
adapter_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,
pipeline_kwargs = {}
) -> nemo_automodel.components.flow_matching.pipeline.FlowMatchingPipeline

Factory function to create a pipeline with a specific adapter.

Parameters:

adapter_type
str

Type of adapter (“hunyuan”, “simple”)

adapter_kwargs
Optional[Dict[str, Any]]Defaults to None

Arguments for the adapter constructor

**pipeline_kwargs
Defaults to {}

Arguments for the pipeline constructor

Returns: FlowMatchingPipeline

FlowMatchingPipeline instance

nemo_automodel.components.flow_matching.pipeline.logger = logging.getLogger(__name__)