bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block#

Copied from https://github.com/Thaurun/mbridge/blob/4462d1e284626d2ed9d3e3e 3e5a40f2ee42a2c74/mbridge/models/qwen3_vl/transformer_block.py

Module Contents#

Classes#

Qwen3VLTransformerBlock

Transformer class.

Data#

API#

bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block.te_checkpoint#

None

class bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block.Qwen3VLTransformerBlock#

Bases: megatron.core.transformer.transformer_block.TransformerBlock

Transformer class.

_checkpointed_forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
context: torch.Tensor,
context_mask: torch.Tensor,
rotary_pos_emb: torch.Tensor,
attention_bias: torch.Tensor,
packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams,
use_inner_fp8_context: bool,
visual_pos_masks: Optional[torch.Tensor] = None,
deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
)#

Forward method with activation checkpointing.

forward(
hidden_states: Union[torch.Tensor, megatron.core.utils.WrappedTensor],
attention_mask: Optional[torch.Tensor],
context: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
sequence_len_offset: Optional[torch.Tensor] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
visual_pos_masks: Optional[torch.Tensor] = None,
deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
)#

Perform the forward pass through the transformer block.

This method handles the core computation of the transformer, including self-attention, optional cross-attention, and feed-forward operations.

Parameters:
  • hidden_states (Union[Tensor, WrappedTensor]) – Input tensor of shape [s, b, h] where s is the sequence length, b is the batch size, and h is the hidden size. Can be passed as a WrappedTensor during inference to avoid an obsolete reference in the calling function.

  • attention_mask (Tensor) – Boolean tensor of shape [1, 1, s, s] for masking self-attention.

  • context (Tensor, optional) – Context tensor for cross-attention.

  • context_mask (Tensor, optional) – Mask for cross-attention context

  • rotary_pos_emb (Tensor, optional) – Rotary positional embeddings.

  • attention_bias (Tensor) – Bias tensor for Q * K.T of shape in shape broadcastable to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. Used as an alternative to apply attention mask for TE cuDNN attention.

  • inference_context (BaseInferenceContext, optional) – Parameters for inference-time optimizations.

  • packed_seq_params (PackedSeqParams, optional) – Parameters for packed sequence processing.

Returns:

The output hidden states tensor of shape [s, b, h], and optionally the updated context tensor if cross-attention is used.

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

_deepstack_process(
hidden_states: torch.Tensor,
visual_pos_masks: torch.Tensor,
visual_embeds: torch.Tensor,
)#