pvaAplSortPayloadVpu.hpp#

Fully qualified name: public/src/primitive/pvaAplSortPayloadVpu.hpp

File members: public/src/primitive/pvaAplSortPayloadVpu.hpp

/*
 * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
 *
 * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
 * property and proprietary rights in and to this material, related
 * documentation and any modifications thereto. Any use, reproduction,
 * disclosure or distribution of this material and related documentation
 * without an express license agreement from NVIDIA CORPORATION or
 * its affiliates is strictly prohibited.
 */

#ifndef PVA_APL_SORT_PAYLOAD_VPU_HPP
#define PVA_APL_SORT_PAYLOAD_VPU_HPP

#include <cupva_device.h>
#include <stdio.h>
#include <string.h>

#include <type_traits>

namespace pvaApl {

// Define a traits struct to map DataType to specific vector types
template<typename DataType>
struct SortPayloadVectorTypes;

template<>
struct SortPayloadVectorTypes<uint32_t>
{
    using DVTYPE = dvint;
    using DVTYPEX = dvintx;
    using PMT_VTYPEX = vshortx;
};

template<>
struct SortPayloadVectorTypes<int32_t>
{
    using DVTYPE = dvint;
    using DVTYPEX = dvintx;
    using PMT_VTYPEX = vshortx;
};

struct SortPayloadContext
{
    AgenCFG cfgs[256];
    static constexpr size_t SCRATCH_SIZE = ((16 + 2) * 512 * sizeof(uint32_t));
    uint8_t scratch[SCRATCH_SIZE];
};

template<typename DataType, int Size>
class SortPayloadVpu
{
private:
    using DVTYPE     = typename SortPayloadVectorTypes<DataType>::DVTYPE;
    using DVTYPEX    = typename SortPayloadVectorTypes<DataType>::DVTYPEX;
    using PMT_VTYPEX = typename SortPayloadVectorTypes<DataType>::PMT_VTYPEX;

    template<typename AgenType>
    auto DVLOAD(AgenType &ag) -> decltype(auto)
    {
        if constexpr (std::is_same_v<DataType, uint32_t>)
        {
            return static_cast<dvintx>(dvuint_load(ag));
        }
        else if constexpr (std::is_same_v<DataType, int32_t>)
        {
            return static_cast<dvintx>(dvint_load(ag));
        }
        else
        {
            static_assert(!std::is_same_v<DataType, DataType>, "Unsupported type for dvload");
        }
    }

    template<typename AgenType, typename VecType>
    auto DVLOAD_T2(AgenType &ag, VecType vx) -> decltype(auto)
    {
        if constexpr (std::is_same_v<DataType, uint32_t>)
        {
            return static_cast<dvintx>(dvuint_load_perm_transp2(ag, vx));
        }
        else if constexpr (std::is_same_v<DataType, int32_t>)
        {
            return static_cast<dvintx>(dvint_load_perm_transp2(ag, vx));
        }
        else
        {
            static_assert(!std::is_same_v<DataType, DataType>, "Unsupported type for dvload");
        }
    }

    inline void VSORT2PL_INPLACE(DVTYPEX &v1, DVTYPEX &v2)
    {
        vsort2pl(v1.lo, v2.lo, v1.lo, v2.lo);
        vsort2pl(v1.hi, v2.hi, v1.hi, v2.hi);
    }

    inline void MERGE_2_WAY(DVTYPEX &v1, DVTYPEX &v2)
    {
        VSORT2PL_INPLACE(v1, v2);
    }

    inline void MERGE_4_WAY(DVTYPEX &v1, DVTYPEX &v2, DVTYPEX &v3, DVTYPEX &v4)
    {
        chess_separator();
        VSORT2PL_INPLACE(v1, v3);
        chess_separator();
        VSORT2PL_INPLACE(v2, v4);
        chess_separator();
        MERGE_2_WAY(v1, v2);
        chess_separator();
        MERGE_2_WAY(v3, v4);
    }

