bridge.training.utils.packed_seq_utils#

Module Contents#

Functions#

get_packed_seq_params

Build packed sequence parameters from a batch dictionary.

API#

bridge.training.utils.packed_seq_utils.get_packed_seq_params(
batch: dict[str, torch.Tensor],
) megatron.core.packed_seq_params.PackedSeqParams#

Build packed sequence parameters from a batch dictionary.

The function squeezes possible batch dimensions and removes any padding marked by -1 values. It returns a PackedSeqParams instance suitable for packed sequence attention kernels.

Parameters:

batch – A dictionary possibly containing cu_seqlens, optional cu_seqlens_argmin, and optional max_seqlen tensors.

Returns:

PackedSeqParams with identical q/kv parameters and qkv_format set to β€œthd”.