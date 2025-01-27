NVIDIA Holoscan SDK v.2.8.0
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef HOLOSCAN_CORE_IO_CONTEXT_HPP
#define HOLOSCAN_CORE_IO_CONTEXT_HPP

#include <cuda_runtime.h>

#include <any>
#include <map>
#include <memory>
#include <string>
#include <string_view>
#include <typeinfo>
#include <unordered_map>
#include <utility>
#include <vector>

#include <common/type_name.hpp>
#include <gxf/cuda/cuda_stream.hpp>
#include "./common.hpp"
#include "./domain/tensor_map.hpp"
#include "./errors.hpp"
#include "./expected.hpp"
#include "./gxf/entity.hpp"
#include "./gxf/gxf_cuda.hpp"
#include "./message.hpp"
#include "./operator.hpp"
#include "./type_traits.hpp"

namespace holoscan {

// To indicate that data is not available for the input port
struct NoMessageType {};
constexpr NoMessageType kNoReceivedMessage;

// To indicate that input port is not accessible
struct NoAccessibleMessageType : public std::string {
  NoAccessibleMessageType() : std::string("Port is not accessible") {}
  explicit NoAccessibleMessageType(const std::string& message) : std::string(message) {}
  explicit NoAccessibleMessageType(const char* message) : std::string(message) {}
  explicit NoAccessibleMessageType(std::string&& message) : std::string(std::move(message)) {}
};

static inline std::string get_well_formed_name(
    const char* name, const std::unordered_map<std::string, std::shared_ptr<IOSpec>>& io_list) {
  std::string well_formed_name;
  if (name == nullptr || name[0] == '\0') {
    if (io_list.size() == 1) {
      well_formed_name = io_list.begin()->first;
    } else {
      well_formed_name = "";
    }
  } else {
    well_formed_name = name;
  }
  return well_formed_name;
}

class InputContext {
 public:
  InputContext(ExecutionContext* execution_context, Operator* op,
               std::unordered_map<std::string, std::shared_ptr<IOSpec>>& inputs)
      : execution_context_(execution_context), op_(op), inputs_(inputs) {}

  InputContext(ExecutionContext* execution_context, Operator* op)
      : execution_context_(execution_context), op_(op), inputs_(op->spec()->inputs()) {}

  virtual ~InputContext() = default;

  ExecutionContext* execution_context() const { return execution_context_; }

  Operator* op() const { return op_; }

  std::unordered_map<std::string, std::shared_ptr<IOSpec>>& inputs() const { return inputs_; }

  bool empty(const char* name = nullptr) {
    // First see if the name could be found in the inputs
    auto& inputs = op_->spec()->inputs();
    auto it = inputs.find(std::string(name));
    if (it != inputs.end()) { return empty_impl(name); }

    // Then see if it is in the parameters
    auto& params = op_->spec()->params();
    auto it2 = params.find(std::string(name));
    if (it2 != params.end()) {
      auto& param_wrapper = it2->second;
      auto& arg_type = param_wrapper.arg_type();
      if ((arg_type.element_type() != ArgElementType::kIOSpec) ||
          (arg_type.container_type() != ArgContainerType::kVector)) {
        HOLOSCAN_LOG_ERROR("Input parameter with name '{}' is not of type 'std::vector<IOSpec*>'",
                           name);
        return true;
      }
      std::any& any_param = param_wrapper.value();
      // Note that the type of any_param is Parameter<typeT>*, not Parameter<typeT>.
      auto& param = *std::any_cast<Parameter<std::vector<IOSpec*>>*>(any_param);
      int num_inputs = param.get().size();
      for (int i = 0; i < num_inputs; ++i) {
        // if any of them is not empty return false
        if (!empty_impl(fmt::format("{}:{}", name, i).c_str())) { return false; }
      }
      return true;  // all of them are empty, so return true.
    }

    HOLOSCAN_LOG_ERROR("Input port '{}' not found", name);
    return true;
  }