    void even_odd_merge_sort_payload(AgenCFG **cfgs, int height)
    {
        DVTYPEX v1, v2, v3, v4, v5, v6, v7, v8;

        agen in, out;

        in  = init_agen_from_cfg(*(*cfgs)++);
        out = init_agen_from_cfg(*(*cfgs)++);

        int niter = height >> 3;

        for (int i = 0; i < niter; i++) chess_loop_range(3, )
        chess_prepare_for_pipelining
        {
            v1 = DVLOAD(in);
            v2 = DVLOAD(in);
            v3 = DVLOAD(in);
            v4 = DVLOAD(in);
            v5 = DVLOAD(in);
            v6 = DVLOAD(in);
            v7 = DVLOAD(in);
            v8 = DVLOAD(in);

            VSORT2PL_INPLACE(v1, v2);
            VSORT2PL_INPLACE(v3, v4);
            VSORT2PL_INPLACE(v5, v6);
            VSORT2PL_INPLACE(v7, v8);

            VSORT2PL_INPLACE(v1, v3);
            VSORT2PL_INPLACE(v5, v7);
            VSORT2PL_INPLACE(v2, v4);
            VSORT2PL_INPLACE(v6, v8);

            VSORT2PL_INPLACE(v2, v3);
            VSORT2PL_INPLACE(v6, v7);

            VSORT2PL_INPLACE(v1, v5);
            VSORT2PL_INPLACE(v2, v6);
            VSORT2PL_INPLACE(v3, v7);
            VSORT2PL_INPLACE(v4, v8);

            VSORT2PL_INPLACE(v3, v5);
            VSORT2PL_INPLACE(v4, v6);

            VSORT2PL_INPLACE(v2, v3);
            VSORT2PL_INPLACE(v4, v5);
            VSORT2PL_INPLACE(v6, v7);

            vstore(v1, out);
            vstore(v2, out);
            vstore(v3, out);
            vstore(v4, out);
            vstore(v5, out);
            vstore(v6, out);
            vstore(v7, out);
            vstore(v8, out);
        }
    }

    void bitonic_merge_4_way_reverse(AgenCFG **cfgs, int height)
    {
        DVTYPEX v1, v2, v3, v4;

        agen_A in1, in2;
        agen_B out;

        in1 = init_agen_A_from_cfg(*(*cfgs)++);
        in2 = init_agen_A_from_cfg(*(*cfgs)++);
        out = init_agen_B_from_cfg(*(*cfgs)++);

        int niter = height >> 2;

        for (int i = 0; i < niter; i++) chess_loop_range(8, )
        chess_unroll_loop(2)
        chess_prepare_for_pipelining
        {
            v1 = DVLOAD(in1);
            v2 = DVLOAD(in1);
            v3 = DVLOAD(in2);
            v4 = DVLOAD(in2);

            MERGE_4_WAY(v1, v2, v3, v4);

            vstore(v1, out);
            vstore(v2, out);
            vstore(v3, out);
            vstore(v4, out);
        }
    }

    void bitonic_merge_4_way_transp(AgenCFG **cfgs, int height)
    {
        DVTYPEX v1, v2, v3, v4;

        agen in, out;

        in  = init_agen_from_cfg(*(*cfgs)++);
        out = init_agen_from_cfg(*(*cfgs)++);

        int niter = height >> 2;

        for (int i = 0; i < niter; i++) chess_loop_range(8, )
        chess_unroll_loop(2)
        chess_prepare_for_pipelining
        {
            v1 = DVLOAD(in);
            v2 = DVLOAD(in);
            v3 = DVLOAD(in);
            v4 = DVLOAD(in);

            MERGE_4_WAY(v1, v2, v3, v4);

            vstore_transp2(v1, out);
            vstore_transp2(v2, out);
            vstore_transp2(v3, out);
            vstore_transp2(v4, out);
        }
    }

