nemo_automodel.components.flow_matching.time_shift_utils#
Module Contents#
Functions#
Convert timesteps to sigmas with sequence-length-aware shifting. |
|
Sample timesteps from different distributions for better training coverage. |
|
Compute loss weights for flow matching based on sigma values. |
API#
- nemo_automodel.components.flow_matching.time_shift_utils.time_shift(
- t: torch.Tensor,
- image_seq_len: int,
- shift_type: str = 'constant',
- base_shift: float = 0.5,
- max_shift: float = 1.15,
- constant: float = 3.0,
Convert timesteps to sigmas with sequence-length-aware shifting.
- Parameters:
t – timesteps in range [0, 1]
image_seq_len – number of tokens (frames * height * width / patch_size^2)
shift_type – “linear”, “sqrt”, or “constant”
base_shift – base shift for linear mode
max_shift – max shift for linear mode
constant – shift value for constant mode (default 3.0 matches Pika)
- Returns:
sigma values for noise scheduling
- nemo_automodel.components.flow_matching.time_shift_utils.compute_density_for_timestep_sampling(
- weighting_scheme: str,
- batch_size: int,
- logit_mean: float = 0.0,
- logit_std: float = 1.0,
- mode_scale: float = 1.29,
Sample timesteps from different distributions for better training coverage.
- Parameters:
weighting_scheme – “uniform”, “logit_normal”, or “mode”
batch_size – number of samples to generate
logit_mean – mean for logit-normal distribution
logit_std – std for logit-normal distribution
mode_scale – scale for mode-based sampling
- Returns:
Tensor of shape (batch_size,) with values in [0, 1]
- nemo_automodel.components.flow_matching.time_shift_utils.get_flow_match_loss_weight(sigma: torch.Tensor, shift: float = 3.0)#
Compute loss weights for flow matching based on sigma values.
Higher sigma (more noise) typically gets higher weight.
- Parameters:
sigma – sigma values in range [0, 1]
shift – weight scaling factor
- Returns:
Loss weights with same shape as sigma