bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block#

Copied from https://github.com/Thaurun/mbridge/blob/4462d1e284626d2ed9d3e3e 3e5a40f2ee42a2c74/mbridge/models/qwen3_vl/transformer_block.py

Module Contents#

Classes#

Qwen3VLVisionTransformerBlock

Vision Transformer Block for Qwen3VL vision model.

Qwen3VLTransformerBlock

Transformer Block for Qwen3VL model.

Data#

API#

bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block.te_checkpoint#

None

class bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block.Qwen3VLVisionTransformerBlock(
config: megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config.Qwen3VLTransformerConfig,
spec: Union[megatron.core.transformer.transformer_block.TransformerBlockSubmodules, megatron.core.transformer.spec_utils.ModuleSpec],
post_layer_norm: bool = True,
pre_process: bool = True,
post_process: bool = True,
vp_stage: Optional[int] = None,
patch_merger_spec: megatron.core.transformer.spec_utils.ModuleSpec = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.transformer_block.TransformerBlock

Vision Transformer Block for Qwen3VL vision model.

Initialization

_checkpointed_forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
context: torch.Tensor,
context_mask: torch.Tensor,
rotary_pos_emb: torch.Tensor,
attention_bias: torch.Tensor,
packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams,
use_inner_fp8_context: bool,
)#

Forward method with activation checkpointing.

forward(
hidden_states: Union[torch.Tensor, megatron.core.utils.WrappedTensor],
attention_mask: Optional[torch.Tensor],
context: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
sequence_len_offset: Optional[torch.Tensor] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
)#

Perform the forward pass through the transformer block.

This method handles the core computation of the transformer, including self-attention, optional cross-attention, and feed-forward operations.

Parameters:
  • hidden_states (Union[Tensor, WrappedTensor]) – Input tensor of shape [s, b, h] where s is the sequence length, b is the batch size, and h is the hidden size. Can be passed as a WrappedTensor during inference to avoid an obsolete reference in the calling function.

  • attention_mask (Tensor) – Boolean tensor of shape [1, 1, s, s] for masking self-attention.

  • context (Tensor, optional) – Context tensor for cross-attention.

  • context_mask (Tensor, optional) – Mask for cross-attention context

  • rotary_pos_emb (Tensor, optional) – Rotary positional embeddings.

  • attention_bias (Tensor) – Bias tensor for Q * K.T of shape in shape broadcastable to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. Used as an alternative to apply attention mask for TE cuDNN attention.

  • inference_context (BaseInferenceContext, optional) – Parameters for inference-time optimizations.

  • packed_seq_params (PackedSeqParams, optional) – Parameters for packed sequence processing.

Returns:

The output hidden states tensor of shape [s, b, h], and optionally the updated context tensor if cross-attention is used.

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

sharded_state_dict(
prefix: str = '',
sharded_offsets: tuple = (),
metadata: dict = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Generate a sharded state dictionary for the transformer block.

Parameters:
  • prefix (str, optional) – Prefix to be added to all keys in the state dict. Defaults to an empty string.

  • sharded_offsets (tuple, optional) – Tuple of sharding offsets.

  • metadata (dict, optional) – Additional metadata for sharding. Can specify if layers are non-homogeneous. Defaults to None.

Returns:

A dictionary containing the sharded state of the model.

Return type:

ShardedStateDict

class bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block.Qwen3VLTransformerBlock(
config: megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config.Qwen3VLTransformerConfig,
spec: Union[megatron.core.transformer.transformer_block.TransformerBlockSubmodules, megatron.core.transformer.spec_utils.ModuleSpec],
post_layer_norm: bool = True,
pre_process: bool = True,
post_process: bool = True,
vp_stage: Optional[int] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.transformer_block.TransformerBlock

Transformer Block for Qwen3VL model.

Initialization

_checkpointed_forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
context: torch.Tensor,
context_mask: torch.Tensor,
rotary_pos_emb: torch.Tensor,
attention_bias: torch.Tensor,
packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams,
use_inner_fp8_context: bool,
visual_pos_masks: Optional[torch.Tensor] = None,
deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
)#

Forward method with activation checkpointing.

forward(
hidden_states: Union[torch.Tensor, megatron.core.utils.WrappedTensor],
attention_mask: Optional[torch.Tensor],
context: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
sequence_len_offset: Optional[torch.Tensor] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
visual_pos_masks: Optional[torch.Tensor] = None,
deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
)#

Perform the forward pass through the transformer block.

This method handles the core computation of the transformer, including self-attention, optional cross-attention, and feed-forward operations.

Parameters:
  • hidden_states (Union[Tensor, WrappedTensor]) – Input tensor of shape [s, b, h] where s is the sequence length, b is the batch size, and h is the hidden size. Can be passed as a WrappedTensor during inference to avoid an obsolete reference in the calling function.

  • attention_mask (Tensor) – Boolean tensor of shape [1, 1, s, s] for masking self-attention.

  • context (Tensor, optional) – Context tensor for cross-attention.

  • context_mask (Tensor, optional) – Mask for cross-attention context

  • rotary_pos_emb (Tensor, optional) – Rotary positional embeddings.

  • attention_bias (Tensor) – Bias tensor for Q * K.T of shape in shape broadcastable to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. Used as an alternative to apply attention mask for TE cuDNN attention.

  • inference_context (BaseInferenceContext, optional) – Parameters for inference-time optimizations.

  • packed_seq_params (PackedSeqParams, optional) – Parameters for packed sequence processing.

Returns:

The output hidden states tensor of shape [s, b, h], and optionally the updated context tensor if cross-attention is used.

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

_deepstack_process(
hidden_states: torch.Tensor,
visual_pos_masks: torch.Tensor,
visual_embeds: torch.Tensor,
)#