****************************** Device-Initiated Communication ****************************** Starting with version 2.28, NCCL provides a device-side communication API, making it possible to use communication primitives directly from user CUDA kernels. Device API ---------- Device API consists of the following modules: * **LSA (Load/Store Accessible)** -- for communication between devices accessible via memory load/store operations, using CUDA P2P. This includes devices connected over NVLink and some devices connected over PCIe (the latter are currently limited to devices with P2P connectivity (as indicated by ``nvidia-smi topo -p2p p``), subject to the :ref:`env_NCCL_P2P_LEVEL` distance check). * **Multimem** -- for communication between devices using the hardware multicast feature provided by NVLink SHARP (available on some datacenter GPUs since the Hopper generation). * **GIN (GPU-Initiated Networking)** -- for communication over the network. This module is under active development and will not be covered here at this time. The device API relies on symmetric memory (see :ref:`window_reg`), which in turn depends on GPU virtual memory management (see :ref:`env_NCCL_CUMEM_ENABLE`) and optionally -- for multimem support -- on NVLink SHARP (see :ref:`env_NCCL_NVLS_ENABLE`). Host-Side Setup --------------- To perform communication from the device, a device communicator needs to be created using :c:func:`ncclDevCommCreate`. Data transfer operations on buffers require symmetric memory windows (see :ref:`window_reg`). A custom communication kernel can then be launched using the standard CUDA syntax. The code excerpt below demonstrates these steps: .. code:: C int main() { [...] NCCLCHECK(ncclCommInitRank(&comm, nranks, id, rank)); /* Buffer initialization and window creation */ char* buffer; size_t size = 256*1048576; NCCLCHECK(ncclMemAlloc((void**)&buffer, size)); ncclWindow_t win; NCCLCHECK(ncclCommWindowRegister(comm, buffer, size, &win, NCCL_WIN_COLL_SYMMETRIC)); /* Get device communicator */ ncclDevComm devComm; ncclDevCommRequirements reqs; memset(&reqs, 0, sizeof(ncclDevCommRequirements)); int nCTAs = 16; reqs.lsaBarrierCount = nCTAs; NCCLCHECK(ncclDevCommCreate(comm, &reqs, &devComm)); /* Launch user kernel */ customKernel<<>>(devComm, win); [...] } Depending on the kernel and application requirements, the same window can be used for input and output, or multiple windows may be needed. When creating a device communicator, the resources that the kernel will need should be specified via the requirements list (see :c:type:`ncclDevCommRequirements`). In the above example we specify just the number of barriers that the kernel will need, in this case one for each CTA the kernel is to be launched on (16, each CTA running 256 threads). Simple Device Kernel -------------------- .. code:: C template __global__ void inPlaceAllReduceKernel(ncclDevComm devComm, ncclWindow_t win, size_t offset, size_t count) { ncclLsaBarrierSession bar { ncclCoopCta(), devComm, ncclTeamTagLsa(), blockIdx.x }; bar.sync(ncclCoopCta(), cuda::memory_order_relaxed); const int rank = devComm.lsaRank, nRanks = devComm.lsaSize; const int globalTid = threadIdx.x + blockDim.x * (rank + blockIdx.x * nRanks); const int globalNthreads = blockDim.x * gridDim.x * nRanks; for (size_t o = globalTid; o < count; o += globalNthreads) { T v = 0; for (int peer=0; peer __global__ void inPlaceAllReduceKernel(ncclDevComm devComm, ncclWindow_t win, size_t offset, size_t count) { ncclLsaBarrierSession bar { ncclCoopCta(), devComm, ncclTeamTagLsa(), blockIdx.x, /*multimem*/true }; [...] T* mmPtr = (T*)ncclGetLsaMultimemPointer(win, offset, devComm); for (size_t o = globalTid; o < count; o += globalNthreads) { T v = multimem_sum(mmPtr+o); multimem_st(mmPtr+o, v); } [...] } The above code excerpt demonstrates modifications needed to the earlier code segments to enable multimem support (the lines with critical changes are highlighted). On the host side, ``lsaMultimem`` needs to be set in the requirements prior to creating the device communicator (:c:func:`ncclDevCommCreate` will fail if the necessary hardware support is unavailable). Within the device kernel, we can switch the memory barrier to a multimem-optimized variant by adding an extra argument to the constructor. The processing loop is actually simpler with multimem: :c:func:`ncclGetLsaMultimemPointer` needs to be invoked just once per kernel. The returned multicast memory pointer enables access to the device memory of all the ranks of the communicator without having to iterate over them, and the data can be reduced in hardware. To keep this example simple, the implementations of ``multimem_sum`` and ``multimem_st`` are not included. Those need to be implemented using PTX, e.g., ``multimem.ld_reduce.global.add`` and ``multimem.st.global``. Thread Groups ------------- Many functions in the device API take a thread cooperative group as input to indicate which threads within the CTA will take part in the operation. NCCL provides three predefined ones: ``ncclCoopThread()``, ``ncclCoopWarp()`` and ``ncclCoopCta()``. Users may also pass CUDA cooperative groups, or any class which provides ``thread_rank()``, ``size()`` and ``sync()`` functions. Teams ----- To address remote ranks or perform barriers, NCCL refers to subsets of ranks within the global communicator as "teams". NCCL provides three predefined ones: ``ncclTeamWorld()``, ``ncclTeamLsa()``, and ``ncclTeamRail()``.