bridge.data.finetuning#
Finetuning-specific data handling utilities.
Module Contents#
Functions#
Split a batch dictionary into microbatches. |
|
Prepare a finetuning batch by getting global batch and splitting into microbatches. |
API#
- bridge.data.finetuning.split_batch_into_microbatches(
- batch: dict[str, Any],
- num_microbatches: int,
- enforce_divisible: bool = True,
Split a batch dictionary into microbatches.
Takes a global batch (e.g., [16, 240] for tokens) and splits it into num_microbatches smaller batches (e.g., 4 batches of [4, 240]).
- Parameters:
batch – Dictionary containing tensors with batch_size = num_microbatches * micro_batch_size
num_microbatches – Number of microbatches to split into
enforce_divisible – Whether to enforce batch_size % num_microbatches == 0
- Returns:
List of microbatch dictionaries, each containing the same keys as the input batch
.. rubric:: Example
batch = {‘tokens’: torch.rand(16, 240), ‘labels’: torch.rand(16, 240)} microbatches = split_batch_into_microbatches(batch, num_microbatches=4) len(microbatches) # 4 microbatches[0][‘tokens’].shape # torch.Size([4, 240])
- bridge.data.finetuning.prepare_finetuning_batch(
- data_iterator: Iterator,
- num_microbatches: int,
- default_seq_length: int,
- seq_key: str = 'tokens',
Prepare a finetuning batch by getting global batch and splitting into microbatches.
This function handles the finetuning-specific data flow:
Gets the full global batch from the iterator
Extracts the dynamic sequence length from the batch
Splits the batch into microbatches with consistent sequence length
Returns an iterator over microbatches and the extracted sequence length
- Parameters:
data_iterator – Iterator that yields global batches (e.g., from DataLoader with batch sampler)
num_microbatches – Number of microbatches to split each global batch into
default_seq_length – Fallback sequence length if it cannot be extracted from batch
seq_key – Key in batch dict containing the sequence tensor (default: ‘tokens’)
- Returns:
Iterator over microbatches (each microbatch is a dict with same keys as global batch)
Sequence length extracted from the global batch (or default_seq_length if not found)
- Return type:
Tuple of
.. rubric:: Example
DataLoader yields global batch of shape [16, 240]
microbatch_iter, seq_len = prepare_finetuning_batch( … data_iterator=iter(dataloader), … num_microbatches=4, … default_seq_length=2048 … ) seq_len # 240 (extracted from batch) batch1 = next(microbatch_iter) batch1[‘tokens’].shape # torch.Size([4, 240])