  template <typename DataT>
  holoscan::expected<DataT, holoscan::RuntimeError> receive(const char* name = nullptr) {
    auto& params = op_->spec()->params();
    auto param_it = params.find(std::string(name));

    if constexpr (holoscan::is_vector_v<DataT>) {
      DataT input_vector;
      std::string error_message;

      if (param_it != params.end()) {
        auto& param_wrapper = param_it->second;
        if (!is_valid_param_type(param_wrapper.arg_type())) {
          return make_unexpected<holoscan::RuntimeError>(
              create_receive_error(name, "Input parameter is not of type 'std::vector<IOSpec*>'"));
        }
        if (!fill_input_vector_from_params(param_wrapper, name, input_vector, error_message)) {
          return make_unexpected<holoscan::RuntimeError>(
              create_receive_error(name, error_message.c_str()));
        }
      } else {
        if (!fill_input_vector_from_inputs(name, input_vector, error_message)) {
          return make_unexpected<holoscan::RuntimeError>(
              create_receive_error(name, error_message.c_str()));
        }
      }
      return input_vector;
    } else {
      return receive_single_value<DataT>(name);
    }
  }

  std::shared_ptr<gxf::CudaObjectHandler> cuda_object_handler() { return cuda_object_handler_; }

  void cuda_object_handler(std::shared_ptr<gxf::CudaObjectHandler> handler) {
    cuda_object_handler_ = handler;
  }

  virtual cudaStream_t receive_cuda_stream(const char* input_port_name = nullptr,
                                           bool allocate = true, bool sync_to_default = false) = 0;

  virtual std::vector<std::optional<cudaStream_t>> receive_cuda_streams(
      const char* input_port_name = nullptr) = 0;

 protected:
  virtual bool empty_impl(const char* name = nullptr) {
    (void)name;
    return true;
  }
  virtual std::any receive_impl(const char* name = nullptr, bool no_error_message = false) {
    (void)name;
    (void)no_error_message;
    return nullptr;
  }

  // --------------- Start of helper functions for the receive method ---------------
  inline bool is_valid_param_type(const ArgType& arg_type) {
    return (arg_type.element_type() == ArgElementType::kIOSpec) &&
           (arg_type.container_type() == ArgContainerType::kVector);
  }

  template <typename DataT>
  inline bool fill_input_vector_from_params(ParameterWrapper& param_wrapper, const char* name,
                                            DataT& input_vector, std::string& error_message) {
    auto& param = *std::any_cast<Parameter<std::vector<IOSpec*>>*>(param_wrapper.value());
    int num_inputs = param.get().size();
    input_vector.reserve(num_inputs);

    for (int index = 0; index < num_inputs; ++index) {
      std::string port_name = fmt::format("{}:{}", name, index);
      auto value = receive_impl(port_name.c_str(), true);
      const std::type_info& value_type = value.type();

      if (value_type == typeid(kNoReceivedMessage)) {
        error_message =
            fmt::format("No data is received from the input port with name '{}'", port_name);
        return false;
      }

      if (!process_received_value(value, value_type, name, index, input_vector, error_message)) {
        return false;
      }
    }
    return true;
  }

  template <typename DataT>
  inline bool fill_input_vector_from_inputs(const char* name, DataT& input_vector,
                                            std::string& error_message) {
    const auto& inputs = op_->spec()->inputs();
    const auto input_it = inputs.find(std::string(name));

    if (input_it == inputs.end()) { return false; }

    int index = 0;
    while (true) {
      auto value = receive_impl(name);
      const std::type_info& value_type = value.type();

      if (value_type == typeid(kNoReceivedMessage)) {
        if (index == 0) {
          error_message =
              fmt::format("No data is received from the input port with name '{}'", name);
          return false;
        }
        break;
      }
      if (index == 0 && value_type == typeid(DataT)) {
        // If the first input is of type DataT (such as `std::vector<bool>`), then return the value
        // directly
        input_vector = std::move(std::any_cast<DataT>(value));
        return true;
      }

      if (!process_received_value(value, value_type, name, index++, input_vector, error_message)) {
        return false;
      }
    }
    return true;
  }