    void bitonic_merge_4_way(AgenCFG **cfgs, int height)
    {
        DVTYPEX v1, v2, v3, v4;

        agen in, out;

        in  = init_agen_from_cfg(*(*cfgs)++);
        out = init_agen_from_cfg(*(*cfgs)++);

        int niter = height >> 2;

        for (int i = 0; i < niter; i++) chess_loop_range(8, )
        chess_unroll_loop(2)
        chess_prepare_for_pipelining
        {
            v1 = DVLOAD(in);
            v2 = DVLOAD(in);
            v3 = DVLOAD(in);
            v4 = DVLOAD(in);

            MERGE_4_WAY(v1, v2, v3, v4);

            vstore(v1, out);
            vstore(v2, out);
            vstore(v3, out);
            vstore(v4, out);
        }
    }

    void bitonic_merge_2_way_reverse(AgenCFG **cfgs, int height)
    {
        DVTYPEX v1, v2;

        agen_A in;
        agen_B out;

        in  = init_agen_A_from_cfg(*(*cfgs)++);
        out = init_agen_B_from_cfg(*(*cfgs)++);

        int niter = height >> 1;

        for (int i = 0; i < niter; i++) chess_loop_range(8, )
        chess_unroll_loop(4)
        chess_prepare_for_pipelining
        {
            v1 = DVLOAD(in);
            v2 = DVLOAD(in);

            MERGE_2_WAY(v1, v2);

            vstore(v1, out);
            vstore(v2, out);
        }
    }

    void bitonic_merge_2_way_transp(AgenCFG **cfgs, int height)
    {
        DVTYPEX v1, v2;

        agen in, out;

        in  = init_agen_from_cfg(*(*cfgs)++);
        out = init_agen_from_cfg(*(*cfgs)++);

        int niter = height >> 1;

        for (int i = 0; i < niter; i++) chess_loop_range(8, )
        chess_unroll_loop(4)
        chess_prepare_for_pipelining
        {
            v1 = DVLOAD(in);
            v2 = DVLOAD(in);

            MERGE_2_WAY(v1, v2);

            vstore_transp2(v1, out);
            vstore_transp2(v2, out);
        }
    }

    void bitonic_merge_2_way(AgenCFG **cfgs, int height)
    {
        DVTYPEX v1, v2;

        agen in, out;

        in  = init_agen_from_cfg(*(*cfgs)++);
        out = init_agen_from_cfg(*(*cfgs)++);

        int niter = height >> 1;

        for (int i = 0; i < niter; i++) chess_loop_range(8, )
        chess_unroll_loop(4)
        chess_prepare_for_pipelining
        {
            v1 = DVLOAD(in);
            v2 = DVLOAD(in);

            MERGE_2_WAY(v1, v2);

            vstore(v1, out);
            vstore(v2, out);
        }
    }

    void bitonic_merge_transpose_4_way_reverse(AgenCFG **cfgs, int height, vshortx *permute_short_ptr,
                                               vshortx *permute_reverse_short_ptr)
    {
        DVTYPEX v1, v2, v3, v4;

        vshortx pmt_short, pmt_reverse_short;
        agen_A in1, in2;
        agen_B out;

        in1 = init_agen_A_from_cfg(*(*cfgs)++);
        in2 = init_agen_A_from_cfg(*(*cfgs)++);

        out = init_agen_B_from_cfg(*(*cfgs)++);

        pmt_short         = *permute_short_ptr;
        pmt_reverse_short = *permute_reverse_short_ptr;

        int niter = height >> 2;

        for (int i = 0; i < niter; i++) chess_loop_range(8, )
        chess_unroll_loop(2)
        chess_prepare_for_pipelining
        {
            v1 = DVLOAD_T2(in1, pmt_short);
            v2 = DVLOAD_T2(in1, pmt_short);
            v3 = DVLOAD_T2(in2, pmt_reverse_short);
            v4 = DVLOAD_T2(in2, pmt_reverse_short);

            MERGE_4_WAY(v1, v2, v3, v4);

            vstore_transp2(v1, out);
            vstore_transp2(v2, out);
            vstore_transp2(v3, out);
            vstore_transp2(v4, out);
        }
    }

