bridge.diffusion.models.wan.wan_step#
Module Contents#
Classes#
Functions#
Data#
API#
- bridge.diffusion.models.wan.wan_step.logger#
‘getLogger(…)’
- bridge.diffusion.models.wan.wan_step.wan_data_step(qkv_format, dataloader_iter)#
- class bridge.diffusion.models.wan.wan_step.WanForwardStep(
- use_sigma_noise: bool = True,
- timestep_sampling: str = 'uniform',
- logit_mean: float = 0.0,
- logit_std: float = 1.0,
- flow_shift: float = 3.0,
- mix_uniform_ratio: float = 0.1,
- sigma_min: float = 0.0,
- sigma_max: float = 1.0,
Initialization
- __call__(
- state: megatron.bridge.training.state.GlobalState,
- data_iterator: Iterable,
- model: megatron.core.models.common.vision_module.vision_module.VisionModule,
Forward training step.
- _create_loss_function(
- loss_mask: torch.Tensor,
- check_for_nan_in_loss: bool,
- check_for_spiky_loss: bool,
Create a partial loss function with the specified configuration.
- Parameters:
loss_mask – Used to mask out some portions of the loss
check_for_nan_in_loss – Whether to check for NaN values in the loss
check_for_spiky_loss – Whether to check for spiky loss values
- Returns:
A partial function that can be called with output_tensor to compute the loss