Using cuBLASMp for Tensor Parallelism in Distributed Machine Learning#

AG+GEMM and GEMM+RS in terms of traditional PBLAS#

Specific case of 2D block-cyclic data layout. Different C/D grid than A/B.

AllGather+GEMM#

A, B are distributed column-wise. C - distributed row-wise (like A^T). AllGather gathers B.

GEMM+ReduceScatter#

A, B are distributed row-wise. C - distributed column-wise (like A^T).

On Python (C) and cuBLAS (Fortran) data ordering#

  1. C-ordered matrices store elements contiguously by rows. This convention is used by Python.

  2. Fortran-ordered matrices store elements contiguously by columns. This convention is used by cuBLASMp and cuBLAS.

  3. The transpose of a C-ordered (row-major) matrix is effectively a Fortran-ordered (column-major) matrix.

  4. In a distributed setting, a row-wise distributed matrix A is equivalent to a column-wise distributed matrix A^T.

We cater to the torch.nn convention of storing transposed weights (weights_t). With the objective of calculating input @ weights = output, we use cuBLASMp as follows: - Python calls a wrapper of a C function (in TE, tex.gemm) calling cublasMpMatmul requesting a TN GEMM of weights_t @ input. (equivalent to weights_t.T @ input) - Treating the given arrays as Fortran-ordered (see (c)), cuBLASMp effectively sees the operation as a TN GEMM of weights_t.T @ input.T (== weights_t.T.T @ input.T (because TN) == weights_t @ input.T == output.T) - output.T, when viewed as a C-ordered array in Python, is the desired output.

AG: cuBLASMp currently requires its parameters, weights_t.T and input.T to be distributed column-wise over processes. This is equivalent to weights_t and input being distributed row-wise over processes from the Python point of view. The output, from the cuBLASMp point of view (output.T), is distributed row-wise over processes. This is equivalent to the output (output) being distributed column-wise over processes from the Python point of view.

RS: cuBLASMp currently requires its parameters, weights_t.T and input.T to be distributed row-wise over processes. This is equivalent to weights_t and input being distributed column-wise over processes from the Python point of view. The output, from the cuBLASMp point of view (output.T), is distributed column-wise over processes. This is equivalent to the output (output) being distributed row-wise over processes from the Python point of view.