Create a Custom Operator in C++

DALI allows you to create a custom operator in C++ and load it at runtime. Here are several reasons you might need to write your custom operator:

  • DALI does not support the operation that you want to perform and it cannot be expressed by a composition of other operators.

  • You want to write an operator that depends on a third-party library.

  • You want to optimize your pipeline by providing a manually fused operation in C++.

In this tutorial, we will walk you through the process of writing, compiling, and loading a plugin with a DALI custom operator. For demonstration purposes we will provide a CPU and a GPU implementation for the CustomDummy operator. The implementation only copies the input data to the output without any modifications.

Prerequisites

  • DALI is installed from the binary distribution or compiled the from source.

  • You can write in C++.

  • You have a basic knowledge of CMake.

Operator Definition

  1. Declare the operator in a header file.

  2. Provide common Setup functions.

The implementation of CanInferOutputs and SetupImpl can be shared across backends. SetupImpl provides the shape and type description of the output based on the input, and CanInferOutputs informs the executor that the Operator can provide that output description for the entire batch before executing RunImpl.

[1]:
! cat customdummy/dummy.h
#ifndef EXAMPLE_DUMMY_H_
#define EXAMPLE_DUMMY_H_

#include <vector>

#include "dali/pipeline/operator/operator.h"

namespace other_ns {

template <typename Backend>
class Dummy : public ::dali::Operator<Backend> {
 public:
  inline explicit Dummy(const ::dali::OpSpec &spec) :
    ::dali::Operator<Backend>(spec) {}

  virtual inline ~Dummy() = default;

  Dummy(const Dummy&) = delete;
  Dummy& operator=(const Dummy&) = delete;
  Dummy(Dummy&&) = delete;
  Dummy& operator=(Dummy&&) = delete;

 protected:
  bool CanInferOutputs() const override {
    return true;
  }

  bool SetupImpl(std::vector<::dali::OutputDesc> &output_desc,
                 const ::dali::Workspace &ws) override {
    const auto &input = ws.Input<Backend>(0);
    output_desc.resize(1);
    output_desc[0] = {input.shape(), input.type()};
    return true;
  }

  void RunImpl(::dali::Workspace &ws) override;
};

}  // namespace other_ns

#endif  // EXAMPLE_DUMMY_H_

CPU Operator Implementation

  1. Provide the CPU implementation in a C++ implementation file by overriding the RunImpl method for Workspace.

  2. Register the schema for the custom operator with DALI_SCHEMA macro and register the CPU version of the operator with DALI_REGISTER_OPERATOR.

In RunImpl we obtain access to the entire batch that is processed. We get the reference to the CPU thread pool from the workspace ws and create tasks that will copy samples from input to output in parallel. The tasks will be ordered by the thread pool from the longest to the shortest, based on the tensor size, to best utilize the worker threads.

The outputs are already allocated as we provided the SetupImpl and CanInferOutputs functions.

[2]:
! cat customdummy/dummy.cc
#include "dummy.h"

namespace other_ns {

template <>
void Dummy<::dali::CPUBackend>::RunImpl(::dali::Workspace &ws) {
  const auto &input = ws.Input<::dali::CPUBackend>(0);
  auto &output = ws.Output<::dali::CPUBackend>(0);

  ::dali::TypeInfo type = input.type_info();
  auto &tp = ws.GetThreadPool();
  const auto &in_shape = input.shape();
  for (int sample_id = 0; sample_id < in_shape.num_samples(); sample_id++) {
    tp.AddWork(
        [&, sample_id](int thread_id) {
          type.Copy<::dali::CPUBackend, ::dali::CPUBackend>(output.raw_mutable_tensor(sample_id),
                                                            input.raw_tensor(sample_id),
                                                            in_shape.tensor_size(sample_id), 0);
        },
        in_shape.tensor_size(sample_id));
  }
  tp.RunAll();
}

}  // namespace other_ns

DALI_REGISTER_OPERATOR(CustomDummy, ::other_ns::Dummy<::dali::CPUBackend>, ::dali::CPU);

DALI_SCHEMA(CustomDummy)
    .DocStr("Make a copy of the input tensor")
    .NumInput(1)
    .NumOutput(1);

