Sampler
SizeAwareBatchSampler
Bases: Sampler[List[int]]
Varriying-size batching data sampler class that ensures batch size doesn't exceed maximum.
A sampler that batches elements of varying sizes while ensuring that the total size of each batch does not exceed a specified maximum.
This is useful when dealing with datasets where each element has a
different size, such as graphs or sequences of varying lengths.
The sampler uses a provided sizeof
function to determine the size
of each element in the dataset and ensures that the total size of
each batch does not exceed the specified max_total_size
.
Examples:
>>> import torch
>>> from bionemo.size_aware_batching.sampler import SizeAwareBatchSampler
>>> # Define a sample dataset with torch.tensor
>>> dataset = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6]),
... torch.tensor([7, 8]), torch.tensor([9, 10])]
>>> # Define a function that returns the size of each element in the dataset.
>>> def sizeof(index):
... return dataset[index].numel()
>>> # Create a SizeAwareBatchSampler with a maximum total batch size of 10.
>>> batch_sampler = SizeAwareBatchSampler(
... sampler=torch.utils.data.SequentialSampler(dataset),
... sizeof=sizeof,
... max_total_size=4
... )
>>> # Iterate over batches of indices that do not exceed the maximum total size.
>>> print(list(batch_sampler))
[[0, 1], [2, 3], [4]]
Source code in bionemo/size_aware_batching/sampler.py
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
|
__init__(sampler, sizeof, max_total_size, info_logger=None, warn_logger=None)
Initializes the SizeAwareBatchSampler.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler
|
Union[Sampler[List[int]], Iterable[int]]
|
The underlying sampler. |
required |
sizeof
|
Callable[[int], Real]
|
A function that returns the size at each index. E.g., this can used to
determine how much memory an element consumes. Its return type must be
comparable with |
required |
max_total_size
|
Real
|
The maximum total size of a mini-batch. The semantics of "size"
is defined by the |
required |
info_logger
|
Optional[Callable[[str], None]]
|
A function to log info. Defaults to None. |
None
|
warn_logger
|
Optional[Callable[[str], None]]
|
A function to log warnings. Defaults None. |
None
|
Raises:
Type | Description |
---|---|
TypeError
|
If sampler is not an instance of Sampler or Iterable, or if sizeof is not a callable, dictionary, or sequence container. |
ValueError
|
If max_total_size is not a positive number. |
Source code in bionemo/size_aware_batching/sampler.py
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
|
__iter__()
Iterate over batches of indices.
This function yields batches of indices that do not exceed the maximum total size.
Yields:
Type | Description |
---|---|
List[int]
|
A batch of indices that do not exceed the maximum total size. |
Source code in bionemo/size_aware_batching/sampler.py
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
|
size_aware_batching(dataset, sizeof, max_total_size, collate_fn=None, info_logger=None, warn_logger=None)
Creates a batching iterator where each batch size varries (within a max limit) according to memory consumption.
A generator that batches elements from an iterable while ensuring that the total size of each batch does not exceed a specified maximum. Here the size can be a measurement of memory consumption of the elements in the batch. This can be useful for both indexible data or non-indexible but iterable data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Iterable[Data]
|
The input iterable. |
required |
sizeof
|
Callable[[Data], Real]
|
A function or mapping that returns the "size" of each element in |
required |
max_total_size
|
Real
|
The maximum total "size" of each batch. The semantics of "size"
is defined by the |
required |
collate_fn
|
Optional[Callable[[Iterable[Data]], BatchCollated]]
|
An optional function to collate batches. Defaults to None, in which case each batch is a list of elements from the input dataset |
None
|
info_logger
|
Optional[Callable[[str], None]]
|
A function to log info. Defaults to None. |
None
|
warn_logger
|
Optional[Callable[[str], None]]
|
A function to log warnings. Defaults to None. |
None
|
Yields:
Type | Description |
---|---|
Union[List[Data], BatchCollated]
|
A generator that yields batches from |
Assumptions
1. Linear complexity. This function consumes the given Iterable of data (dataset
) once,
by going over the data item one by one to build a batch and yield it as soon as the
addition of the next data item to the batch would exceed max_total_size
or if the
batch is the last one (end of iteration)
2. Additive size measurement. For the general usage case of building mini-batches with
a threshold of the batch's memory consumption, it assumes that the size of the batch is
the sum of all elements in the batch (additive property).
3. Comparable type of max_total_size
and sizeof
's return. sizeof
's return values
must be compared with max_total_size
to threshold the size of batches
Caveat 1: The generated batch sizes may have large variance - how to workaround: filter the output of this generator using a batch size threshold 2: The number of batches may vary a lot across different epochs. - how to workaround: increase the number of steps that compose an epoch, e.g., in the Lightning training/validation loop, which effectively increases the input dataset size per epoch
Example:
>>> import torch
>>> from torch.utils.data import default_collate
>>> from bionemo.size_aware_batching.sampler import size_aware_batching
>>> # Define a sample dataset with torch.tensor
>>> dataset = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6]),
... torch.tensor([7, 8]), torch.tensor([9, 10])]
>>> # Define a sizeof function that returns the size of each tensor
>>> def sizeof(x):
... return x.numel()
>>> # Create a generator with max_total_size=4 and default_collate_fn
>>> gen = size_aware_batching(dataset, sizeof, 4, collate_fn=default_collate)
>>> batches = list(gen)
>>> print(batches)
[tensor([[1, 2], [3, 4]]), tensor([[5, 6], [7, 8]]), tensor([[9, 10]])]
Source code in bionemo/size_aware_batching/sampler.py
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
|