  inline bool populate_tensor_map(const holoscan::gxf::Entity& gxf_entity,
                                  holoscan::TensorMap& tensor_map) {
    auto tensor_components_expected = gxf_entity.findAllHeap<nvidia::gxf::Tensor>();
    for (const auto& gxf_tensor : tensor_components_expected.value()) {
      // Do zero-copy conversion to holoscan::Tensor (as in gxf_entity.get<holoscan::Tensor>())
      auto maybe_dl_ctx = (*gxf_tensor->get()).toDLManagedTensorContext();
      if (!maybe_dl_ctx) {
        HOLOSCAN_LOG_ERROR(
            "Failed to get std::shared_ptr<DLManagedTensorContext> from nvidia::gxf::Tensor");
        return false;
      }
      auto holoscan_tensor = std::make_shared<Tensor>(maybe_dl_ctx.value());
      tensor_map.insert({gxf_tensor->name(), holoscan_tensor});
    }
    return true;
  }

  template <typename DataT>
  inline bool process_received_value(std::any& value, const std::type_info& value_type,
                                     const char* name, int index, DataT& input_vector,
                                     std::string& error_message) {
    bool is_bad_any_cast = false;

    // Assume that the received data is not of type NoMessageType
    // (this case should be handled by the caller)

    if (value_type == typeid(NoAccessibleMessageType)) {
      auto casted_value = std::any_cast<NoAccessibleMessageType>(value);
      HOLOSCAN_LOG_ERROR(static_cast<std::string>(casted_value));
      error_message = std::move(static_cast<std::string>(casted_value));
      return false;
    }

    if constexpr (std::is_same_v<typename DataT::value_type, std::any>) {
      input_vector.push_back(std::move(value));
    } else if (value_type == typeid(std::nullptr_t)) {
      handle_null_value<DataT>(input_vector);
    } else {
      try {
        auto casted_value = std::any_cast<typename DataT::value_type>(value);
        input_vector.push_back(casted_value);
      } catch (const std::bad_any_cast& e) {
        is_bad_any_cast = true;
      } catch (const std::exception& e) {
        error_message = fmt::format(
            "Unable to cast the received data to the specified type for input '{}:{}' of "
            "type {}: {}",
            name,
            index,
            value_type.name(),
            e.what());
        return false;
      }
    }

    if (is_bad_any_cast) {
      return handle_bad_any_cast<DataT>(value, name, index, input_vector, error_message);
    }

    return true;
  }

  template <typename DataT>
  inline void handle_null_value(DataT& input_vector) {
    if constexpr (holoscan::is_shared_ptr_v<typename DataT::value_type> ||
                  std::is_pointer_v<typename DataT::value_type>) {
      input_vector.push_back(typename DataT::value_type{nullptr});
    }
  }

  template <typename DataT>
  inline bool handle_bad_any_cast(std::any& value, const char* name, int index, DataT& input_vector,
                                  std::string& error_message) {
    if constexpr (is_one_of_derived_v<typename DataT::value_type, nvidia::gxf::Entity>) {
      error_message = fmt::format(
          "Unable to cast the received data to the specified type (holoscan::gxf::Entity) for "
          "input "
          "'{}:{}'",
          name,
          index);
      HOLOSCAN_LOG_DEBUG(error_message);
      return false;
    } else if constexpr (is_one_of_derived_v<typename DataT::value_type, holoscan::TensorMap>) {
      TensorMap tensor_map;
      try {
        auto gxf_entity = std::any_cast<holoscan::gxf::Entity>(value);
        bool is_tensor_map_populated = populate_tensor_map(gxf_entity, tensor_map);
        if (!is_tensor_map_populated) {
          error_message = fmt::format(
              "Unable to populate the TensorMap from the received GXF Entity for input '{}:{}'",
              name,
              index);
          HOLOSCAN_LOG_DEBUG(error_message);
          return false;
        }
      } catch (const std::bad_any_cast& e) {
        error_message = fmt::format(
            "Unable to cast the received data to the specified type (holoscan::TensorMap) for "
            "input "
            "'{}:{}'",
            name,
            index);
        HOLOSCAN_LOG_DEBUG(error_message);
        return false;
      }
      input_vector.push_back(std::move(tensor_map));
    } else {
      error_message = fmt::format(
          "Unable to cast the received data to the specified type for input '{}:{}' of type {}: {}",
          name,
          index,
          value.type().name(),
          error_message);
      HOLOSCAN_LOG_DEBUG(error_message);
      return false;
    }
    return true;
  }

