bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block#
Copied from https://github.com/Thaurun/mbridge/blob/4462d1e284626d2ed9d3e3e 3e5a40f2ee42a2c74/mbridge/models/qwen3_vl/transformer_block.py
Module Contents#
Classes#
Vision Transformer Block for Qwen3VL vision model. |
|
Transformer Block for Qwen3VL model. |
Data#
API#
- bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block.te_checkpoint#
None
- class bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block.Qwen3VLVisionTransformerBlock(
- config: megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config.Qwen3VLTransformerConfig,
- spec: Union[megatron.core.transformer.transformer_block.TransformerBlockSubmodules, megatron.core.transformer.spec_utils.ModuleSpec],
- post_layer_norm: bool = True,
- pre_process: bool = True,
- post_process: bool = True,
- vp_stage: Optional[int] = None,
- patch_merger_spec: megatron.core.transformer.spec_utils.ModuleSpec = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Bases:
megatron.core.transformer.transformer_block.TransformerBlockVision Transformer Block for Qwen3VL vision model.
Initialization
- _checkpointed_forward(
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- context: torch.Tensor,
- context_mask: torch.Tensor,
- rotary_pos_emb: torch.Tensor,
- attention_bias: torch.Tensor,
- packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams,
- use_inner_fp8_context: bool,
Forward method with activation checkpointing.
- forward(
- hidden_states: Union[torch.Tensor, megatron.core.utils.WrappedTensor],
- attention_mask: Optional[torch.Tensor],
- context: Optional[torch.Tensor] = None,
- context_mask: Optional[torch.Tensor] = None,
- rotary_pos_emb: Optional[torch.Tensor] = None,
- rotary_pos_cos: Optional[torch.Tensor] = None,
- rotary_pos_sin: Optional[torch.Tensor] = None,
- attention_bias: Optional[torch.Tensor] = None,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- sequence_len_offset: Optional[torch.Tensor] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Perform the forward pass through the transformer block.
This method handles the core computation of the transformer, including self-attention, optional cross-attention, and feed-forward operations.
- Parameters:
hidden_states (Union[Tensor, WrappedTensor]) – Input tensor of shape [s, b, h] where s is the sequence length, b is the batch size, and h is the hidden size. Can be passed as a WrappedTensor during inference to avoid an obsolete reference in the calling function.
attention_mask (Tensor) – Boolean tensor of shape [1, 1, s, s] for masking self-attention.
context (Tensor, optional) – Context tensor for cross-attention.
context_mask (Tensor, optional) – Mask for cross-attention context
rotary_pos_emb (Tensor, optional) – Rotary positional embeddings.
attention_bias (Tensor) – Bias tensor for Q * K.T of shape in shape broadcastable to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. Used as an alternative to apply attention mask for TE cuDNN attention.
inference_context (BaseInferenceContext, optional) – Parameters for inference-time optimizations.
packed_seq_params (PackedSeqParams, optional) – Parameters for packed sequence processing.
- Returns:
The output hidden states tensor of shape [s, b, h], and optionally the updated context tensor if cross-attention is used.
- Return type:
Union[Tensor, Tuple[Tensor, Tensor]]
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: tuple = (),
- metadata: dict = None,
Generate a sharded state dictionary for the transformer block.
- Parameters:
prefix (str, optional) – Prefix to be added to all keys in the state dict. Defaults to an empty string.
sharded_offsets (tuple, optional) – Tuple of sharding offsets.
metadata (dict, optional) – Additional metadata for sharding. Can specify if layers are non-homogeneous. Defaults to None.
- Returns:
A dictionary containing the sharded state of the model.
- Return type:
ShardedStateDict
- class bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block.Qwen3VLTransformerBlock(
- config: megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config.Qwen3VLTransformerConfig,
- spec: Union[megatron.core.transformer.transformer_block.TransformerBlockSubmodules, megatron.core.transformer.spec_utils.ModuleSpec],
- post_layer_norm: bool = True,
- pre_process: bool = True,
- post_process: bool = True,
- vp_stage: Optional[int] = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Bases:
megatron.core.transformer.transformer_block.TransformerBlockTransformer Block for Qwen3VL model.
Initialization
- _checkpointed_forward(
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- context: torch.Tensor,
- context_mask: torch.Tensor,
- rotary_pos_emb: torch.Tensor,
- attention_bias: torch.Tensor,
- packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams,
- use_inner_fp8_context: bool,
- visual_pos_masks: Optional[torch.Tensor] = None,
- deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
Forward method with activation checkpointing.
- forward(
- hidden_states: Union[torch.Tensor, megatron.core.utils.WrappedTensor],
- attention_mask: Optional[torch.Tensor],
- context: Optional[torch.Tensor] = None,
- context_mask: Optional[torch.Tensor] = None,
- rotary_pos_emb: Optional[torch.Tensor] = None,
- rotary_pos_cos: Optional[torch.Tensor] = None,
- rotary_pos_sin: Optional[torch.Tensor] = None,
- attention_bias: Optional[torch.Tensor] = None,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- sequence_len_offset: Optional[torch.Tensor] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- visual_pos_masks: Optional[torch.Tensor] = None,
- deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
Perform the forward pass through the transformer block.
This method handles the core computation of the transformer, including self-attention, optional cross-attention, and feed-forward operations.
- Parameters:
hidden_states (Union[Tensor, WrappedTensor]) – Input tensor of shape [s, b, h] where s is the sequence length, b is the batch size, and h is the hidden size. Can be passed as a WrappedTensor during inference to avoid an obsolete reference in the calling function.
attention_mask (Tensor) – Boolean tensor of shape [1, 1, s, s] for masking self-attention.
context (Tensor, optional) – Context tensor for cross-attention.
context_mask (Tensor, optional) – Mask for cross-attention context
rotary_pos_emb (Tensor, optional) – Rotary positional embeddings.
attention_bias (Tensor) – Bias tensor for Q * K.T of shape in shape broadcastable to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. Used as an alternative to apply attention mask for TE cuDNN attention.
inference_context (BaseInferenceContext, optional) – Parameters for inference-time optimizations.
packed_seq_params (PackedSeqParams, optional) – Parameters for packed sequence processing.
- Returns:
The output hidden states tensor of shape [s, b, h], and optionally the updated context tensor if cross-attention is used.
- Return type:
Union[Tensor, Tuple[Tensor, Tensor]]
- _deepstack_process(
- hidden_states: torch.Tensor,
- visual_pos_masks: torch.Tensor,
- visual_embeds: torch.Tensor,