    void bitonic_merge_transpose_4_way(AgenCFG **cfgs, int height, vshortx *permute_short_ptr)
    {
        DVTYPEX v1, v2, v3, v4;

        vshortx pmt_short;
        agen in, out;

        in  = init_agen_from_cfg(*(*cfgs)++);
        out = init_agen_from_cfg(*(*cfgs)++);

        pmt_short = *permute_short_ptr;

        int niter = height >> 2;

        for (int i = 0; i < niter; i++) chess_loop_range(8, )
        chess_unroll_loop(2)
        chess_prepare_for_pipelining
        {
            v1 = DVLOAD_T2(in, pmt_short);
            v2 = DVLOAD_T2(in, pmt_short);
            v3 = DVLOAD_T2(in, pmt_short);
            v4 = DVLOAD_T2(in, pmt_short);

            MERGE_4_WAY(v1, v2, v3, v4);

            vstore_transp2(v1, out);
            vstore_transp2(v2, out);
            vstore_transp2(v3, out);
            vstore_transp2(v4, out);
        }
    }

    void bitonic_merge_transpose_2_way_reverse(AgenCFG **cfgs, int height, vshortx *permute_short_ptr,
                                               vshortx *permute_reverse_short_ptr)
    {
        DVTYPEX v1, v2;

        vshortx pmt_short, pmt_reverse_short;
        agen_A in1;
        agen_A in2;
        agen_B out;

        in1 = init_agen_A_from_cfg(*(*cfgs)++);
        in2 = init_agen_A_from_cfg(*(*cfgs)++);

        out = init_agen_B_from_cfg(*(*cfgs)++);

        pmt_short         = *permute_short_ptr;
        pmt_reverse_short = *permute_reverse_short_ptr;

        int niter = height >> 1;

        for (int i = 0; i < niter; i++) chess_loop_range(8, )
        chess_unroll_loop(4)
        chess_prepare_for_pipelining
        {
            v1 = DVLOAD_T2(in1, pmt_short);
            v2 = DVLOAD_T2(in2, pmt_reverse_short);

            VSORT2PL_INPLACE(v1, v2);

            vstore_transp2(v1, out);
            vstore_transp2(v2, out);
        }
    }

    void bitonic_merge_transpose_2_way(AgenCFG **cfgs, int height, vshortx *permute_short_ptr)
    {
        DVTYPEX v1, v2;

        vshortx pmt_short;
        agen_A in;
        agen_B out;

        in  = init_agen_A_from_cfg(*(*cfgs)++);
        out = init_agen_B_from_cfg(*(*cfgs)++);

        pmt_short = *permute_short_ptr;

        int niter = height >> 1;

        for (int i = 0; i < niter; i++) chess_loop_range(8, )
        chess_unroll_loop(4)
        chess_prepare_for_pipelining
        {
            v1 = DVLOAD_T2(in, pmt_short);
            v2 = DVLOAD_T2(in, pmt_short);

            VSORT2PL_INPLACE(v1, v2);

            vstore_transp2(v1, out);
            vstore_transp2(v2, out);
        }
    }

    void even_odd_merge_sort_payload_init(DataType *ping, DataType *pong, AgenCFG **cfgs, int lofst, int height)
    {
        AgenWrapper wrapper;
        int vecw = pva_elementsof(DVTYPEX);

        agen in      = init((DVTYPE *)ping);
        wrapper.size = sizeof(DataType);
        wrapper.n1   = height;
        wrapper.s1   = vecw;
        INIT_AGEN1(in, wrapper);

        agen out     = init((DVTYPE *)pong);
        wrapper.size = sizeof(DataType);
        wrapper.n1   = height;
        wrapper.s1   = lofst;
        INIT_AGEN1(out, wrapper);

        *(*cfgs)++ = extract_agen_cfg(in);
        *(*cfgs)++ = extract_agen_cfg(out);
    }