  template <typename DataT>
  inline holoscan::expected<DataT, holoscan::RuntimeError> receive_single_value(const char* name) {
    auto value = receive_impl(name);
    const std::type_info& value_type = value.type();

    if (value_type == typeid(NoMessageType)) {
      return make_unexpected<holoscan::RuntimeError>(
          create_receive_error(name, "No message received from the input port"));
    } else if (value_type == typeid(NoAccessibleMessageType)) {
      auto casted_value = std::any_cast<NoAccessibleMessageType>(value);
      HOLOSCAN_LOG_ERROR(static_cast<std::string>(casted_value));
      auto error_message = std::move(static_cast<std::string>(casted_value));
      return make_unexpected<holoscan::RuntimeError>(
          create_receive_error(name, error_message.c_str()));
    }

    try {
      if constexpr (std::is_same_v<DataT, std::any>) {
        return value;
      } else if (value_type == typeid(std::nullptr_t)) {
        return handle_null_value<DataT>();
      } else if constexpr (is_one_of_derived_v<DataT, nvidia::gxf::Entity>) {
        // Handle nvidia::gxf::Entity
        return std::any_cast<DataT>(value);
      } else if constexpr (is_one_of_derived_v<DataT, holoscan::TensorMap>) {
        // Handle holoscan::TensorMap
        TensorMap tensor_map;
        bool is_tensor_map_populated =
            populate_tensor_map(std::any_cast<holoscan::gxf::Entity>(value), tensor_map);
        if (!is_tensor_map_populated) {
          auto error_message = fmt::format(
              "Unable to populate the TensorMap from the received GXF Entity for input '{}'", name);
          HOLOSCAN_LOG_DEBUG(error_message);
          return make_unexpected<holoscan::RuntimeError>(
              create_receive_error(name, error_message.c_str()));
        }
        return tensor_map;
      } else {
        return std::any_cast<DataT>(value);
      }
    } catch (const std::bad_any_cast& e) {
      auto error_message = fmt::format(
          "Unable to cast the received data to the specified type for input '{}' of type {}",
          name,
          value.type().name());
      HOLOSCAN_LOG_DEBUG(error_message);

      return make_unexpected<holoscan::RuntimeError>(
          create_receive_error(name, error_message.c_str()));
    }
  }

  inline holoscan::RuntimeError create_receive_error(const char* name, const char* message) {
    auto error_message =
        fmt::format("Failure receiving message from input port '{}': {}", name, message);
    HOLOSCAN_LOG_TRACE(error_message);
    return holoscan::RuntimeError(holoscan::ErrorCode::kReceiveError, error_message.c_str());
  }

  template <typename DataT>
  inline holoscan::expected<DataT, holoscan::RuntimeError> handle_null_value() {
    if constexpr (holoscan::is_shared_ptr_v<DataT> || std::is_pointer_v<DataT>) {
      return DataT{nullptr};
    } else {
      auto error_message = "Received nullptr for a non-pointer type";
      return make_unexpected<holoscan::RuntimeError>(create_receive_error("input", error_message));
    }
  }

  // --------------- End of helper functions for the receive method ---------------

