bridge.training.audio_lm_step#

Audio-language model training step, independent of vlm_step.py.

Module Contents#

Functions#

get_batch_from_iterator

Get a batch of data from the iterator for audio-language models.

get_batch

Generate a batch for audio-language models.

forward_step

Forward training step for audio-language models.

Data#

API#

bridge.training.audio_lm_step.logger#

‘getLogger(…)’

bridge.training.audio_lm_step.get_batch_from_iterator(
data_iterator: Iterable,
use_mtp: bool = False,
skip_getting_attention_mask_from_dataset: bool = True,
*,
is_first_pp_stage: bool,
is_last_pp_stage: bool,
) dict[str, Any]#

Get a batch of data from the iterator for audio-language models.

Uses the audio_inputs batch key instead of visual_inputs.

bridge.training.audio_lm_step.get_batch(
data_iterator: Iterable,
cfg: megatron.bridge.training.config.ConfigContainer,
use_mtp: bool = False,
*,
pg_collection,
) tuple[...]#

Generate a batch for audio-language models.

Adapted from vlm_step.get_batch but uses audio_inputs key.

bridge.training.audio_lm_step.forward_step(
state: megatron.bridge.training.state.GlobalState,
data_iterator: Iterable,
model: megatron.core.models.gpt.GPTModel,
return_schedule_plan: bool = False,
) tuple[torch.Tensor, functools.partial]#

Forward training step for audio-language models.

Uses a local get_batch that extracts audio_inputs instead of visual_inputs.

Parameters:
  • state – Global state for the run

  • data_iterator – Input data iterator

  • model – The audio-language model

  • return_schedule_plan (bool) – Whether to return the schedule plan instead of the output tensor

Returns:

tuple containing the output tensor and the loss function