bridge.data.finetuning#

Finetuning-specific data handling utilities.

Module Contents#

Functions#

split_batch_into_microbatches

Split a batch dictionary into microbatches.

prepare_finetuning_batch

Prepare a finetuning batch by getting global batch and splitting into microbatches.

API#

bridge.data.finetuning.split_batch_into_microbatches(
batch: dict[str, Any],
num_microbatches: int,
enforce_divisible: bool = True,
) list[dict[str, Any]]#

Split a batch dictionary into microbatches.

Takes a global batch (e.g., [16, 240] for tokens) and splits it into num_microbatches smaller batches (e.g., 4 batches of [4, 240]).

Parameters:
  • batch – Dictionary containing tensors with batch_size = num_microbatches * micro_batch_size

  • num_microbatches – Number of microbatches to split into

  • enforce_divisible – Whether to enforce batch_size % num_microbatches == 0

Returns:

List of microbatch dictionaries, each containing the same keys as the input batch

.. rubric:: Example

batch = {‘tokens’: torch.rand(16, 240), ‘labels’: torch.rand(16, 240)} microbatches = split_batch_into_microbatches(batch, num_microbatches=4) len(microbatches) # 4 microbatches[0][‘tokens’].shape # torch.Size([4, 240])

bridge.data.finetuning.prepare_finetuning_batch(
data_iterator: Iterator,
num_microbatches: int,
default_seq_length: int,
seq_key: str = 'tokens',
) tuple[Iterator, int]#

Prepare a finetuning batch by getting global batch and splitting into microbatches.

This function handles the finetuning-specific data flow:

  1. Gets the full global batch from the iterator

  2. Extracts the dynamic sequence length from the batch

  3. Splits the batch into microbatches with consistent sequence length

  4. Returns an iterator over microbatches and the extracted sequence length

Parameters:
  • data_iterator – Iterator that yields global batches (e.g., from DataLoader with batch sampler)

  • num_microbatches – Number of microbatches to split each global batch into

  • default_seq_length – Fallback sequence length if it cannot be extracted from batch

  • seq_key – Key in batch dict containing the sequence tensor (default: ‘tokens’)

Returns:

  • Iterator over microbatches (each microbatch is a dict with same keys as global batch)

  • Sequence length extracted from the global batch (or default_seq_length if not found)

Return type:

Tuple of

.. rubric:: Example

DataLoader yields global batch of shape [16, 240]

microbatch_iter, seq_len = prepare_finetuning_batch( … data_iterator=iter(dataloader), … num_microbatches=4, … default_seq_length=2048 … ) seq_len # 240 (extracted from batch) batch1 = next(microbatch_iter) batch1[‘tokens’].shape # torch.Size([4, 240])