    void bitonic_merge_transpose_n_way_reverse_init(DataType *ping, DataType *pong, int *it, AgenCFG **cfgs, int lofst,
                                                    int height, int h_dist, int n)
    {
        agen in1, in2, out;
        int vecw = pva_elementsof(DVTYPEX);

        DataType *src = (*it) & 1 ? ping : pong;
        DataType *dst = (*it) & 1 ? pong : ping;
        (*it)++;

        int niter1 = 1;
        int niter2 = n / 2;
        int niter3 = (vecw / 2) / h_dist / n;
        int niter4 = height / (vecw / 2);
        int niter5 = h_dist;

        AgenWrapper wrapper;
        wrapper.size = sizeof(DataType);

        in1           = init((DVTYPE *)src);
        in1.lane_ofst = lofst / vecw;
        wrapper.n1    = niter1;
        wrapper.n2    = niter2;
        wrapper.n3    = niter3;
        wrapper.n4    = niter4;
        wrapper.n5    = niter5;
        wrapper.s1    = 1;
        wrapper.s2    = h_dist * 2;
        wrapper.s3    = n * h_dist * 2;
        wrapper.s4    = vecw / 2 * lofst;
        wrapper.s5    = 2;
        INIT_AGEN5(in1, wrapper);

        *(*cfgs)++ = extract_agen_cfg(in1);

        in2           = init((DVTYPE *)(src + (height - vecw / 2) * lofst + h_dist * n * 2 - 2));
        in2.lane_ofst = lofst / vecw;
        wrapper.s1    = 1;
        wrapper.s2    = -h_dist * 2;
        wrapper.s3    = n * h_dist * 2;
        wrapper.s4    = -vecw / 2 * lofst;
        wrapper.s5    = -2;
        INIT_AGEN5(in2, wrapper);

        *(*cfgs)++ = extract_agen_cfg(in2);

        niter1 = 1;
        niter2 = (vecw / 2) / h_dist;
        niter3 = height / (vecw / 2);
        niter4 = h_dist;

        out           = init((DVTYPE *)dst);
        out.lane_ofst = lofst / vecw;
        wrapper.n1    = niter1;
        wrapper.n2    = niter2;
        wrapper.n3    = niter3;
        wrapper.n4    = niter4;
        wrapper.s1    = 1;
        wrapper.s2    = h_dist * 2;
        wrapper.s3    = vecw / 2 * lofst;
        wrapper.s4    = 2;
        INIT_AGEN4(out, wrapper);

        *(*cfgs)++ = extract_agen_cfg(out);
    }

    void bitonic_merge_transpose_n_way_init(DataType *ping, DataType *pong, int *it, AgenCFG **cfgs, int lofst,
                                            int height, int h_dist)
    {
        agen in, out;
        int vecw = pva_elementsof(DVTYPEX);

        DataType *src = (*it) & 1 ? ping : pong;
        DataType *dst = (*it) & 1 ? pong : ping;
        (*it)++;

        int niter1 = 1;
        int niter2 = (vecw / 2) / h_dist;
        int niter3 = height / (vecw / 2);
        int niter4 = h_dist;

        AgenWrapper wrapper;
        wrapper.size = sizeof(DataType);

        in           = init((DVTYPE *)src);
        in.lane_ofst = lofst / vecw;
        wrapper.n1   = niter1;
        wrapper.n2   = niter2;
        wrapper.n3   = niter3;
        wrapper.n4   = niter4;
        wrapper.s1   = 1;
        wrapper.s2   = h_dist * 2;
        wrapper.s3   = vecw / 2 * lofst;
        wrapper.s4   = 2;
        INIT_AGEN4(in, wrapper);

        *(*cfgs)++ = extract_agen_cfg(in);

        out           = init((DVTYPE *)dst);
        out.lane_ofst = lofst / vecw;
        wrapper.n1    = niter1;
        wrapper.n2    = niter2;
        wrapper.n3    = niter3;
        wrapper.n4    = niter4;
        wrapper.s1    = 1;
        wrapper.s2    = h_dist * 2;
        wrapper.s3    = vecw / 2 * lofst;
        wrapper.s4    = 2;
        INIT_AGEN4(out, wrapper);

        *(*cfgs)++ = extract_agen_cfg(out);
    }

