Program Listing for File multi.hpp
↰ Return to documentation for file (morpheus/_lib/include/morpheus/messages/multi.hpp
)
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "morpheus/messages/meta.hpp"
#include "morpheus/objects/table_info.hpp"
#include "morpheus/objects/tensor_object.hpp"
#include "morpheus/types.hpp" // for TensorIndex
#include <mrc/utils/macros.hpp> // for MRC_PTR_CAST
#include <pybind11/pytypes.h>
#include <pybind11/stl.h> // IWYU pragma: keep
#include <memory>
#include <string>
#include <vector>
namespace morpheus {
/****** Component public implementations *******************/
/****** MultiMessage****************************************/
#pragma GCC visibility push(default)
class MultiMessage;
template <typename DerivedT, typename... BasesT>
class DerivedMultiMessage : public BasesT...
{
public:
virtual ~DerivedMultiMessage() = default;
std::shared_ptr<DerivedT> get_slice(TensorIndex start, TensorIndex stop) const
{
std::shared_ptr<MultiMessage> new_message = this->clone_impl();
this->get_slice_impl(new_message, start, stop);
return MRC_PTR_CAST(DerivedT, new_message);
}
std::shared_ptr<DerivedT> copy_ranges(const std::vector<RangeType>& ranges, TensorIndex num_selected_rows) const
{
std::shared_ptr<MultiMessage> new_message = this->clone_impl();
this->copy_ranges_impl(new_message, ranges, num_selected_rows);
return MRC_PTR_CAST(DerivedT, new_message);
}
protected:
virtual void get_slice_impl(std::shared_ptr<MultiMessage> new_message,
TensorIndex start,
TensorIndex stop) const = 0;
virtual void copy_ranges_impl(std::shared_ptr<MultiMessage> new_message,
const std::vector<RangeType>& ranges,
TensorIndex num_selected_rows) const = 0;
private:
virtual std::shared_ptr<MultiMessage> clone_impl() const
{
// Cast `this` to the derived type
auto derived_this = static_cast<const DerivedT*>(this);
// Use copy constructor to make a clone
return std::make_shared<DerivedT>(*derived_this);
}
};
// Single base class version. Should be the version used by default
template <typename DerivedT, typename BaseT>
class DerivedMultiMessage<DerivedT, BaseT> : public BaseT
{
public:
using BaseT::BaseT;
~DerivedMultiMessage() override = default;
std::shared_ptr<DerivedT> get_slice(TensorIndex start, TensorIndex stop) const
{
std::shared_ptr<MultiMessage> new_message = this->clone_impl();
this->get_slice_impl(new_message, start, stop);
return MRC_PTR_CAST(DerivedT, new_message);
}
std::shared_ptr<DerivedT> copy_ranges(const std::vector<RangeType>& ranges, TensorIndex num_selected_rows) const
{
std::shared_ptr<MultiMessage> new_message = this->clone_impl();
this->copy_ranges_impl(new_message, ranges, num_selected_rows);
return MRC_PTR_CAST(DerivedT, new_message);
}
protected:
void get_slice_impl(std::shared_ptr<MultiMessage> new_message, TensorIndex start, TensorIndex stop) const override
{
return BaseT::get_slice_impl(new_message, start, stop);
}
void copy_ranges_impl(std::shared_ptr<MultiMessage> new_message,
const std::vector<RangeType>& ranges,
TensorIndex num_selected_rows) const override
{
return BaseT::copy_ranges_impl(new_message, ranges, num_selected_rows);
}
private:
std::shared_ptr<MultiMessage> clone_impl() const override
{
// Cast `this` to the derived type
auto derived_this = static_cast<const DerivedT*>(this);
// Use copy constructor to make a clone
return std::make_shared<DerivedT>(*derived_this);
}
};
// No base class version. This should only be used by `MultiMessage` itself.
template <typename DerivedT>
class DerivedMultiMessage<DerivedT>
{
public:
virtual ~DerivedMultiMessage() = default;
std::shared_ptr<DerivedT> get_slice(TensorIndex start, TensorIndex stop) const
{
std::shared_ptr<MultiMessage> new_message = this->clone_impl();
this->get_slice_impl(new_message, start, stop);
return MRC_PTR_CAST(DerivedT, new_message);
}
std::shared_ptr<DerivedT> copy_ranges(const std::vector<RangeType>& ranges, TensorIndex num_selected_rows) const
{
std::shared_ptr<MultiMessage> new_message = this->clone_impl();
this->copy_ranges_impl(new_message, ranges, num_selected_rows);
return MRC_PTR_CAST(DerivedT, new_message);
}
protected:
virtual void get_slice_impl(std::shared_ptr<MultiMessage> new_message,
TensorIndex start,
TensorIndex stop) const = 0;
virtual void copy_ranges_impl(std::shared_ptr<MultiMessage> new_message,
const std::vector<RangeType>& ranges,
TensorIndex num_selected_rows) const = 0;
private:
virtual std::shared_ptr<MultiMessage> clone_impl() const
{
// Cast `this` to the derived type
auto derived_this = static_cast<const DerivedT*>(this);
// Use copy constructor to make a clone
return std::make_shared<DerivedT>(*derived_this);
}
};
class MultiMessage : public DerivedMultiMessage<MultiMessage>
{
public:
MultiMessage(const MultiMessage& other) = default;
MultiMessage(std::shared_ptr<MessageMeta> m, TensorIndex offset = 0, TensorIndex count = -1);
std::shared_ptr<MessageMeta> meta;
TensorIndex mess_offset{0};
TensorIndex mess_count{0};
TableInfo get_meta();
TableInfo get_meta(const std::string& col_name);
TableInfo get_meta(const std::vector<std::string>& column_names);
void set_meta(const std::string& col_name, TensorObject tensor);
void set_meta(const std::vector<std::string>& column_names, const std::vector<TensorObject>& tensors);
protected:
void get_slice_impl(std::shared_ptr<MultiMessage> new_message, TensorIndex start, TensorIndex stop) const override;
void copy_ranges_impl(std::shared_ptr<MultiMessage> new_message,
const std::vector<RangeType>& ranges,
TensorIndex num_selected_rows) const override;
virtual std::shared_ptr<MessageMeta> copy_meta_ranges(const std::vector<RangeType>& ranges) const;
std::vector<RangeType> apply_offset_to_ranges(TensorIndex offset, const std::vector<RangeType>& ranges) const;
};
/****** MultiMessageInterfaceProxy**************************/
struct MultiMessageInterfaceProxy
{
static std::shared_ptr<MultiMessage> init(std::shared_ptr<MessageMeta> meta,
TensorIndex mess_offset,
TensorIndex mess_count);
static std::shared_ptr<MessageMeta> meta(const MultiMessage& self);
static TensorIndex mess_offset(const MultiMessage& self);
static TensorIndex mess_count(const MultiMessage& self);
static pybind11::object get_meta(MultiMessage& self);
static pybind11::object get_meta(MultiMessage& self, std::string col_name);
static pybind11::object get_meta(MultiMessage& self, std::vector<std::string> columns);
// This overload is necessary to match the python signature where you can call self.get_meta(None)
static pybind11::object get_meta(MultiMessage& self, pybind11::none none_obj);
static pybind11::object get_meta_list(MultiMessage& self, pybind11::object col_name);
static void set_meta(MultiMessage& self, pybind11::object columns, pybind11::object value);
static std::shared_ptr<MultiMessage> get_slice(MultiMessage& self, TensorIndex start, TensorIndex stop);
static std::shared_ptr<MultiMessage> copy_ranges(MultiMessage& self,
const std::vector<RangeType>& ranges,
pybind11::object num_selected_rows);
};
#pragma GCC visibility pop// end of group
} // namespace morpheus