bridge.training.utils.packed_seq_utils
#
Module Contents#
Functions#
Build packed sequence parameters from a batch dictionary. |
API#
- bridge.training.utils.packed_seq_utils.get_packed_seq_params(
- batch: dict[str, torch.Tensor],
Build packed sequence parameters from a batch dictionary.
The function squeezes possible batch dimensions and removes any padding marked by -1 values. It returns a
PackedSeqParams
instance suitable for packed sequence attention kernels.- Parameters:
batch β A dictionary possibly containing
cu_seqlens
, optionalcu_seqlens_argmin
, and optionalmax_seqlen
tensors.- Returns:
PackedSeqParams with identical q/kv parameters and
qkv_format
set to βthdβ.