bridge.models.qwen_omni.modeling_qwen3_omni.thinker_model#
Module Contents#
Classes#
Qwen3-Omni thinker model. |
Functions#
Match ms-swift’s tower patching for gradient-checkpoint input hooks. |
|
Create text-only multimodal rope ids shaped [3, batch, seq]. |
|
Apply a requested attention implementation to HF multimodal configs. |
|
Best-effort enable gradient checkpointing for HF multimodal towers. |
|
API#
- bridge.models.qwen_omni.modeling_qwen3_omni.thinker_model._deep_getattr(
- module: torch.nn.Module,
- attr_path: str,
- bridge.models.qwen_omni.modeling_qwen3_omni.thinker_model._patch_get_input_embeddings(
- module: torch.nn.Module,
- attr_path: str,
Match ms-swift’s tower patching for gradient-checkpoint input hooks.
- bridge.models.qwen_omni.modeling_qwen3_omni.thinker_model._build_text_only_mrope_position_ids(
- input_ids: torch.Tensor,
Create text-only multimodal rope ids shaped [3, batch, seq].
- bridge.models.qwen_omni.modeling_qwen3_omni.thinker_model._configure_multimodal_attn_impl(
- config: object,
- attn_impl: str | None,
Apply a requested attention implementation to HF multimodal configs.
- bridge.models.qwen_omni.modeling_qwen3_omni.thinker_model._enable_multimodal_gradient_checkpointing(
- module: torch.nn.Module,
Best-effort enable gradient checkpointing for HF multimodal towers.
- bridge.models.qwen_omni.modeling_qwen3_omni.thinker_model._trim_feature_sequence(
- features: torch.Tensor | None,
- multiscale_features: list[torch.Tensor] | None,
- expected_tokens: int,
- feature_name: str,
- bridge.models.qwen_omni.modeling_qwen3_omni.thinker_model._normalize_visual_outputs(
- outputs: object,
- class bridge.models.qwen_omni.modeling_qwen3_omni.thinker_model.Qwen3OmniThinkerModel(
- language_transformer_config: megatron.bridge.models.qwen_omni.modeling_qwen3_omni.transformer_config.Qwen3OmniTransformerConfig,
- language_transformer_layer_spec: megatron.core.transformer.spec_utils.ModuleSpec,
- thinker_transformer_config: transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe.Qwen3OmniMoeThinkerConfig,
- parallel_output: bool = True,
- pre_process: bool = True,
- post_process: bool = True,
- add_encoder: bool = True,
- add_decoder: bool = True,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection | None = None,
Bases:
megatron.core.transformer.MegatronModuleQwen3-Omni thinker model.
The current implementation supports multimodal thinker-side forward paths for text, vision, and audio inputs.
Initialization
- set_input_tensor(input_tensor) None#
- freeze(
- freeze_language_model: bool = False,
- freeze_vision_model: bool = False,
- freeze_audio_model: bool = False,
- static _get_placeholder_mask(
- input_ids: torch.LongTensor,
- inputs_embeds: torch.FloatTensor,
- image_features: torch.FloatTensor | None = None,
- video_features: torch.FloatTensor | None = None,
- image_token_id: int = 151655,
- video_token_id: int = 151656,
- get_image_features(
- pixel_values: torch.FloatTensor,
- image_grid_thw: torch.LongTensor,
- get_video_features(
- pixel_values_videos: torch.FloatTensor,
- video_grid_thw: torch.LongTensor,
- get_audio_features(
- input_features: torch.FloatTensor,
- feature_attention_mask: torch.LongTensor | None = None,
- audio_feature_lengths: torch.LongTensor | None = None,
- expected_audio_token_counts: torch.LongTensor | None = None,
- forward(
- input_ids: torch.Tensor,
- position_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- loss_mask: torch.Tensor | None = None,
- inference_params: megatron.core.InferenceParams | None = None,
- packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams | None = None,
- extra_block_kwargs: dict | None = None,
- runtime_gather_output: bool | None = None,
- pixel_values: torch.Tensor | None = None,
- pixel_values_videos: torch.Tensor | None = None,
- image_grid_thw: torch.Tensor | None = None,
- video_grid_thw: torch.Tensor | None = None,
- video_second_per_grid: torch.Tensor | None = None,
- input_features: torch.Tensor | None = None,
- feature_attention_mask: torch.Tensor | None = None,
- audio_feature_lengths: torch.Tensor | None = None,
- **kwargs,