nemo_automodel.components.flow_matching.time_shift_utils#

Module Contents#

Functions#

time_shift

Convert timesteps to sigmas with sequence-length-aware shifting.

compute_density_for_timestep_sampling

Sample timesteps from different distributions for better training coverage.

get_flow_match_loss_weight

Compute loss weights for flow matching based on sigma values.

API#

nemo_automodel.components.flow_matching.time_shift_utils.time_shift(
t: torch.Tensor,
image_seq_len: int,
shift_type: str = 'constant',
base_shift: float = 0.5,
max_shift: float = 1.15,
constant: float = 3.0,
)#

Convert timesteps to sigmas with sequence-length-aware shifting.

Parameters:
  • t – timesteps in range [0, 1]

  • image_seq_len – number of tokens (frames * height * width / patch_size^2)

  • shift_type – “linear”, “sqrt”, or “constant”

  • base_shift – base shift for linear mode

  • max_shift – max shift for linear mode

  • constant – shift value for constant mode (default 3.0 matches Pika)

Returns:

sigma values for noise scheduling

nemo_automodel.components.flow_matching.time_shift_utils.compute_density_for_timestep_sampling(
weighting_scheme: str,
batch_size: int,
logit_mean: float = 0.0,
logit_std: float = 1.0,
mode_scale: float = 1.29,
)#

Sample timesteps from different distributions for better training coverage.

Parameters:
  • weighting_scheme – “uniform”, “logit_normal”, or “mode”

  • batch_size – number of samples to generate

  • logit_mean – mean for logit-normal distribution

  • logit_std – std for logit-normal distribution

  • mode_scale – scale for mode-based sampling

Returns:

Tensor of shape (batch_size,) with values in [0, 1]

nemo_automodel.components.flow_matching.time_shift_utils.get_flow_match_loss_weight(sigma: torch.Tensor, shift: float = 3.0)#

Compute loss weights for flow matching based on sigma values.

Higher sigma (more noise) typically gets higher weight.

Parameters:
  • sigma – sigma values in range [0, 1]

  • shift – weight scaling factor

Returns:

Loss weights with same shape as sigma