pvaAplSortPayloadVpu.hpp#
Fully qualified name: public/src/primitive/pvaAplSortPayloadVpu.hpp
File members: public/src/primitive/pvaAplSortPayloadVpu.hpp
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#ifndef PVA_APL_SORT_PAYLOAD_VPU_HPP
#define PVA_APL_SORT_PAYLOAD_VPU_HPP
#include <cupva_device.h>
#include <stdio.h>
#include <string.h>
#include <type_traits>
namespace pvaApl {
// Define a traits struct to map DataType to specific vector types
template<typename DataType>
struct SortPayloadVectorTypes;
template<>
struct SortPayloadVectorTypes<uint32_t>
{
using DVTYPE = dvint;
using DVTYPEX = dvintx;
using PMT_VTYPEX = vshortx;
};
template<>
struct SortPayloadVectorTypes<int32_t>
{
using DVTYPE = dvint;
using DVTYPEX = dvintx;
using PMT_VTYPEX = vshortx;
};
struct SortPayloadContext
{
AgenCFG cfgs[256];
static constexpr size_t SCRATCH_SIZE = ((16 + 2) * 512 * sizeof(uint32_t));
uint8_t scratch[SCRATCH_SIZE];
};
template<typename DataType, int Size>
class SortPayloadVpu
{
private:
using DVTYPE = typename SortPayloadVectorTypes<DataType>::DVTYPE;
using DVTYPEX = typename SortPayloadVectorTypes<DataType>::DVTYPEX;
using PMT_VTYPEX = typename SortPayloadVectorTypes<DataType>::PMT_VTYPEX;
template<typename AgenType>
auto DVLOAD(AgenType &ag) -> decltype(auto)
{
if constexpr (std::is_same_v<DataType, uint32_t>)
{
return static_cast<dvintx>(dvuint_load(ag));
}
else if constexpr (std::is_same_v<DataType, int32_t>)
{
return static_cast<dvintx>(dvint_load(ag));
}
else
{
static_assert(!std::is_same_v<DataType, DataType>, "Unsupported type for dvload");
}
}
template<typename AgenType, typename VecType>
auto DVLOAD_T2(AgenType &ag, VecType vx) -> decltype(auto)
{
if constexpr (std::is_same_v<DataType, uint32_t>)
{
return static_cast<dvintx>(dvuint_load_perm_transp2(ag, vx));
}
else if constexpr (std::is_same_v<DataType, int32_t>)
{
return static_cast<dvintx>(dvint_load_perm_transp2(ag, vx));
}
else
{
static_assert(!std::is_same_v<DataType, DataType>, "Unsupported type for dvload");
}
}
inline void VSORT2PL_INPLACE(DVTYPEX &v1, DVTYPEX &v2)
{
vsort2pl(v1.lo, v2.lo, v1.lo, v2.lo);
vsort2pl(v1.hi, v2.hi, v1.hi, v2.hi);
}
inline void MERGE_2_WAY(DVTYPEX &v1, DVTYPEX &v2)
{
VSORT2PL_INPLACE(v1, v2);
}
inline void MERGE_4_WAY(DVTYPEX &v1, DVTYPEX &v2, DVTYPEX &v3, DVTYPEX &v4)
{
chess_separator();
VSORT2PL_INPLACE(v1, v3);
chess_separator();
VSORT2PL_INPLACE(v2, v4);
chess_separator();
MERGE_2_WAY(v1, v2);
chess_separator();
MERGE_2_WAY(v3, v4);
}
void even_odd_merge_sort_payload(AgenCFG **cfgs, int height)
{
DVTYPEX v1, v2, v3, v4, v5, v6, v7, v8;
agen in, out;
in = init_agen_from_cfg(*(*cfgs)++);
out = init_agen_from_cfg(*(*cfgs)++);
int niter = height >> 3;
for (int i = 0; i < niter; i++) chess_loop_range(3, )
chess_prepare_for_pipelining
{
v1 = DVLOAD(in);
v2 = DVLOAD(in);
v3 = DVLOAD(in);
v4 = DVLOAD(in);
v5 = DVLOAD(in);
v6 = DVLOAD(in);
v7 = DVLOAD(in);
v8 = DVLOAD(in);
VSORT2PL_INPLACE(v1, v2);
VSORT2PL_INPLACE(v3, v4);
VSORT2PL_INPLACE(v5, v6);
VSORT2PL_INPLACE(v7, v8);
VSORT2PL_INPLACE(v1, v3);
VSORT2PL_INPLACE(v5, v7);
VSORT2PL_INPLACE(v2, v4);
VSORT2PL_INPLACE(v6, v8);
VSORT2PL_INPLACE(v2, v3);
VSORT2PL_INPLACE(v6, v7);
VSORT2PL_INPLACE(v1, v5);
VSORT2PL_INPLACE(v2, v6);
VSORT2PL_INPLACE(v3, v7);
VSORT2PL_INPLACE(v4, v8);
VSORT2PL_INPLACE(v3, v5);
VSORT2PL_INPLACE(v4, v6);
VSORT2PL_INPLACE(v2, v3);
VSORT2PL_INPLACE(v4, v5);
VSORT2PL_INPLACE(v6, v7);
vstore(v1, out);
vstore(v2, out);
vstore(v3, out);
vstore(v4, out);
vstore(v5, out);
vstore(v6, out);
vstore(v7, out);
vstore(v8, out);
}
}
void bitonic_merge_4_way_reverse(AgenCFG **cfgs, int height)
{
DVTYPEX v1, v2, v3, v4;
agen_A in1, in2;
agen_B out;
in1 = init_agen_A_from_cfg(*(*cfgs)++);
in2 = init_agen_A_from_cfg(*(*cfgs)++);
out = init_agen_B_from_cfg(*(*cfgs)++);
int niter = height >> 2;
for (int i = 0; i < niter; i++) chess_loop_range(8, )
chess_unroll_loop(2)
chess_prepare_for_pipelining
{
v1 = DVLOAD(in1);
v2 = DVLOAD(in1);
v3 = DVLOAD(in2);
v4 = DVLOAD(in2);
MERGE_4_WAY(v1, v2, v3, v4);
vstore(v1, out);
vstore(v2, out);
vstore(v3, out);
vstore(v4, out);
}
}
void bitonic_merge_4_way_transp(AgenCFG **cfgs, int height)
{
DVTYPEX v1, v2, v3, v4;
agen in, out;
in = init_agen_from_cfg(*(*cfgs)++);
out = init_agen_from_cfg(*(*cfgs)++);
int niter = height >> 2;
for (int i = 0; i < niter; i++) chess_loop_range(8, )
chess_unroll_loop(2)
chess_prepare_for_pipelining
{
v1 = DVLOAD(in);
v2 = DVLOAD(in);
v3 = DVLOAD(in);
v4 = DVLOAD(in);
MERGE_4_WAY(v1, v2, v3, v4);
vstore_transp2(v1, out);
vstore_transp2(v2, out);
vstore_transp2(v3, out);
vstore_transp2(v4, out);
}
}
void bitonic_merge_4_way(AgenCFG **cfgs, int height)
{
DVTYPEX v1, v2, v3, v4;
agen in, out;
in = init_agen_from_cfg(*(*cfgs)++);
out = init_agen_from_cfg(*(*cfgs)++);
int niter = height >> 2;
for (int i = 0; i < niter; i++) chess_loop_range(8, )
chess_unroll_loop(2)
chess_prepare_for_pipelining
{
v1 = DVLOAD(in);
v2 = DVLOAD(in);
v3 = DVLOAD(in);
v4 = DVLOAD(in);
MERGE_4_WAY(v1, v2, v3, v4);
vstore(v1, out);
vstore(v2, out);
vstore(v3, out);
vstore(v4, out);
}
}
void bitonic_merge_2_way_reverse(AgenCFG **cfgs, int height)
{
DVTYPEX v1, v2;
agen_A in;
agen_B out;
in = init_agen_A_from_cfg(*(*cfgs)++);
out = init_agen_B_from_cfg(*(*cfgs)++);
int niter = height >> 1;
for (int i = 0; i < niter; i++) chess_loop_range(8, )
chess_unroll_loop(4)
chess_prepare_for_pipelining
{
v1 = DVLOAD(in);
v2 = DVLOAD(in);
MERGE_2_WAY(v1, v2);
vstore(v1, out);
vstore(v2, out);
}
}
void bitonic_merge_2_way_transp(AgenCFG **cfgs, int height)
{
DVTYPEX v1, v2;
agen in, out;
in = init_agen_from_cfg(*(*cfgs)++);
out = init_agen_from_cfg(*(*cfgs)++);
int niter = height >> 1;
for (int i = 0; i < niter; i++) chess_loop_range(8, )
chess_unroll_loop(4)
chess_prepare_for_pipelining
{
v1 = DVLOAD(in);
v2 = DVLOAD(in);
MERGE_2_WAY(v1, v2);
vstore_transp2(v1, out);
vstore_transp2(v2, out);
}
}
void bitonic_merge_2_way(AgenCFG **cfgs, int height)
{
DVTYPEX v1, v2;
agen in, out;
in = init_agen_from_cfg(*(*cfgs)++);
out = init_agen_from_cfg(*(*cfgs)++);
int niter = height >> 1;
for (int i = 0; i < niter; i++) chess_loop_range(8, )
chess_unroll_loop(4)
chess_prepare_for_pipelining
{
v1 = DVLOAD(in);
v2 = DVLOAD(in);
MERGE_2_WAY(v1, v2);
vstore(v1, out);
vstore(v2, out);
}
}
void bitonic_merge_transpose_4_way_reverse(AgenCFG **cfgs, int height, vshortx *permute_short_ptr,
vshortx *permute_reverse_short_ptr)
{
DVTYPEX v1, v2, v3, v4;
vshortx pmt_short, pmt_reverse_short;
agen_A in1, in2;
agen_B out;
in1 = init_agen_A_from_cfg(*(*cfgs)++);
in2 = init_agen_A_from_cfg(*(*cfgs)++);
out = init_agen_B_from_cfg(*(*cfgs)++);
pmt_short = *permute_short_ptr;
pmt_reverse_short = *permute_reverse_short_ptr;
int niter = height >> 2;
for (int i = 0; i < niter; i++) chess_loop_range(8, )
chess_unroll_loop(2)
chess_prepare_for_pipelining
{
v1 = DVLOAD_T2(in1, pmt_short);
v2 = DVLOAD_T2(in1, pmt_short);
v3 = DVLOAD_T2(in2, pmt_reverse_short);
v4 = DVLOAD_T2(in2, pmt_reverse_short);
MERGE_4_WAY(v1, v2, v3, v4);
vstore_transp2(v1, out);
vstore_transp2(v2, out);
vstore_transp2(v3, out);
vstore_transp2(v4, out);
}
}
void bitonic_merge_transpose_4_way(AgenCFG **cfgs, int height, vshortx *permute_short_ptr)
{
DVTYPEX v1, v2, v3, v4;
vshortx pmt_short;
agen in, out;
in = init_agen_from_cfg(*(*cfgs)++);
out = init_agen_from_cfg(*(*cfgs)++);
pmt_short = *permute_short_ptr;
int niter = height >> 2;
for (int i = 0; i < niter; i++) chess_loop_range(8, )
chess_unroll_loop(2)
chess_prepare_for_pipelining
{
v1 = DVLOAD_T2(in, pmt_short);
v2 = DVLOAD_T2(in, pmt_short);
v3 = DVLOAD_T2(in, pmt_short);
v4 = DVLOAD_T2(in, pmt_short);
MERGE_4_WAY(v1, v2, v3, v4);
vstore_transp2(v1, out);
vstore_transp2(v2, out);
vstore_transp2(v3, out);
vstore_transp2(v4, out);
}
}
void bitonic_merge_transpose_2_way_reverse(AgenCFG **cfgs, int height, vshortx *permute_short_ptr,
vshortx *permute_reverse_short_ptr)
{
DVTYPEX v1, v2;
vshortx pmt_short, pmt_reverse_short;
agen_A in1;
agen_A in2;
agen_B out;
in1 = init_agen_A_from_cfg(*(*cfgs)++);
in2 = init_agen_A_from_cfg(*(*cfgs)++);
out = init_agen_B_from_cfg(*(*cfgs)++);
pmt_short = *permute_short_ptr;
pmt_reverse_short = *permute_reverse_short_ptr;
int niter = height >> 1;
for (int i = 0; i < niter; i++) chess_loop_range(8, )
chess_unroll_loop(4)
chess_prepare_for_pipelining
{
v1 = DVLOAD_T2(in1, pmt_short);
v2 = DVLOAD_T2(in2, pmt_reverse_short);
VSORT2PL_INPLACE(v1, v2);
vstore_transp2(v1, out);
vstore_transp2(v2, out);
}
}
void bitonic_merge_transpose_2_way(AgenCFG **cfgs, int height, vshortx *permute_short_ptr)
{
DVTYPEX v1, v2;
vshortx pmt_short;
agen_A in;
agen_B out;
in = init_agen_A_from_cfg(*(*cfgs)++);
out = init_agen_B_from_cfg(*(*cfgs)++);
pmt_short = *permute_short_ptr;
int niter = height >> 1;
for (int i = 0; i < niter; i++) chess_loop_range(8, )
chess_unroll_loop(4)
chess_prepare_for_pipelining
{
v1 = DVLOAD_T2(in, pmt_short);
v2 = DVLOAD_T2(in, pmt_short);
VSORT2PL_INPLACE(v1, v2);
vstore_transp2(v1, out);
vstore_transp2(v2, out);
}
}
void even_odd_merge_sort_payload_init(DataType *ping, DataType *pong, AgenCFG **cfgs, int lofst, int height)
{
AgenWrapper wrapper;
int vecw = pva_elementsof(DVTYPEX);
agen in = init((DVTYPE *)ping);
wrapper.size = sizeof(DataType);
wrapper.n1 = height;
wrapper.s1 = vecw;
INIT_AGEN1(in, wrapper);
agen out = init((DVTYPE *)pong);
wrapper.size = sizeof(DataType);
wrapper.n1 = height;
wrapper.s1 = lofst;
INIT_AGEN1(out, wrapper);
*(*cfgs)++ = extract_agen_cfg(in);
*(*cfgs)++ = extract_agen_cfg(out);
}
void bitonic_merge_transpose_n_way_reverse_init(DataType *ping, DataType *pong, int *it, AgenCFG **cfgs, int lofst,
int height, int h_dist, int n)
{
agen in1, in2, out;
int vecw = pva_elementsof(DVTYPEX);
DataType *src = (*it) & 1 ? ping : pong;
DataType *dst = (*it) & 1 ? pong : ping;
(*it)++;
int niter1 = 1;
int niter2 = n / 2;
int niter3 = (vecw / 2) / h_dist / n;
int niter4 = height / (vecw / 2);
int niter5 = h_dist;
AgenWrapper wrapper;
wrapper.size = sizeof(DataType);
in1 = init((DVTYPE *)src);
in1.lane_ofst = lofst / vecw;
wrapper.n1 = niter1;
wrapper.n2 = niter2;
wrapper.n3 = niter3;
wrapper.n4 = niter4;
wrapper.n5 = niter5;
wrapper.s1 = 1;
wrapper.s2 = h_dist * 2;
wrapper.s3 = n * h_dist * 2;
wrapper.s4 = vecw / 2 * lofst;
wrapper.s5 = 2;
INIT_AGEN5(in1, wrapper);
*(*cfgs)++ = extract_agen_cfg(in1);
in2 = init((DVTYPE *)(src + (height - vecw / 2) * lofst + h_dist * n * 2 - 2));
in2.lane_ofst = lofst / vecw;
wrapper.s1 = 1;
wrapper.s2 = -h_dist * 2;
wrapper.s3 = n * h_dist * 2;
wrapper.s4 = -vecw / 2 * lofst;
wrapper.s5 = -2;
INIT_AGEN5(in2, wrapper);
*(*cfgs)++ = extract_agen_cfg(in2);
niter1 = 1;
niter2 = (vecw / 2) / h_dist;
niter3 = height / (vecw / 2);
niter4 = h_dist;
out = init((DVTYPE *)dst);
out.lane_ofst = lofst / vecw;
wrapper.n1 = niter1;
wrapper.n2 = niter2;
wrapper.n3 = niter3;
wrapper.n4 = niter4;
wrapper.s1 = 1;
wrapper.s2 = h_dist * 2;
wrapper.s3 = vecw / 2 * lofst;
wrapper.s4 = 2;
INIT_AGEN4(out, wrapper);
*(*cfgs)++ = extract_agen_cfg(out);
}
void bitonic_merge_transpose_n_way_init(DataType *ping, DataType *pong, int *it, AgenCFG **cfgs, int lofst,
int height, int h_dist)
{
agen in, out;
int vecw = pva_elementsof(DVTYPEX);
DataType *src = (*it) & 1 ? ping : pong;
DataType *dst = (*it) & 1 ? pong : ping;
(*it)++;
int niter1 = 1;
int niter2 = (vecw / 2) / h_dist;
int niter3 = height / (vecw / 2);
int niter4 = h_dist;
AgenWrapper wrapper;
wrapper.size = sizeof(DataType);
in = init((DVTYPE *)src);
in.lane_ofst = lofst / vecw;
wrapper.n1 = niter1;
wrapper.n2 = niter2;
wrapper.n3 = niter3;
wrapper.n4 = niter4;
wrapper.s1 = 1;
wrapper.s2 = h_dist * 2;
wrapper.s3 = vecw / 2 * lofst;
wrapper.s4 = 2;
INIT_AGEN4(in, wrapper);
*(*cfgs)++ = extract_agen_cfg(in);
out = init((DVTYPE *)dst);
out.lane_ofst = lofst / vecw;
wrapper.n1 = niter1;
wrapper.n2 = niter2;
wrapper.n3 = niter3;
wrapper.n4 = niter4;
wrapper.s1 = 1;
wrapper.s2 = h_dist * 2;
wrapper.s3 = vecw / 2 * lofst;
wrapper.s4 = 2;
INIT_AGEN4(out, wrapper);
*(*cfgs)++ = extract_agen_cfg(out);
}
void bitonic_merge_transpose_init(DataType *ping, DataType *pong, int *it, AgenCFG **cfgs, int lofst, int height,
int h_dist)
{
int reverse = 1;
while (h_dist > 0)
{
if (h_dist > 1)
{
if (reverse)
{
bitonic_merge_transpose_n_way_reverse_init(ping, pong, it, cfgs, lofst, height, h_dist >> (2 - 1),
4);
}
else
{
bitonic_merge_transpose_n_way_init(ping, pong, it, cfgs, lofst, height, h_dist >> (2 - 1));
}
h_dist = h_dist >> 2;
}
else
{
if (reverse)
{
bitonic_merge_transpose_n_way_reverse_init(ping, pong, it, cfgs, lofst, height, h_dist, 2);
}
else
{
bitonic_merge_transpose_n_way_init(ping, pong, it, cfgs, lofst, height, h_dist);
}
h_dist = h_dist >> 1;
}
reverse = 0;
}
}
void bitonic_merge_n_way_reverse_init(DataType *ping, DataType *pong, int *it, AgenCFG **cfgs, int lofst,
int height, int v_dist, int n)
{
agen stream;
DataType *src = (*it) & 1 ? ping : pong;
DataType *dst = (*it) & 1 ? pong : ping;
(*it)++;
int niter1 = n / 2;
int niter2 = height / v_dist / n;
int niter3 = v_dist;
AgenWrapper wrapper;
wrapper.size = sizeof(DataType);
stream = init((DVTYPE *)src);
wrapper.n1 = niter1;
wrapper.n2 = niter2;
wrapper.n3 = niter3;
wrapper.s1 = v_dist * lofst;
wrapper.s2 = n * v_dist * lofst;
wrapper.s3 = lofst;
INIT_AGEN3(stream, wrapper);
*(*cfgs)++ = extract_agen_cfg(stream);
stream = init((DVTYPE *)(src + (n * v_dist - 1) * lofst));
wrapper.s1 = -v_dist * lofst;
wrapper.s2 = n * v_dist * lofst;
wrapper.s3 = -lofst;
INIT_AGEN3(stream, wrapper);
*(*cfgs)++ = extract_agen_cfg(stream);
niter1 = height / v_dist;
niter2 = v_dist;
stream = init((DVTYPE *)dst);
wrapper.size = sizeof(DataType);
wrapper.n1 = niter1;
wrapper.n2 = niter2;
wrapper.s1 = v_dist * lofst;
wrapper.s2 = lofst;
INIT_AGEN2(stream, wrapper);
*(*cfgs)++ = extract_agen_cfg(stream);
}
void bitonic_merge_n_way_transpose_init(DataType *ping, DataType *pong, int *it, AgenCFG **cfgs, int lofst,
int height, int v_dist)
{
agen stream;
DataType *src = ping;
DataType *dst = pong;
(*it)++;
int vecw = pva_elementsof(DVTYPEX);
int niter1 = height / v_dist;
int niter2 = v_dist;
AgenWrapper wrapper;
stream = init((DVTYPE *)src);
wrapper.size = sizeof(DataType);
wrapper.n1 = niter1;
wrapper.n2 = niter2;
wrapper.s1 = v_dist * lofst;
wrapper.s2 = lofst;
INIT_AGEN2(stream, wrapper);
*(*cfgs)++ = extract_agen_cfg(stream);
stream = init((DVTYPE *)dst);
wrapper.size = sizeof(DataType);
stream.lane_ofst = height * 2 / vecw;
wrapper.n1 = 1;
wrapper.n2 = niter1;
wrapper.n3 = niter2;
wrapper.s1 = 1;
wrapper.s2 = 2 * v_dist;
wrapper.s3 = 1;
INIT_AGEN3(stream, wrapper);
*(*cfgs)++ = extract_agen_cfg(stream);
}
void bitonic_merge_n_way_init(DataType *ping, DataType *pong, int *it, AgenCFG **cfgs, int lofst, int height,
int v_dist)
{
agen stream;
DataType *src = (*it) & 1 ? ping : pong;
DataType *dst = (*it) & 1 ? pong : ping;
(*it)++;
int niter1 = height / v_dist;
int niter2 = v_dist;
AgenWrapper wrapper;
stream = init((DVTYPE *)src);
wrapper.size = sizeof(DataType);
wrapper.n1 = niter1;
wrapper.n2 = niter2;
wrapper.s1 = v_dist * lofst;
wrapper.s2 = lofst;
INIT_AGEN2(stream, wrapper);
*(*cfgs)++ = extract_agen_cfg(stream);
stream = init((DVTYPE *)dst);
wrapper.size = sizeof(DataType);
wrapper.n1 = niter1;
wrapper.n2 = niter2;
wrapper.s1 = v_dist * lofst;
wrapper.s2 = lofst;
INIT_AGEN2(stream, wrapper);
*(*cfgs)++ = extract_agen_cfg(stream);
}
void bitonic_merge_init(DataType *ping, DataType *pong, DataType *output, int *it, AgenCFG **cfgs, int lofst,
int height, int v_dist, int reverse, int last)
{
while (v_dist > 0)
{
if (v_dist >= 2)
{
if (reverse)
{
bitonic_merge_n_way_reverse_init(ping, pong, it, cfgs, lofst, height, v_dist >> (2 - 1), 4);
}
else
{
if (last && (v_dist >> 2) == 0)
{
DataType *src = (*it) & 1 ? ping : pong;
bitonic_merge_n_way_transpose_init(src, output, it, cfgs, lofst, height, v_dist >> (2 - 1));
}
else
{
bitonic_merge_n_way_init(ping, pong, it, cfgs, lofst, height, v_dist >> (2 - 1));
}
}
v_dist = v_dist >> 2;
}
else
{
if (reverse)
{
bitonic_merge_n_way_reverse_init(ping, pong, it, cfgs, lofst, height, v_dist >> (1 - 1), 4);
}
else
{
if (last && (v_dist >> 1) == 0)
{
DataType *src = (*it) & 1 ? ping : pong;
bitonic_merge_n_way_transpose_init(src, output, it, cfgs, lofst, height, v_dist >> (1 - 1));
}
else
{
bitonic_merge_n_way_init(ping, pong, it, cfgs, lofst, height, v_dist >> (1 - 1));
}
}
v_dist = v_dist >> 1;
}
reverse = 0;
}
}
void sort_payload_height_512(AgenCFG *cfgs, vshortx *permute_short_ptr, vshortx *permute_reverse_short_ptr)
{
int height = 512;
even_odd_merge_sort_payload(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_transpose_2_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_transpose_2_way(&cfgs, height, permute_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way_transp(&cfgs, height);
}
void sort_payload_height_256(AgenCFG *cfgs, vshortx *permute_short_ptr, vshortx *permute_reverse_short_ptr)
{
int height = 256;
even_odd_merge_sort_payload(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_transpose_2_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_transpose_2_way(&cfgs, height, permute_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way_transp(&cfgs, height);
}
void sort_payload_height_128(AgenCFG *cfgs, vshortx *permute_short_ptr, vshortx *permute_reverse_short_ptr)
{
int height = 128;
even_odd_merge_sort_payload(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_transpose_2_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_transpose_2_way(&cfgs, height, permute_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way_transp(&cfgs, height);
}
void sort_payload_height_64(AgenCFG *cfgs, vshortx *permute_short_ptr, vshortx *permute_reverse_short_ptr)
{
int height = 64;
even_odd_merge_sort_payload(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_transpose_2_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_transpose_2_way(&cfgs, height, permute_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way_transp(&cfgs, height);
}
void sort_payload_height_32(AgenCFG *cfgs, vshortx *permute_short_ptr, vshortx *permute_reverse_short_ptr)
{
int height = 32;
even_odd_merge_sort_payload(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way_reverse(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_transpose_2_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way(&cfgs, height);
bitonic_merge_transpose_4_way_reverse(&cfgs, height, permute_short_ptr, permute_reverse_short_ptr);
bitonic_merge_transpose_2_way(&cfgs, height, permute_short_ptr);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_4_way(&cfgs, height);
bitonic_merge_2_way_transp(&cfgs, height);
}
public:
SortPayloadVpu()
{
static_assert(Size == 512 || Size == 1024 || Size == 2048 || Size == 4096 || Size == 8192,
"Unsupported size value for SortPayloadVpu template instantiation.");
}
void Init(DataType *src, DataType *dst, SortPayloadContext *context)
{
int it = 0;
int v_dist, h_dist;
int vecw = pva_elementsof(DVTYPEX);
int lofst = vecw + 2;
int height = Size / vecw;
AgenCFG *cfgs = context->cfgs;
DataType *tmp = reinterpret_cast<DataType *>(context->scratch);
even_odd_merge_sort_payload_init(src, tmp, &cfgs, lofst, height);
for (v_dist = 8; v_dist < height; v_dist *= 2)
{
bitonic_merge_init(src, tmp, dst, &it, &cfgs, lofst, height, v_dist, 1, 0);
}
for (h_dist = 1; h_dist < vecw / 2; h_dist *= 2)
{
bitonic_merge_transpose_init(src, tmp, &it, &cfgs, lofst, height, h_dist);
bitonic_merge_init(src, tmp, dst, &it, &cfgs, lofst, height, height / 2, 0, h_dist == vecw / 4);
}
}
void Execute(SortPayloadContext *context)
{
AgenCFG *cfgs = context->cfgs;
constexpr int height = Size / pva_elementsof(DVTYPEX);
const short permute_index_short[] = {
0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf,
};
const short permute_index_reverse_short[] = {
0xe, 0xf, 0xc, 0xd, 0xa, 0xb, 0x8, 0x9, 0x6, 0x7, 0x4, 0x5, 0x2, 0x3, 0x0, 0x1,
};
vshortx permute_short = sign_extend(*((vshort *)permute_index_short));
vshortx permute_reverse_short = sign_extend(*((vshort *)permute_index_reverse_short));
if constexpr (height == 512)
sort_payload_height_512(cfgs, &permute_short, &permute_reverse_short);
else if constexpr (height == 256)
sort_payload_height_256(cfgs, &permute_short, &permute_reverse_short);
else if constexpr (height == 128)
sort_payload_height_128(cfgs, &permute_short, &permute_reverse_short);
else if constexpr (height == 64)
sort_payload_height_64(cfgs, &permute_short, &permute_reverse_short);
else if constexpr (height == 32)
sort_payload_height_32(cfgs, &permute_short, &permute_reverse_short);
}
static constexpr size_t MIN_INPUT_BUFFER_SIZE{(pva_elementsof(DVTYPEX) + 2) * (Size / pva_elementsof(DVTYPEX)) *
sizeof(DataType)};
static constexpr size_t MIN_OUTPUT_BUFFER_SIZE{(2 * Size / pva_elementsof(DVTYPEX) + 2) *
(pva_elementsof(DVTYPEX) / 2) * sizeof(DataType)};
};
} // namespace pvaApl
#endif /* PVA_APL_SORT_PAYLOAD_VPU_HPP */