NVIDIA Holoscan SDK v2.7.0
Holoscan v2.7.0

Class TorchInfer

Base Type

class TorchInfer : public holoscan::inference::InferBase

Libtorch based inference class

Public Functions

TorchInfer(const std::string &model_file_path, bool cuda_flag, bool cuda_buf_in, bool cuda_buf_out)

Constructor.

Parameters
  • model_file_path – Path to torch model file

  • cuda_flag – Flag to show if inference will happen using CUDA

~TorchInfer()

Destructor.

virtual InferStatus do_inference(const std::vector<std::shared_ptr<DataBuffer>> &input_data, std::vector<std::shared_ptr<DataBuffer>> &output_buffer, cudaEvent_t cuda_event_data, cudaEvent_t *cuda_event_inference)

Does the Core inference. The provided CUDA data event is used to prepare the input data any execution of CUDA work should be in sync with this event. If the inference is using CUDA it should record a CUDA event and pass it back in cuda_event_inference.

Parameters
  • input_data – Vector of Input DataBuffer

  • output_buffer – Vector of Output DataBuffer, is populated with inferred results

Returns

InferStatus

InferStatus populate_model_details()

Populate class parameters with model details and values.

void print_model_details()

Print model details.

virtual std::vector<std::vector<int64_t>> get_input_dims() const

Get input data dimensions to the model.

Returns

Vector of input dimensions. Each dimension is a vector of int64_t corresponding to the shape of the input tensor.

virtual std::vector<std::vector<int64_t>> get_output_dims() const

Get output data dimensions from the model.

Returns

Vector of output dimensions. Each dimension is a vector of int64_t corresponding to the shape of the output tensor.

virtual std::vector<holoinfer_datatype> get_input_datatype() const

Get input data types from the model.

Returns

Vector of values as datatype per input tensor

virtual std::vector<holoinfer_datatype> get_output_datatype() const

Get output data types from the model.

Returns

Vector of values as datatype per output tensor

Previous Template Class ThreadSafeQueue
Next Class TransformBase
© Copyright 2022-2024, NVIDIA. Last updated on Dec 2, 2024.