bridge.diffusion.models.wan.wan_model#
Module Contents#
Classes#
Functions#
API#
- bridge.diffusion.models.wan.wan_model.sinusoidal_embedding_1d(dim, position)#
- class bridge.diffusion.models.wan.wan_model.Head(dim, out_dim, patch_size, eps=1e-06)#
Bases:
torch.nn.ModuleInitialization
- forward(x, e)#
- Parameters:
x (Tensor) β Shape [B, L1, C]
e (Tensor) β Shape [B, C]
- class bridge.diffusion.models.wan.wan_model.WanModel(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- pre_process: bool = True,
- post_process: bool = True,
- fp16_lm_cross_entropy: bool = False,
- parallel_output: bool = True,
- transformer_decoder_layer_spec=WanLayerWithAdaLNspec,
- **kwargs,
Bases:
megatron.core.models.common.vision_module.vision_module.VisionModuleWanModel is a VisionModule that implements a Wan model. .. attribute:: config
Configuration for the transformer.
- Type:
.. attribute:: pre_process
Whether to apply pre-processing steps.
- Type:
bool
.. attribute:: post_process
Whether to apply post-processing steps.
- Type:
bool
.. attribute:: fp16_lm_cross_entropy
Whether to use fp16 for cross-entropy loss.
- Type:
bool
.. attribute:: parallel_output
Whether to use parallel output.
- Type:
bool
.. attribute:: transformer_decoder_layer_spec
Specification for the transformer decoder layer.
- Type:
WanLayerWithAdaLNspec
.. attribute:: model_type
Type of the model.
- Type:
Initialization
- forward(
- x: torch.Tensor,
- grid_sizes: list[Tuple[int, int, int]],
- t: torch.Tensor,
- context: torch.Tensor,
- packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams = None,
- **kwargs,
Forward pass.
- Parameters:
List[Tensor] (context) β list of vae encoded data (in_channel, f, h, w)
List[Tuple[int (grid_sizes) β list of grid sizes (f, h, w)
int β list of grid sizes (f, h, w)
int]] β list of grid sizes (f, h, w)
Tensor (t) β timesteps
List[Tensor] β list of context (text_len, hidden_size)
PackedSeqParams (packed_seq_params) β packed sequence parameters
- Returns:
output tensor (still patchified) of shape [seq_len, batch_size, hidden_size]
- Return type:
Tensor
- set_input_tensor(input_tensor: torch.Tensor) None#
Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
- Parameters:
input_tensor (Tensor) β Sets the input tensor for the model.
- sharded_state_dict(
- prefix: str = 'module.',
- sharded_offsets: tuple = (),
- metadata: Optional[Dict] = None,
Sharded state dict implementation for GPTModel backward-compatibility (removing extra state).
- Parameters:
prefix (str) β Module name prefix.
sharded_offsets (tuple) β PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]) β metadata controlling sharded state dict creation.
- Returns:
sharded state dict for the GPTModel
- Return type:
ShardedStateDict
- _mark_trainable_params_for_tp_grad_avg(
- modules: Optional[list] = None,
Mark selected modulesβ trainable parameters to average gradients across TP domain.
- _set_embedder_weights_replica_id(
- tensor: torch.Tensor,
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- embedder_weight_key: str,
set replica ids of the weights in t_embedder for sharded state dict.
- Parameters:
sharded_state_dict (ShardedStateDict) β state dict with the weight to tie
weight_key (str) β key of the weight in the state dict. This entry will be replaced with a tied version
Returns: None, acts in-place