GPU operator implementation

  1. Provide a GPU implementation in a CUDA implementation file by overriding the RunImpl method for Workspace.

  2. Register the GPU version of the operator with DALI_REGISTER_OPERATOR macro.

As it was the case for the CPU implementation, we obtain the entire batch in the RunImpl function. The outputs are already allocated based on the return value of SetupImpl function that was provided earlier.

It is important that we issue the GPU operations on the stream provided by the workspace. Here we copy the batch using cudaMemcpyAsync.

[3]:
! cat customdummy/dummy.cu
#include <cuda_runtime_api.h>
#include "dummy.h"

namespace other_ns {

template<>
void Dummy<::dali::GPUBackend>::RunImpl(::dali::Workspace &ws) {
  const auto &input = ws.Input<::dali::GPUBackend>(0);
  const auto &shape = input.shape();
  auto &output = ws.Output<::dali::GPUBackend>(0);
  for (int sample_idx = 0; sample_idx < shape.num_samples(); sample_idx++) {
    CUDA_CALL(cudaMemcpyAsync(
            output.raw_mutable_tensor(sample_idx),
            input.raw_tensor(sample_idx),
            shape[sample_idx].num_elements() * input.type_info().size(),
            cudaMemcpyDeviceToDevice,
            ws.stream()));
  }
}

}  // namespace other_ns

DALI_REGISTER_OPERATOR(CustomDummy, ::other_ns::Dummy<::dali::GPUBackend>, ::dali::GPU);

Building the Plugin

  1. Specify the build configuration.

To retrieve the build configuration parameters use nvidia.dali.sysconfig.

[4]:
import nvidia.dali.sysconfig as sysconfig
[5]:
print(sysconfig.get_include_dir())
/usr/local/lib/python3.10/dist-packages/nvidia/dali/include
[6]:
print(sysconfig.get_lib_dir())
/usr/local/lib/python3.10/dist-packages/nvidia/dali
[7]:
print(sysconfig.get_compile_flags())
['-I/usr/local/lib/python3.10/dist-packages/nvidia/dali/include', '-D_GLIBCXX_USE_CXX11_ABI=1']
[8]:
print(sysconfig.get_link_flags())
['-L/usr/local/lib/python3.10/dist-packages/nvidia/dali', '-ldali']

Important: Only one version of libdali.so should be loaded in the process at the same time. A plugin must be linked against the exact same library in the DALI’s Python package directory that you intend to use to load your plugin. As a result of this limitation, when you upgrade your DALI version you must link your plugin against the new library again.

  1. In this example, we used CMake to build the plugin.

[9]:
! cat customdummy/CMakeLists.txt
cmake_minimum_required(VERSION 3.10)
set(CMAKE_CUDA_ARCHITECTURES "50;60;70;80;90")

project(custom_dummy_plugin LANGUAGES CUDA CXX C)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_C_STANDARD 11)

set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)

include_directories(SYSTEM "${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}")

execute_process(
        COMMAND python -c "import nvidia.dali as dali; print(dali.sysconfig.get_lib_dir())"
        OUTPUT_VARIABLE DALI_LIB_DIR)
string(STRIP ${DALI_LIB_DIR} DALI_LIB_DIR)

execute_process(
        COMMAND python -c "import nvidia.dali as dali; print(\" \".join(dali.sysconfig.get_compile_flags()))"
        OUTPUT_VARIABLE DALI_COMPILE_FLAGS)
string(STRIP ${DALI_COMPILE_FLAGS} DALI_COMPILE_FLAGS)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${DALI_COMPILE_FLAGS} ")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${DALI_COMPILE_FLAGS} ")
link_directories("${DALI_LIB_DIR}")

add_library(customdummy SHARED dummy.cc dummy.cu)
target_link_libraries(customdummy dali)
  1. We are now ready to compile the plugin that contains the CustomDummy custom operator.

[10]:
! rm -rf customdummy/build
! mkdir -p customdummy/build
! cd customdummy/build && \
  cmake .. && \
  make
