bridge.models.qwen_omni.qwen3_omni_step#

Qwen3-Omni thinker training step helpers.

Module Contents#

Functions#

get_batch_from_iterator

Get a thinker-training batch from the iterator.

_normalize_multimodal_inputs

Normalize multimodal batch tensors for Qwen3-Omni model forward.

get_batch

Generate a minimal thinker-training batch.

forward_step

Forward training step for Qwen3-Omni thinker.

Data#

API#

bridge.models.qwen_omni.qwen3_omni_step._MULTIMODAL_KEYS#

(‘pixel_values’, ‘image_grid_thw’, ‘pixel_values_videos’, ‘video_grid_thw’, ‘video_second_per_grid’,…

bridge.models.qwen_omni.qwen3_omni_step.get_batch_from_iterator(
data_iterator: Iterable,
use_mtp: bool = False,
skip_getting_attention_mask_from_dataset: bool = True,
*,
is_first_pp_stage: bool,
is_last_pp_stage: bool,
) dict[str, Any]#

Get a thinker-training batch from the iterator.

bridge.models.qwen_omni.qwen3_omni_step._normalize_multimodal_inputs(
batch: dict[str, Any],
) dict[str, torch.Tensor]#

Normalize multimodal batch tensors for Qwen3-Omni model forward.

bridge.models.qwen_omni.qwen3_omni_step.get_batch(
data_iterator: Iterable,
cfg: megatron.bridge.training.config.ConfigContainer,
use_mtp: bool = False,
*,
pg_collection,
) tuple[...]#

Generate a minimal thinker-training batch.

bridge.models.qwen_omni.qwen3_omni_step.forward_step(
state: megatron.bridge.training.state.GlobalState,
data_iterator: Iterable,
model: megatron.core.models.gpt.GPTModel,
return_schedule_plan: bool = False,
) tuple[torch.Tensor, functools.partial]#

Forward training step for Qwen3-Omni thinker.