Program Listing for File matx_util.hpp#

Return to documentation for file (python/morpheus/morpheus/_lib/include/morpheus/utilities/matx_util.hpp)

/*
 * SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include "morpheus/objects/dev_mem_info.hpp"
#include "morpheus/objects/dtype.hpp"
#include "morpheus/objects/rmm_tensor.hpp"
#include "morpheus/objects/tensor_object.hpp"
#include "morpheus/types.hpp"  // for ShapeType, TensorIndex

#include <memory>
#include <vector>

namespace morpheus {

struct MatxUtil
{
    static std::shared_ptr<rmm::device_buffer> cast(const DevMemInfo& input, TypeId output_type);

    static std::shared_ptr<rmm::device_buffer> create_seq_ids(TensorIndex row_count,
                                                              TensorIndex fea_len,
                                                              TypeId output_type,
                                                              std::shared_ptr<MemoryDescriptor> md,
                                                              TensorIndex start_idx = 0);

    static void offset_seq_ids(const DevMemInfo& input, TensorIndex offset);

    static std::shared_ptr<rmm::device_buffer> logits(const DevMemInfo& input);

    static std::shared_ptr<rmm::device_buffer> transpose(const DevMemInfo& input);

    static std::shared_ptr<rmm::device_buffer> threshold(const DevMemInfo& input, double thresh_val, bool by_row);

    static std::shared_ptr<rmm::device_buffer> reduce_max(const DevMemInfo& input,
                                                          const ShapeType& seq_ids,
                                                          TensorIndex seq_id_offset,
                                                          const ShapeType& output_shape);
};  // end of group
}  // namespace morpheus