bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block#
Copied from https://github.com/Thaurun/mbridge/blob/4462d1e284626d2ed9d3e3e 3e5a40f2ee42a2c74/mbridge/models/qwen3_vl/transformer_block.py
Module Contents#
Classes#
Transformer class. |
Data#
API#
- bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block.te_checkpoint#
None
- class bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block.Qwen3VLTransformerBlock#
Bases:
megatron.core.transformer.transformer_block.TransformerBlockTransformer class.
- _checkpointed_forward(
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- context: torch.Tensor,
- context_mask: torch.Tensor,
- rotary_pos_emb: torch.Tensor,
- attention_bias: torch.Tensor,
- packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams,
- use_inner_fp8_context: bool,
- visual_pos_masks: Optional[torch.Tensor] = None,
- deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
Forward method with activation checkpointing.
- forward(
- hidden_states: Union[torch.Tensor, megatron.core.utils.WrappedTensor],
- attention_mask: Optional[torch.Tensor],
- context: Optional[torch.Tensor] = None,
- context_mask: Optional[torch.Tensor] = None,
- rotary_pos_emb: Optional[torch.Tensor] = None,
- rotary_pos_cos: Optional[torch.Tensor] = None,
- rotary_pos_sin: Optional[torch.Tensor] = None,
- attention_bias: Optional[torch.Tensor] = None,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- sequence_len_offset: Optional[torch.Tensor] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- visual_pos_masks: Optional[torch.Tensor] = None,
- deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
Perform the forward pass through the transformer block.
This method handles the core computation of the transformer, including self-attention, optional cross-attention, and feed-forward operations.
- Parameters:
hidden_states (Union[Tensor, WrappedTensor]) – Input tensor of shape [s, b, h] where s is the sequence length, b is the batch size, and h is the hidden size. Can be passed as a WrappedTensor during inference to avoid an obsolete reference in the calling function.
attention_mask (Tensor) – Boolean tensor of shape [1, 1, s, s] for masking self-attention.
context (Tensor, optional) – Context tensor for cross-attention.
context_mask (Tensor, optional) – Mask for cross-attention context
rotary_pos_emb (Tensor, optional) – Rotary positional embeddings.
attention_bias (Tensor) – Bias tensor for Q * K.T of shape in shape broadcastable to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. Used as an alternative to apply attention mask for TE cuDNN attention.
inference_context (BaseInferenceContext, optional) – Parameters for inference-time optimizations.
packed_seq_params (PackedSeqParams, optional) – Parameters for packed sequence processing.
- Returns:
The output hidden states tensor of shape [s, b, h], and optionally the updated context tensor if cross-attention is used.
- Return type:
Union[Tensor, Tuple[Tensor, Tensor]]
- _deepstack_process(
- hidden_states: torch.Tensor,
- visual_pos_masks: torch.Tensor,
- visual_embeds: torch.Tensor,