bridge.models.qwen_omni.modeling_qwen3_omni.model#

Module Contents#

Classes#

Qwen3OmniModel

Qwen3-Omni model wrapper.

API#

class bridge.models.qwen_omni.modeling_qwen3_omni.model.Qwen3OmniModel(
language_transformer_config: megatron.bridge.models.qwen_omni.modeling_qwen3_omni.transformer_config.Qwen3OmniTransformerConfig,
language_transformer_layer_spec: megatron.core.transformer.spec_utils.ModuleSpec,
thinker_transformer_config: transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe.Qwen3OmniMoeThinkerConfig,
talker_transformer_config: transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe.Qwen3OmniMoeTalkerConfig | None = None,
code2wav_transformer_config: transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe.Qwen3OmniMoeCode2WavConfig | None = None,
parallel_output: bool = True,
pre_process: bool = True,
post_process: bool = True,
add_encoder: bool = True,
add_decoder: bool = True,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection | None = None,
)#

Bases: megatron.core.transformer.MegatronModule

Qwen3-Omni model wrapper.

Initialization

shared_embedding_or_output_weight()#
set_input_tensor(input_tensor) None#
freeze(
freeze_language_model: bool = False,
freeze_vision_model: bool = False,
freeze_audio_model: bool = False,
)#
forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
loss_mask: torch.Tensor | None = None,
inference_params: megatron.core.InferenceParams | None = None,
packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams | None = None,
extra_block_kwargs: dict | None = None,
runtime_gather_output: bool | None = None,
**kwargs,
) torch.Tensor#