-- The CUDA compiler identification is NVIDIA 12.0.76
-- The CXX compiler identification is GNU 11.3.0
-- The C compiler identification is GNU 11.3.0
-- Detecting CUDA compiler ABI info
-- Detecting CUDA compiler ABI info - done
-- Check for working CUDA compiler: /usr/local/cuda/bin/nvcc - skipped
-- Detecting CUDA compile features
-- Detecting CUDA compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /usr/bin/cc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Configuring done (4.4s)
-- Generating done (0.0s)
-- Build files have been written to: /dali/docs/examples/custom_operations/custom_operator/customdummy/build
[ 33%] Building CXX object CMakeFiles/customdummy.dir/dummy.cc.o
[ 66%] Building CUDA object CMakeFiles/customdummy.dir/dummy.cu.o
[100%] Linking CXX shared library libcustomdummy.so
[100%] Built target customdummy
  1. After the build is complete we have a dynamic library file that is ready to use.

[11]:
! ls customdummy/build/*.so
customdummy/build/libcustomdummy.so

Importing the Plugin

  1. We can see that there is no such operator called custom_dummy.

Note: Operations available in nvidia.dali.fn are automatically converted from camel case to snake case, while the legacy operator objects in nvidia.dali.ops keep the camel case format (Example: fn.custom_dummy vs. ops.CustomDummy).

[12]:
import nvidia.dali.fn as fn

try:
    help(fn.custom_dummy)
except Exception as e:
    print("Error: " + str(e))
Error: module 'nvidia.dali.fn' has no attribute 'custom_dummy'
  1. Load the plugin.

[13]:
import nvidia.dali.plugin_manager as plugin_manager

plugin_manager.load_library("./customdummy/build/libcustomdummy.so")
  1. Verify that the new operator is available.

[14]:
help(fn.custom_dummy)
Help on function custom_dummy in module nvidia.dali.fn:

custom_dummy(*inputs, **kwargs)
    Make a copy of the input tensor

    Supported backends
     * 'cpu'
     * 'gpu'


    Args
    ----
    `input` : TensorList
        Input to the operator.


    Keyword args
    ------------
    `bytes_per_sample_hint` : int or list of int, optional, default = `[0]`
        Output size hint, in bytes per sample.

        If specified, the operator's outputs residing in GPU or page-locked host memory will be preallocated
        to accommodate a batch of samples of this size.
    `preserve` : bool, optional, default = `False`
        Prevents the operator from being removed from the
        graph even if its outputs are not used.
    `seed` : int, optional, default = `-1`
        Random seed.

        If not provided, it will be populated based on the global seed of the pipeline.

For the sake of completeness, it is worth mentioning that even if it discouraged, it is also possible to access the custom operator through the legacy operator object API:

[15]:
import nvidia.dali.ops as ops

help(ops.CustomDummy)
Help on class CustomDummy in module nvidia.dali.ops:

class CustomDummy(builtins.object)
 |  CustomDummy(*, device='cpu', **kwargs)
 |
 |  Make a copy of the input tensor
 |
 |  Supported backends
 |   * 'cpu'
 |   * 'gpu'
 |
 |
 |  Keyword args
 |  ------------
 |  `bytes_per_sample_hint` : int or list of int, optional, default = `[0]`
 |      Output size hint, in bytes per sample.
 |
 |      If specified, the operator's outputs residing in GPU or page-locked host memory will be preallocated
 |      to accommodate a batch of samples of this size.
 |  `preserve` : bool, optional, default = `False`
 |      Prevents the operator from being removed from the
 |      graph even if its outputs are not used.
 |  `seed` : int, optional, default = `-1`
 |      Random seed.
 |
 |      If not provided, it will be populated based on the global seed of the pipeline.
 |
 |  Methods defined here:
 |
 |  __call__(self, *inputs, **kwargs)
 |      __call__(data, **kwargs)
 |
 |      Operator call to be used in graph definition.
 |
 |      Args
 |      ----
 |      `data` : TensorList
 |          Input to the operator.
 |
 |  __init__(self, *, device='cpu', **kwargs)
 |
 |  ----------------------------------------------------------------------
 |  Readonly properties defined here:
 |
 |  device
 |
 |  preserve
 |
 |  schema
 |
 |  spec
 |
 |  ----------------------------------------------------------------------
 |  Data descriptors defined here:
 |
 |  __dict__
 |      dictionary for instance variables (if defined)
 |
 |  __weakref__
 |      list of weak references to the object (if defined)
 |
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |
 |  schema_name = 'CustomDummy'