Using cuBLASMp for Tensor Parallelism in Distributed Machine Learning#
AG+GEMM and GEMM+RS in terms of traditional PBLAS#
Specific case of 2D block-cyclic data layout. Different C/D grid than A/B.
AllGather+GEMM#
A, B are distributed column-wise. C - distributed row-wise (like A^T). AllGather gathers B.
GEMM+ReduceScatter#
A, B are distributed row-wise. C - distributed column-wise (like A^T).
On Python (C) and cuBLAS (Fortran) data ordering#
C-ordered matrices store elements contiguously by rows. This convention is used by Python.
Fortran-ordered matrices store elements contiguously by columns. This convention is used by cuBLASMp and cuBLAS.
The transpose of a C-ordered (row-major) matrix is effectively a Fortran-ordered (column-major) matrix.
In a distributed setting, a row-wise distributed matrix A is equivalent to a column-wise distributed matrix A^T.
We cater to the torch.nn convention of storing transposed weights (weights_t).
With the objective of calculating input @ weights = output, we use cuBLASMp as follows:
- Python calls a wrapper of a C function (in TE, tex.gemm) calling cublasMpMatmul requesting a TN GEMM of weights_t @ input. (equivalent to weights_t.T @ input)
- Treating the given arrays as Fortran-ordered (see (c)), cuBLASMp effectively sees the operation as a TN GEMM of weights_t.T @ input.T (== weights_t.T.T @ input.T (because TN) == weights_t @ input.T == output.T)
- output.T, when viewed as a C-ordered array in Python, is the desired output.
AG:
cuBLASMp currently requires its parameters, weights_t.T and input.T to be distributed column-wise over processes.
This is equivalent to weights_t and input being distributed row-wise over processes from the Python point of view.
The output, from the cuBLASMp point of view (output.T), is distributed row-wise over processes. This is equivalent to the output (output) being distributed column-wise over processes from the Python point of view.
RS:
cuBLASMp currently requires its parameters, weights_t.T and input.T to be distributed row-wise over processes.
This is equivalent to weights_t and input being distributed column-wise over processes from the Python point of view.
The output, from the cuBLASMp point of view (output.T), is distributed column-wise over processes. This is equivalent to the output (output) being distributed row-wise over processes from the Python point of view.