bridge.diffusion.models.wan.wan_model#

Module Contents#

Classes#

Head

WanModel

WanModel is a VisionModule that implements a Wan model. .. attribute:: config

Functions#

API#

bridge.diffusion.models.wan.wan_model.sinusoidal_embedding_1d(dim, position)#
class bridge.diffusion.models.wan.wan_model.Head(dim, out_dim, patch_size, eps=1e-06)#

Bases: torch.nn.Module

Initialization

forward(x, e)#
Parameters:
  • x (Tensor) – Shape [B, L1, C]

  • e (Tensor) – Shape [B, C]

class bridge.diffusion.models.wan.wan_model.WanModel(
config: megatron.core.transformer.transformer_config.TransformerConfig,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
transformer_decoder_layer_spec=WanLayerWithAdaLNspec,
**kwargs,
)#

Bases: megatron.core.models.common.vision_module.vision_module.VisionModule

WanModel is a VisionModule that implements a Wan model. .. attribute:: config

Configuration for the transformer.

Type:

TransformerConfig

.. attribute:: pre_process

Whether to apply pre-processing steps.

Type:

bool

.. attribute:: post_process

Whether to apply post-processing steps.

Type:

bool

.. attribute:: fp16_lm_cross_entropy

Whether to use fp16 for cross-entropy loss.

Type:

bool

.. attribute:: parallel_output

Whether to use parallel output.

Type:

bool

.. attribute:: transformer_decoder_layer_spec

Specification for the transformer decoder layer.

Type:

WanLayerWithAdaLNspec

.. attribute:: model_type

Type of the model.

Type:

ModelType

Initialization

forward(
x: torch.Tensor,
grid_sizes: list[Tuple[int, int, int]],
t: torch.Tensor,
context: torch.Tensor,
packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams = None,
**kwargs,
) torch.Tensor#

Forward pass.

Parameters:
  • List[Tensor] (context) – list of vae encoded data (in_channel, f, h, w)

  • List[Tuple[int (grid_sizes) – list of grid sizes (f, h, w)

  • int – list of grid sizes (f, h, w)

  • int]] – list of grid sizes (f, h, w)

  • Tensor (t) – timesteps

  • List[Tensor] – list of context (text_len, hidden_size)

  • PackedSeqParams (packed_seq_params) – packed sequence parameters

Returns:

output tensor (still patchified) of shape [seq_len, batch_size, hidden_size]

Return type:

Tensor

set_input_tensor(input_tensor: torch.Tensor) None#

Sets input tensor to the model.

See megatron.model.transformer.set_input_tensor()

Parameters:

input_tensor (Tensor) – Sets the input tensor for the model.

sharded_state_dict(
prefix: str = 'module.',
sharded_offsets: tuple = (),
metadata: Optional[Dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Sharded state dict implementation for GPTModel backward-compatibility (removing extra state).

Parameters:
  • prefix (str) – Module name prefix.

  • sharded_offsets (tuple) – PP related offsets, expected to be empty at this module level.

  • metadata (Optional[Dict]) – metadata controlling sharded state dict creation.

Returns:

sharded state dict for the GPTModel

Return type:

ShardedStateDict

_mark_trainable_params_for_tp_grad_avg(
modules: Optional[list] = None,
) None#

Mark selected modules’ trainable parameters to average gradients across TP domain.

_set_embedder_weights_replica_id(
tensor: torch.Tensor,
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
embedder_weight_key: str,
) None#

set replica ids of the weights in t_embedder for sharded state dict.

Parameters:
  • sharded_state_dict (ShardedStateDict) – state dict with the weight to tie

  • weight_key (str) – key of the weight in the state dict. This entry will be replaced with a tied version

Returns: None, acts in-place