Skip to content

Collate

bert_padding_collate_fn(batch, padding_value, min_length=None, max_length=None)

Padding collate function for BERT dataloaders.

Parameters:

Name Type Description Default
batch list

List of samples.

required
padding_value int

The tokenizer's pad token ID.

required
min_length int | None

Minimum length of the output batch; tensors will be padded to this length. If not provided, no extra padding beyond the max_length will be added.

None
max_length int | None

Maximum length of the sequence. If not provided, tensors will be padded to the longest sequence in the batch.

None
Source code in bionemo/llm/data/collate.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def bert_padding_collate_fn(
    batch: Sequence[types.BertSample],
    padding_value: int,
    min_length: int | None = None,
    max_length: int | None = None,
) -> types.BertSample:
    """Padding collate function for BERT dataloaders.

    Args:
        batch (list): List of samples.
        padding_value (int, optional): The tokenizer's pad token ID.
        min_length: Minimum length of the output batch; tensors will be padded to this length. If not
            provided, no extra padding beyond the max_length will be added.
        max_length: Maximum length of the sequence. If not provided, tensors will be padded to the
            longest sequence in the batch.
    """
    padding_values = {
        "text": padding_value,
        "types": 0,
        "attention_mask": False,
        "labels": -1,
        "loss_mask": False,
        "is_random": 0,
    }
    return padding_collate_fn(
        batch=batch,  # type: ignore[assignment]
        padding_values=padding_values,
        min_length=min_length,
        max_length=max_length,
    )

padding_collate_fn(batch, padding_values, min_length=None, max_length=None)

Collate function with padding.

Parameters:

Name Type Description Default
batch Sequence[_T]

List of samples, each of which is a dictionary of tensors.

required
padding_values dict[str, int]

A dictionary of padding values for each tensor key.

required
min_length int | None

Minimum length of the output batch; tensors will be padded to this length. If not provided, no extra padding beyond the max_length will be added.

None
max_length int | None

Maximum length of the sequence. If not provided, tensors will be padded to the longest sequence in the batch.

None

Returns:

Type Description
_T

A collated batch with the same dictionary input structure.

Source code in bionemo/llm/data/collate.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def padding_collate_fn(
    batch: Sequence[_T],
    padding_values: dict[str, int],
    min_length: int | None = None,
    max_length: int | None = None,
) -> _T:
    """Collate function with padding.

    Args:
        batch: List of samples, each of which is a dictionary of tensors.
        padding_values: A dictionary of padding values for each tensor key.
        min_length: Minimum length of the output batch; tensors will be padded to this length. If not
            provided, no extra padding beyond the max_length will be added.
        max_length: Maximum length of the sequence. If not provided, tensors will be padded to the
            longest sequence in the batch.

    Returns:
        A collated batch with the same dictionary input structure.
    """
    for entry in batch:
        if entry.keys() != padding_values.keys():
            raise ValueError("All keys in inputs must match provided padding_values.")

    def _pad(tensors, padding_value):
        if max_length is not None:
            tensors = [t[:max_length] for t in tensors]
        batched_tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True, padding_value=padding_value)
        if min_length is None:
            return batched_tensors
        return torch.nn.functional.pad(batched_tensors, (0, min_length - batched_tensors.size(1)), value=padding_value)

    return {k: _pad([s[k] for s in batch], padding_values[k]) for k in batch[0].keys()}  # type: ignore[return-value]