    void bitonic_merge_transpose_init(DataType *ping, DataType *pong, int *it, AgenCFG **cfgs, int lofst, int height,
                                      int h_dist)
    {
        int reverse = 1;
        while (h_dist > 0)
        {
            if (h_dist > 1)
            {
                if (reverse)
                {
                    bitonic_merge_transpose_n_way_reverse_init(ping, pong, it, cfgs, lofst, height, h_dist >> (2 - 1),
                                                               4);
                }
                else
                {
                    bitonic_merge_transpose_n_way_init(ping, pong, it, cfgs, lofst, height, h_dist >> (2 - 1));
                }
                h_dist = h_dist >> 2;
            }
            else
            {
                if (reverse)
                {
                    bitonic_merge_transpose_n_way_reverse_init(ping, pong, it, cfgs, lofst, height, h_dist, 2);
                }
                else
                {
                    bitonic_merge_transpose_n_way_init(ping, pong, it, cfgs, lofst, height, h_dist);
                }
                h_dist = h_dist >> 1;
            }
            reverse = 0;
        }
    }

    void bitonic_merge_n_way_reverse_init(DataType *ping, DataType *pong, int *it, AgenCFG **cfgs, int lofst,
                                          int height, int v_dist, int n)
    {
        agen stream;

        DataType *src = (*it) & 1 ? ping : pong;
        DataType *dst = (*it) & 1 ? pong : ping;
        (*it)++;

        int niter1 = n / 2;
        int niter2 = height / v_dist / n;
        int niter3 = v_dist;

        AgenWrapper wrapper;
        wrapper.size = sizeof(DataType);

        stream     = init((DVTYPE *)src);
        wrapper.n1 = niter1;
        wrapper.n2 = niter2;
        wrapper.n3 = niter3;
        wrapper.s1 = v_dist * lofst;
        wrapper.s2 = n * v_dist * lofst;
        wrapper.s3 = lofst;
        INIT_AGEN3(stream, wrapper);

        *(*cfgs)++ = extract_agen_cfg(stream);

        stream     = init((DVTYPE *)(src + (n * v_dist - 1) * lofst));
        wrapper.s1 = -v_dist * lofst;
        wrapper.s2 = n * v_dist * lofst;
        wrapper.s3 = -lofst;
        INIT_AGEN3(stream, wrapper);

        *(*cfgs)++ = extract_agen_cfg(stream);

        niter1 = height / v_dist;
        niter2 = v_dist;

        stream       = init((DVTYPE *)dst);
        wrapper.size = sizeof(DataType);
        wrapper.n1   = niter1;
        wrapper.n2   = niter2;
        wrapper.s1   = v_dist * lofst;
        wrapper.s2   = lofst;
        INIT_AGEN2(stream, wrapper);

        *(*cfgs)++ = extract_agen_cfg(stream);
    }

    void bitonic_merge_n_way_transpose_init(DataType *ping, DataType *pong, int *it, AgenCFG **cfgs, int lofst,
                                            int height, int v_dist)
    {
        agen stream;

        DataType *src = ping;
        DataType *dst = pong;
        (*it)++;

        int vecw   = pva_elementsof(DVTYPEX);
        int niter1 = height / v_dist;
        int niter2 = v_dist;

        AgenWrapper wrapper;

        stream       = init((DVTYPE *)src);
        wrapper.size = sizeof(DataType);
        wrapper.n1   = niter1;
        wrapper.n2   = niter2;
        wrapper.s1   = v_dist * lofst;
        wrapper.s2   = lofst;
        INIT_AGEN2(stream, wrapper);

        *(*cfgs)++ = extract_agen_cfg(stream);

        stream           = init((DVTYPE *)dst);
        wrapper.size     = sizeof(DataType);
        stream.lane_ofst = height * 2 / vecw;
        wrapper.n1       = 1;
        wrapper.n2       = niter1;
        wrapper.n3       = niter2;
        wrapper.s1       = 1;
        wrapper.s2       = 2 * v_dist;
        wrapper.s3       = 1;
        INIT_AGEN3(stream, wrapper);

        *(*cfgs)++ = extract_agen_cfg(stream);
    }

