bridge.training.utils.packed_seq_utils#
Module Contents#
Functions#
Build packed sequence parameters from a batch dictionary. |
API#
- bridge.training.utils.packed_seq_utils.get_packed_seq_params(
- batch: dict[str, torch.Tensor],
Build packed sequence parameters from a batch dictionary.
The function squeezes possible batch dimensions and removes any padding marked by -1 values. It returns a
PackedSeqParamsinstance suitable for packed sequence attention kernels.- Parameters:
batch β A dictionary containing packed-sequence metadata. Expected keys:
cu_seqlens, optionalcu_seqlens_unpadded, optional argmins, and optionalmax_seqlen.- Returns:
PackedSeqParams with identical q/kv parameters and
qkv_formatset to βthdβ.