bridge.diffusion.models.wan.wan_step#

Module Contents#

Classes#

Functions#

Data#

API#

bridge.diffusion.models.wan.wan_step.logger#

‘getLogger(…)’

bridge.diffusion.models.wan.wan_step.wan_data_step(qkv_format, dataloader_iter)#
class bridge.diffusion.models.wan.wan_step.WanForwardStep(
use_sigma_noise: bool = True,
timestep_sampling: str = 'uniform',
logit_mean: float = 0.0,
logit_std: float = 1.0,
flow_shift: float = 3.0,
mix_uniform_ratio: float = 0.1,
sigma_min: float = 0.0,
sigma_max: float = 1.0,
)#

Initialization

__call__(
state: megatron.bridge.training.state.GlobalState,
data_iterator: Iterable,
model: megatron.core.models.common.vision_module.vision_module.VisionModule,
) tuple[torch.Tensor, functools.partial]#

Forward training step.

_create_loss_function(
loss_mask: torch.Tensor,
check_for_nan_in_loss: bool,
check_for_spiky_loss: bool,
) functools.partial#

Create a partial loss function with the specified configuration.

Parameters:
  • loss_mask – Used to mask out some portions of the loss

  • check_for_nan_in_loss – Whether to check for NaN values in the loss

  • check_for_spiky_loss – Whether to check for spiky loss values

Returns:

A partial function that can be called with output_tensor to compute the loss