Program Listing for File multi.hpp

Return to documentation for file (morpheus/_lib/include/morpheus/messages/multi.hpp)

Copy
Copied!
            

/* * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "morpheus/messages/meta.hpp" #include "morpheus/objects/table_info.hpp" #include "morpheus/objects/tensor_object.hpp" #include "morpheus/types.hpp" // for TensorIndex #include <mrc/utils/macros.hpp> // for MRC_PTR_CAST #include <pybind11/pytypes.h> #include <pybind11/stl.h> // IWYU pragma: keep #include <memory> #include <string> #include <vector> namespace morpheus { /****** Component public implementations *******************/ /****** MultiMessage****************************************/ #pragma GCC visibility push(default) class MultiMessage; template <typename DerivedT, typename... BasesT> class DerivedMultiMessage : public BasesT... { public: virtual ~DerivedMultiMessage() = default; std::shared_ptr<DerivedT> get_slice(TensorIndex start, TensorIndex stop) const { std::shared_ptr<MultiMessage> new_message = this->clone_impl(); this->get_slice_impl(new_message, start, stop); return MRC_PTR_CAST(DerivedT, new_message); } std::shared_ptr<DerivedT> copy_ranges(const std::vector<RangeType>& ranges, TensorIndex num_selected_rows) const { std::shared_ptr<MultiMessage> new_message = this->clone_impl(); this->copy_ranges_impl(new_message, ranges, num_selected_rows); return MRC_PTR_CAST(DerivedT, new_message); } protected: virtual void get_slice_impl(std::shared_ptr<MultiMessage> new_message, TensorIndex start, TensorIndex stop) const = 0; virtual void copy_ranges_impl(std::shared_ptr<MultiMessage> new_message, const std::vector<RangeType>& ranges, TensorIndex num_selected_rows) const = 0; private: virtual std::shared_ptr<MultiMessage> clone_impl() const { // Cast `this` to the derived type auto derived_this = static_cast<const DerivedT*>(this); // Use copy constructor to make a clone return std::make_shared<DerivedT>(*derived_this); } }; // Single base class version. Should be the version used by default template <typename DerivedT, typename BaseT> class DerivedMultiMessage<DerivedT, BaseT> : public BaseT { public: using BaseT::BaseT; ~DerivedMultiMessage() override = default; std::shared_ptr<DerivedT> get_slice(TensorIndex start, TensorIndex stop) const { std::shared_ptr<MultiMessage> new_message = this->clone_impl(); this->get_slice_impl(new_message, start, stop); return MRC_PTR_CAST(DerivedT, new_message); } std::shared_ptr<DerivedT> copy_ranges(const std::vector<RangeType>& ranges, TensorIndex num_selected_rows) const { std::shared_ptr<MultiMessage> new_message = this->clone_impl(); this->copy_ranges_impl(new_message, ranges, num_selected_rows); return MRC_PTR_CAST(DerivedT, new_message); } protected: void get_slice_impl(std::shared_ptr<MultiMessage> new_message, TensorIndex start, TensorIndex stop) const override { return BaseT::get_slice_impl(new_message, start, stop); } void copy_ranges_impl(std::shared_ptr<MultiMessage> new_message, const std::vector<RangeType>& ranges, TensorIndex num_selected_rows) const override { return BaseT::copy_ranges_impl(new_message, ranges, num_selected_rows); } private: std::shared_ptr<MultiMessage> clone_impl() const override { // Cast `this` to the derived type auto derived_this = static_cast<const DerivedT*>(this); // Use copy constructor to make a clone return std::make_shared<DerivedT>(*derived_this); } }; // No base class version. This should only be used by `MultiMessage` itself. template <typename DerivedT> class DerivedMultiMessage<DerivedT> { public: virtual ~DerivedMultiMessage() = default; std::shared_ptr<DerivedT> get_slice(TensorIndex start, TensorIndex stop) const { std::shared_ptr<MultiMessage> new_message = this->clone_impl(); this->get_slice_impl(new_message, start, stop); return MRC_PTR_CAST(DerivedT, new_message); } std::shared_ptr<DerivedT> copy_ranges(const std::vector<RangeType>& ranges, TensorIndex num_selected_rows) const { std::shared_ptr<MultiMessage> new_message = this->clone_impl(); this->copy_ranges_impl(new_message, ranges, num_selected_rows); return MRC_PTR_CAST(DerivedT, new_message); } protected: virtual void get_slice_impl(std::shared_ptr<MultiMessage> new_message, TensorIndex start, TensorIndex stop) const = 0; virtual void copy_ranges_impl(std::shared_ptr<MultiMessage> new_message, const std::vector<RangeType>& ranges, TensorIndex num_selected_rows) const = 0; private: virtual std::shared_ptr<MultiMessage> clone_impl() const { // Cast `this` to the derived type auto derived_this = static_cast<const DerivedT*>(this); // Use copy constructor to make a clone return std::make_shared<DerivedT>(*derived_this); } }; class MultiMessage : public DerivedMultiMessage<MultiMessage> { public: MultiMessage(const MultiMessage& other) = default; MultiMessage(std::shared_ptr<MessageMeta> m, TensorIndex offset = 0, TensorIndex count = -1); std::shared_ptr<MessageMeta> meta; TensorIndex mess_offset{0}; TensorIndex mess_count{0}; std::vector<std::string> get_meta_column_names() const; TableInfo get_meta(); TableInfo get_meta(const std::string& col_name); TableInfo get_meta(const std::vector<std::string>& column_names); void set_meta(const std::string& col_name, TensorObject tensor); void set_meta(const std::vector<std::string>& column_names, const std::vector<TensorObject>& tensors); protected: void get_slice_impl(std::shared_ptr<MultiMessage> new_message, TensorIndex start, TensorIndex stop) const override; void copy_ranges_impl(std::shared_ptr<MultiMessage> new_message, const std::vector<RangeType>& ranges, TensorIndex num_selected_rows) const override; virtual std::shared_ptr<MessageMeta> copy_meta_ranges(const std::vector<RangeType>& ranges) const; std::vector<RangeType> apply_offset_to_ranges(TensorIndex offset, const std::vector<RangeType>& ranges) const; }; /****** MultiMessageInterfaceProxy**************************/ struct MultiMessageInterfaceProxy { static std::shared_ptr<MultiMessage> init(std::shared_ptr<MessageMeta> meta, TensorIndex mess_offset, TensorIndex mess_count); static std::shared_ptr<MessageMeta> meta(const MultiMessage& self); static TensorIndex mess_offset(const MultiMessage& self); static TensorIndex mess_count(const MultiMessage& self); static std::vector<std::string> get_meta_column_names(const MultiMessage& self); static pybind11::object get_meta(MultiMessage& self); static pybind11::object get_meta(MultiMessage& self, std::string col_name); static pybind11::object get_meta(MultiMessage& self, std::vector<std::string> columns); // This overload is necessary to match the python signature where you can call self.get_meta(None) static pybind11::object get_meta(MultiMessage& self, pybind11::none none_obj); static pybind11::object get_meta_list(MultiMessage& self, pybind11::object col_name); static void set_meta(MultiMessage& self, pybind11::object columns, pybind11::object value); static std::shared_ptr<MultiMessage> get_slice(MultiMessage& self, TensorIndex start, TensorIndex stop); static std::shared_ptr<MultiMessage> copy_ranges(MultiMessage& self, const std::vector<RangeType>& ranges, pybind11::object num_selected_rows); }; #pragma GCC visibility pop// end of group } // namespace morpheus

© Copyright 2024, NVIDIA. Last updated on Apr 11, 2024.