bridge.diffusion.models.wan.wan_layer_spec#
Module Contents#
Classes#
Adaptive Layer Normalization Module for DiT. |
|
A single transformer layer. |
Functions#
API#
- class bridge.diffusion.models.wan.wan_layer_spec.WanWithAdaLNSubmodules#
Bases:
megatron.core.transformer.transformer_layer.TransformerLayerSubmodules- temporal_self_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- full_self_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- norm1: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- norm3: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- norm2: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- class bridge.diffusion.models.wan.wan_layer_spec.WanAdaLN(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
Bases:
megatron.core.transformer.module.MegatronModuleAdaptive Layer Normalization Module for DiT.
Initialization
- forward(timestep_emb)#
- normalize_modulate(norm, hidden_states, shift, scale)#
- modulate(x, shift, scale)#
- scale_add(residual, x, gate)#
- class bridge.diffusion.models.wan.wan_layer_spec.WanLayerWithAdaLN(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: megatron.core.transformer.transformer_layer.TransformerLayerSubmodules,
- layer_number: int = 1,
- hidden_dropout: float = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
- vp_stage: Optional[int] = None,
Bases:
megatron.core.transformer.transformer_layer.TransformerLayerA single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an output of the same size.
DiT with Adapative Layer Normalization.
Initialization
- _mark_trainable_params_for_tp_grad_avg(
- modules: Optional[list] = None,
Mark selected modules’ trainable parameters to average gradients across TP domain.
- add_residual(x: torch.Tensor, residual: torch.Tensor) torch.Tensor#
- forward(
- hidden_states,
- attention_mask=None,
- context=None,
- context_mask=None,
- rotary_pos_emb=None,
- rotary_pos_cos=None,
- rotary_pos_sin=None,
- attention_bias=None,
- inference_params=None,
- packed_seq_params=None,
- sequence_len_offset=None,
- inference_context=None,
- rotary_pos_cos_sin=None,
- **kwargs,
- bridge.diffusion.models.wan.wan_layer_spec.get_wan_block_with_transformer_engine_spec() megatron.core.transformer.spec_utils.ModuleSpec#