Distributed Reshape#

Overview#

The distributed reshape module nvmath.distributed.reshape in nvmath-python leverages the NVIDIA cuFFTMp library and provides APIs that can be directly called from the host to efficiently redistribute local operands on multiple processes on multi-node multi-GPU systems at scale. Both stateless function-form APIs and stateful class-form APIs are provided:

Reshape is a general-purpose API to change how data is distributed or partitioned across processes, by shuffling data among the processes. Distributed reshape supports arbitrary data distributions in the form of 1D/2D/3D boxes (see Box distribution).

Example#

To perform a distributed reshape, each process specifies its own input and output box, which determines the distribution of the input and output, respectively.

As an example, consider a matrix that is distributed column-wise on two processes (each process owns a contiguous chunk of columns). To redistribute the matrix row-wise, we can use distributed reshape:

Tip

Reminder to initialize the distributed context first as per Initializing the distributed runtime.

from nvmath.distributed.distribution import Box

# The global dimensions of the matrix are 4x4. The matrix is distributed
# column-wise, so each process has 4 rows and 2 columns.

# Get process rank from mpi4py communicator.
rank = communicator.Get_rank()

# Initialize the matrix on each process, as a NumPy ndarray (on the CPU).
A = np.zeros((4, 2)) if rank == 0 else np.ones((4, 2))

# Reshape from column-wise to row-wise.
if rank == 0:
    input_box = Box((0, 0), (4, 2))
    output_box = Box((0, 0), (2, 4))
else:
    input_box = Box((0, 2), (4, 4))
    output_box = Box((2, 0), (4, 4))

# Distributed reshape returns a new operand with its own buffer.
B = nvmath.distributed.reshape.reshape(A, input_box, output_box)

# The result is a NumPy ndarray, distributed row-wise:
# [0] B:
# [[0. 0. 1. 1.]
#  [0. 0. 1. 1.]]
#
# [1] B:
# [[0. 0. 1. 1.]
#  [0. 0. 1. 1.]]
print(f"[{rank}] B:\n{B}")

API Reference#

Reshape support (nvmath.distributed.reshape)#

reshape(operand, input_box, output_box[, ...])

Perform a distributed reshape on the provided operand to change its distribution across processes.

Reshape(operand, /, input_box, output_box, *)

Create a stateful object that encapsulates the specified distributed Reshape and required resources.

ReshapeOptions([logger, blocking])

A data class for providing options to the Reshape object and the wrapper function reshape().