nemo_automodel.components.flow_matching.time_shift_utils

View as Markdown

Module Contents

Functions

NameDescription
compute_density_for_timestep_samplingSample timesteps from different distributions for better training coverage.
get_flow_match_loss_weightCompute loss weights for flow matching based on sigma values.
time_shiftConvert timesteps to sigmas with sequence-length-aware shifting.

API

nemo_automodel.components.flow_matching.time_shift_utils.compute_density_for_timestep_sampling(
weighting_scheme: str,
batch_size: int,
logit_mean: float = 0.0,
logit_std: float = 1.0,
mode_scale: float = 1.29
)

Sample timesteps from different distributions for better training coverage.

Parameters:

weighting_scheme
str

“uniform”, “logit_normal”, or “mode”

batch_size
int

number of samples to generate

logit_mean
floatDefaults to 0.0

mean for logit-normal distribution

logit_std
floatDefaults to 1.0

std for logit-normal distribution

mode_scale
floatDefaults to 1.29

scale for mode-based sampling

Returns:

Tensor of shape (batch_size,) with values in [0, 1]

nemo_automodel.components.flow_matching.time_shift_utils.get_flow_match_loss_weight(
sigma: torch.Tensor,
shift: float = 3.0
)

Compute loss weights for flow matching based on sigma values.

Higher sigma (more noise) typically gets higher weight.

Parameters:

sigma
torch.Tensor

sigma values in range [0, 1]

shift
floatDefaults to 3.0

weight scaling factor

Returns:

Loss weights with same shape as sigma

nemo_automodel.components.flow_matching.time_shift_utils.time_shift(
t: torch.Tensor,
image_seq_len: int,
shift_type: str = 'constant',
base_shift: float = 0.5,
max_shift: float = 1.15,
constant: float = 3.0
)

Convert timesteps to sigmas with sequence-length-aware shifting.

Parameters:

t
torch.Tensor

timesteps in range [0, 1]

image_seq_len
int

number of tokens (frames * height * width / patch_size^2)

shift_type
strDefaults to 'constant'

“linear”, “sqrt”, or “constant”

base_shift
floatDefaults to 0.5

base shift for linear mode

max_shift
floatDefaults to 1.15

max shift for linear mode

constant
floatDefaults to 3.0

shift value for constant mode (default 3.0 matches Pika)

Returns:

sigma values for noise scheduling