    void bitonic_merge_n_way_init(DataType *ping, DataType *pong, int *it, AgenCFG **cfgs, int lofst, int height,
                                  int v_dist)
    {
        agen stream;

        DataType *src = (*it) & 1 ? ping : pong;
        DataType *dst = (*it) & 1 ? pong : ping;
        (*it)++;

        int niter1 = height / v_dist;
        int niter2 = v_dist;

        AgenWrapper wrapper;

        stream       = init((DVTYPE *)src);
        wrapper.size = sizeof(DataType);
        wrapper.n1   = niter1;
        wrapper.n2   = niter2;
        wrapper.s1   = v_dist * lofst;
        wrapper.s2   = lofst;
        INIT_AGEN2(stream, wrapper);

        *(*cfgs)++ = extract_agen_cfg(stream);

        stream       = init((DVTYPE *)dst);
        wrapper.size = sizeof(DataType);
        wrapper.n1   = niter1;
        wrapper.n2   = niter2;
        wrapper.s1   = v_dist * lofst;
        wrapper.s2   = lofst;
        INIT_AGEN2(stream, wrapper);

        *(*cfgs)++ = extract_agen_cfg(stream);
    }

    void bitonic_merge_init(DataType *ping, DataType *pong, DataType *output, int *it, AgenCFG **cfgs, int lofst,
                            int height, int v_dist, int reverse, int last)
    {
        while (v_dist > 0)
        {
            if (v_dist >= 2)
            {
                if (reverse)
                {
                    bitonic_merge_n_way_reverse_init(ping, pong, it, cfgs, lofst, height, v_dist >> (2 - 1), 4);
                }
                else
                {
                    if (last && (v_dist >> 2) == 0)
                    {
                        DataType *src = (*it) & 1 ? ping : pong;
                        bitonic_merge_n_way_transpose_init(src, output, it, cfgs, lofst, height, v_dist >> (2 - 1));
                    }
                    else
                    {
                        bitonic_merge_n_way_init(ping, pong, it, cfgs, lofst, height, v_dist >> (2 - 1));
                    }
                }
                v_dist = v_dist >> 2;
            }
            else
            {
                if (reverse)
                {
                    bitonic_merge_n_way_reverse_init(ping, pong, it, cfgs, lofst, height, v_dist >> (1 - 1), 4);
                }
                else
                {
                    if (last && (v_dist >> 1) == 0)
                    {
                        DataType *src = (*it) & 1 ? ping : pong;
                        bitonic_merge_n_way_transpose_init(src, output, it, cfgs, lofst, height, v_dist >> (1 - 1));
                    }
                    else
                    {
                        bitonic_merge_n_way_init(ping, pong, it, cfgs, lofst, height, v_dist >> (1 - 1));
                    }
                }
                v_dist = v_dist >> 1;
            }
            reverse = 0;
        }
    }

