bridge.diffusion.models.wan.wan_layer_spec#

Module Contents#

Classes#

WanWithAdaLNSubmodules

WanAdaLN

Adaptive Layer Normalization Module for DiT.

WanLayerWithAdaLN

A single transformer layer.

Functions#

API#

class bridge.diffusion.models.wan.wan_layer_spec.WanWithAdaLNSubmodules#

Bases: megatron.core.transformer.transformer_layer.TransformerLayerSubmodules

temporal_self_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

full_self_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

norm1: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

norm3: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

norm2: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

class bridge.diffusion.models.wan.wan_layer_spec.WanAdaLN(
config: megatron.core.transformer.transformer_config.TransformerConfig,
)#

Bases: megatron.core.transformer.module.MegatronModule

Adaptive Layer Normalization Module for DiT.

Initialization

forward(timestep_emb)#
normalize_modulate(norm, hidden_states, shift, scale)#
modulate(x, shift, scale)#
scale_add(residual, x, gate)#
class bridge.diffusion.models.wan.wan_layer_spec.WanLayerWithAdaLN(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.transformer_layer.TransformerLayerSubmodules,
layer_number: int = 1,
hidden_dropout: float = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
vp_stage: Optional[int] = None,
)#

Bases: megatron.core.transformer.transformer_layer.TransformerLayer

A single transformer layer.

Transformer layer takes input with size [s, b, h] and returns an output of the same size.

DiT with Adapative Layer Normalization.

Initialization

_mark_trainable_params_for_tp_grad_avg(
modules: Optional[list] = None,
) None#

Mark selected modules’ trainable parameters to average gradients across TP domain.

add_residual(x: torch.Tensor, residual: torch.Tensor) torch.Tensor#
forward(
hidden_states,
attention_mask=None,
context=None,
context_mask=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
inference_params=None,
packed_seq_params=None,
sequence_len_offset=None,
inference_context=None,
rotary_pos_cos_sin=None,
**kwargs,
)#
bridge.diffusion.models.wan.wan_layer_spec.get_wan_block_with_transformer_engine_spec() megatron.core.transformer.spec_utils.ModuleSpec#