bridge.training.utils.padding_utils#

Padding and truncation helpers for training batches.

These utilities centralize common sequence length adjustments used to ensure fixed or efficient shapes for tensors such as tokens, labels, position ids, and attention masks.

Module Contents#

Functions#

pad_or_truncate_2d_to_len

Pad or truncate a 2D tensor to a desired target length with an upper cap.

pad_or_truncate_pos_to_len

Pad or truncate position ids to a target length with an upper cap.

pad_or_truncate_attn_to_len

Pad or truncate a 4D attention mask to the target length with an upper cap.

Data#

API#

bridge.training.utils.padding_utils.__all__#

[‘pad_or_truncate_2d_to_len’, ‘pad_or_truncate_pos_to_len’, ‘pad_or_truncate_attn_to_len’]

bridge.training.utils.padding_utils.pad_or_truncate_2d_to_len(
x: torch.Tensor | None,
target_len: int,
max_cap: int,
pad_value: int | float,
) torch.Tensor | None#

Pad or truncate a 2D tensor to a desired target length with an upper cap.

Expects input of shape (batch, seq_len). Pads/truncates along the last dimension.

bridge.training.utils.padding_utils.pad_or_truncate_pos_to_len(
pos: torch.Tensor | None,
target_len: int,
max_cap: int,
) torch.Tensor | None#

Pad or truncate position ids to a target length with an upper cap.

Extends positions by appending a monotonically increasing range starting from the current length to the target length.

bridge.training.utils.padding_utils.pad_or_truncate_attn_to_len(
mask: torch.Tensor | None,
target_len: int,
max_cap: int,
) torch.Tensor | None#

Pad or truncate a 4D attention mask to the target length with an upper cap.

Expects input of shape (batch, heads, seq_len, seq_len). Pads the last two dims.