bridge.training.vlm_step#

Module Contents#

Functions#

get_batch_from_iterator

Get a batch of data from the iterator.

get_batch

Generate a batch.

forward_step

Forward training step.

Data#

API#

bridge.training.vlm_step.logger#

β€˜getLogger(…)’

bridge.training.vlm_step.get_batch_from_iterator(
data_iterator: Iterable,
use_mtp: bool = False,
skip_getting_attention_mask_from_dataset: bool = True,
) dict[str, Any]#

Get a batch of data from the iterator.

Parameters:
  • data_iterator – The data iterator to get the batch from.

  • use_mtp – Whether Multi-Token Prediction layers are enabled.

  • skip_getting_attention_mask_from_dataset – If set, the dataset will pass a None attention mask.

Returns:

A dictionary containing the batch data.

Return type:

dict[str, torch.Tensor]

bridge.training.vlm_step.get_batch(
data_iterator: Iterable,
cfg: megatron.bridge.training.config.ConfigContainer,
use_mtp: bool = False,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]#

Generate a batch.

Parameters:
  • data_iterator – Input data iterator

  • cfg – Configuration container

  • use_mtp – Whether Multi-Token Prediction layers are enabled

Returns:

tuple of tensors containing tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens, cu_seqlens_argmin, max_seqlen, visual_inputs (container of optional modalities)

bridge.training.vlm_step.forward_step(
state: megatron.bridge.training.state.GlobalState,
data_iterator: Iterable,
model: megatron.core.models.gpt.GPTModel,
return_schedule_plan: bool = False,
) tuple[torch.Tensor, functools.partial]#

Forward training step.

Parameters:
  • state – Global state for the run

  • data_iterator – Input data iterator

  • model – The GPT Model

  • return_schedule_plan (bool) – Whether to return the schedule plan instead of the output tensor

Returns:

tuple containing the output tensor and the loss function