↰ Return to documentation for file (morpheus/_lib/include/morpheus/objects/rmm_tensor.hpp
)
#pragma once
#include "morpheus/objects/tensor_object.hpp"
#include "morpheus/utilities/type_util.hpp"
#include "morpheus/utilities/type_util_detail.hpp"// for DataType
#include <rmm/device_buffer.hpp>
#include <cstddef>// for size_t
#include <memory>
#include <utility>// for pair
#include <vector>
namespace morpheus {
/****** Component public implementations *******************/
/****** RMMTensor****************************************/
class RMMTensor : public ITensor
{
public:
RMMTensor(std::shared_ptr<rmm::device_buffer> device_buffer,
size_t offset,
DType dtype,
std::vector<TensorIndex> shape,
std::vector<TensorIndex> stride = {});
~RMMTensor() override = default;
bool is_compact() const final;
DataType dtype() const override;
RankType rank() const final;
std::shared_ptr<ITensor> deep_copy() const override;
std::shared_ptr<ITensor> reshape(const std::vector<TensorIndex> &dims) const override;
std::shared_ptr<ITensor> slice(const std::vector<TensorIndex> &min_dims,
const std::vector<TensorIndex> &max_dims) const override;
std::shared_ptr<ITensor> copy_rows(const std::vector<std::pair<TensorIndex, TensorIndex>> &selected_rows,
TensorIndex num_rows) const override;
std::shared_ptr<MemoryDescriptor> get_memory() const override;
std::size_t bytes() const final;
std::size_t count() const final;
std::size_t shape(std::size_t idx) const final;
std::size_t stride(std::size_t idx) const final;
void *data() const override;
void get_shape(std::vector<TensorIndex> &s) const;
void get_stride(std::vector<TensorIndex> &s) const;
// Tensor reshape(std::vector<TensorIndex> shape)
// {
// CHECK(is_compact());
// return Tensor(descriptor_shared(), dtype_size(), shape);
// }
std::shared_ptr<ITensor> as_type(DataType dtype) const override;
protected:
private:
size_t offset_bytes() const;
// Memory info
std::shared_ptr<rmm::device_buffer> m_md;
size_t m_offset;
// // Type info
// std::string m_typestr;
// std::size_t m_dtype_size;
DType m_dtype;
// Shape info
std::vector<TensorIndex> m_shape;
std::vector<TensorIndex> m_stride;
}; // end of group
} // namespace morpheus