  ExecutionContext* execution_context_ =
      nullptr;
  Operator* op_ = nullptr;
  std::unordered_map<std::string, std::shared_ptr<IOSpec>>& inputs_;

 private:
  std::shared_ptr<gxf::CudaObjectHandler> cuda_object_handler_{};
};

class OutputContext {
 public:
  OutputContext(ExecutionContext* execution_context, Operator* op)
      : execution_context_(execution_context), op_(op), outputs_(op->spec()->outputs()) {}

  OutputContext(ExecutionContext* execution_context, Operator* op,
                std::unordered_map<std::string, std::shared_ptr<IOSpec>>& outputs)
      : execution_context_(execution_context), op_(op), outputs_(outputs) {}

  virtual ~OutputContext() = default;

  ExecutionContext* execution_context() const { return execution_context_; }

  Operator* op() const { return op_; }

  std::unordered_map<std::string, std::shared_ptr<IOSpec>>& outputs() const { return outputs_; }

  enum class OutputType {
    kSharedPointer,
    kGXFEntity,
    kAny,
  };

  template <typename DataT, typename = std::enable_if_t<!holoscan::is_one_of_derived_v<
                                DataT, nvidia::gxf::Entity, std::any>>>
  void emit(std::shared_ptr<DataT>& data, const char* name = nullptr,
            const int64_t acq_timestamp = -1) {
    emit_impl(data, name, OutputType::kSharedPointer, acq_timestamp);
  }

  template <typename DataT,
            typename = std::enable_if_t<holoscan::is_one_of_derived_v<DataT, nvidia::gxf::Entity>>>
  void emit(DataT& data, const char* name = nullptr, const int64_t acq_timestamp = -1) {
    // if it is the same as nvidia::gxf::Entity then just pass it to emit_impl
    if constexpr (holoscan::is_one_of_v<DataT, nvidia::gxf::Entity>) {
      emit_impl(data, name, OutputType::kGXFEntity, acq_timestamp);
    } else {
      // Convert it to nvidia::gxf::Entity and then pass it to emit_impl
      // Otherwise, we will lose the type information and cannot cast appropriately in emit_impl
      emit_impl(nvidia::gxf::Entity(data), name, OutputType::kGXFEntity, acq_timestamp);
    }
  }

  template <typename DataT,
            typename = std::enable_if_t<!holoscan::is_one_of_derived_v<DataT, nvidia::gxf::Entity>>>
  void emit(DataT data, const char* name = nullptr, const int64_t acq_timestamp = -1) {
    emit_impl(std::move(data), name, OutputType::kAny, acq_timestamp);
  }

  void emit(holoscan::TensorMap& data, const char* name = nullptr,
            const int64_t acq_timestamp = -1) {
    auto out_message = holoscan::gxf::Entity::New(execution_context_);
    for (auto& [key, tensor] : data) { out_message.add(tensor, key.c_str()); }
    emit(out_message, name, acq_timestamp);
  }

  virtual void set_cuda_stream(const cudaStream_t stream,
                               const char* output_port_name = nullptr) = 0;

  std::shared_ptr<gxf::CudaObjectHandler> cuda_object_handler() { return cuda_object_handler_; }

  void cuda_object_handler(std::shared_ptr<gxf::CudaObjectHandler> handler) {
    cuda_object_handler_ = handler;
  }

 protected:
  virtual void emit_impl([[maybe_unused]] std::any data,
                         [[maybe_unused]] const char* name = nullptr,
                         [[maybe_unused]] OutputType out_type = OutputType::kSharedPointer,
                         [[maybe_unused]] const int64_t acq_timestamp = -1) {}

  ExecutionContext* execution_context_ =
      nullptr;
  Operator* op_ = nullptr;
  std::unordered_map<std::string, std::shared_ptr<IOSpec>>& outputs_;
  std::shared_ptr<gxf::CudaObjectHandler> cuda_object_handler_{};
};

}  // namespace holoscan

#endif/* HOLOSCAN_CORE_IO_CONTEXT_HPP */