    void sort_payload_height_512(AgenCFG *cfgs, vshortx *permute_short_ptr, vshortx *permute_reverse_short_ptr)
    {
        int height = 512;

        even_odd_merge_sort_payload(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_transpose_2_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_transpose_2_way(&cfgs, height, permute_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way_transp(&cfgs, height);
    }

    void sort_payload_height_256(AgenCFG *cfgs, vshortx *permute_short_ptr, vshortx *permute_reverse_short_ptr)
    {
        int height = 256;
        even_odd_merge_sort_payload(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_transpose_2_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_transpose_2_way(&cfgs, height, permute_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way_transp(&cfgs, height);
    }

    void sort_payload_height_128(AgenCFG *cfgs, vshortx *permute_short_ptr, vshortx *permute_reverse_short_ptr)
    {
        int height = 128;
        even_odd_merge_sort_payload(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_transpose_2_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_transpose_2_way(&cfgs, height, permute_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way_transp(&cfgs, height);
    }

    void sort_payload_height_64(AgenCFG *cfgs, vshortx *permute_short_ptr, vshortx *permute_reverse_short_ptr)
    {
        int height = 64;
        even_odd_merge_sort_payload(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_transpose_2_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_transpose_2_way(&cfgs, height, permute_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way_transp(&cfgs, height);
    }

    void sort_payload_height_32(AgenCFG *cfgs, vshortx *permute_short_ptr, vshortx *permute_reverse_short_ptr)
    {
        int height = 32;
        even_odd_merge_sort_payload(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way_reverse(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_transpose_2_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way(&cfgs, height);
        bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
        bitonic_merge_transpose_2_way(&cfgs, height, permute_short_ptr);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_4_way(&cfgs, height);
        bitonic_merge_2_way_transp(&cfgs, height);
    }

public:
    SortPayloadVpu()
    {
        static_assert(Size == 512 || Size == 1024 || Size == 2048 || Size == 4096 || Size == 8192,
                      "Unsupported size value for SortPayloadVpu template instantiation.");
    }

    void Init(DataType *src, DataType *dst, SortPayloadContext *context)
    {
        int it = 0;
        int v_dist, h_dist;

        int vecw   = pva_elementsof(DVTYPEX);
        int lofst  = vecw + 2;
        int height = Size / vecw;

        AgenCFG *cfgs = context->cfgs;
        DataType *tmp = reinterpret_cast<DataType *>(context->scratch);

        even_odd_merge_sort_payload_init(src, tmp, &cfgs, lofst, height);

        for (v_dist = 8; v_dist < height; v_dist *= 2)
        {
            bitonic_merge_init(src, tmp, dst, &it, &cfgs, lofst, height, v_dist, 1, 0);
        }

        for (h_dist = 1; h_dist < vecw / 2; h_dist *= 2)
        {
            bitonic_merge_transpose_init(src, tmp, &it, &cfgs, lofst, height, h_dist);
            bitonic_merge_init(src, tmp, dst, &it, &cfgs, lofst, height, height / 2, 0, h_dist == vecw / 4);
        }
    }

    void Execute(SortPayloadContext *context)
    {
        AgenCFG *cfgs        = context->cfgs;
        constexpr int height = Size / pva_elementsof(DVTYPEX);

        const short permute_index_short[] = {
            0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf,
        };

        const short permute_index_reverse_short[] = {
            0xe, 0xf, 0xc, 0xd, 0xa, 0xb, 0x8, 0x9, 0x6, 0x7, 0x4, 0x5, 0x2, 0x3, 0x0, 0x1,
        };

        vshortx permute_short         = sign_extend(*((vshort *)permute_index_short));
        vshortx permute_reverse_short = sign_extend(*((vshort *)permute_index_reverse_short));

        if constexpr (height == 512)
            sort_payload_height_512(cfgs, &permute_short, &permute_reverse_short);
        else if constexpr (height == 256)
            sort_payload_height_256(cfgs, &permute_short, &permute_reverse_short);
        else if constexpr (height == 128)
            sort_payload_height_128(cfgs, &permute_short, &permute_reverse_short);
        else if constexpr (height == 64)
            sort_payload_height_64(cfgs, &permute_short, &permute_reverse_short);
        else if constexpr (height == 32)
            sort_payload_height_32(cfgs, &permute_short, &permute_reverse_short);
    }

    static constexpr size_t MIN_INPUT_BUFFER_SIZE{(pva_elementsof(DVTYPEX) + 2) * (Size / pva_elementsof(DVTYPEX)) *
                                                  sizeof(DataType)};
    static constexpr size_t MIN_OUTPUT_BUFFER_SIZE{(2 * Size / pva_elementsof(DVTYPEX) + 2) *
                                                   (pva_elementsof(DVTYPEX) / 2) * sizeof(DataType)};
};
} // namespace pvaApl

#endif /* PVA_APL_SORT_PAYLOAD_VPU_HPP */