nemo_rl.distributed.named_sharding#

Module Contents#

Classes#

NamedSharding

Represents an N-dimensional arrangement of ranks with named axes, facilitating data sharding, replication, and collection based on these axes.

API#

class nemo_rl.distributed.named_sharding.NamedSharding(layout: Sequence[Any], names: list[str])[source]#

Represents an N-dimensional arrangement of ranks with named axes, facilitating data sharding, replication, and collection based on these axes.

.. rubric:: Example

layout = [ [[0, 1, 2, 3], [4, 5, 6, 7]], ] names = [“dp”, “pp”, “tp”]

This represents DP=1, PP=2, TP=4#

sharding = NamedSharding(layout, names) print(sharding.shape) # Output: (1, 2, 4) print(sharding.names) # Output: [‘dp’, ‘pp’, ‘tp’] print(sharding.get_ranks(dp=0, pp=1)) # Output: [4, 5, 6, 7]

Initialization

Initializes the NamedSharding object.

Parameters:
  • layout – A nested sequence (e.g., list of lists) representing the ND rank layout. All inner lists must contain integer rank IDs.

  • names – A list of strings representing the names of the dimensions, ordered from the outermost to the innermost dimension.

property shape: dict[str, int]#

Returns the shape of the rank layout.

property names: list[str]#

Returns the names of the axes.

property ndim: int#

Returns the number of dimensions.

property size: int#

Returns the total number of ranks.

property layout: numpy.ndarray[Any, numpy.dtype[numpy.int_]]#

Returns the underlying NumPy array representing the layout.

get_worker_coords(worker_id: int) dict[str, int][source]#

Gets the coordinates of a specific worker ID in the sharding layout.

Parameters:

worker_id – The integer ID of the worker.

Returns:

A dictionary mapping axis names to their integer coordinates for the given worker_id.

Raises:

ValueError – If the worker_id is not found in the layout.

get_ranks_by_coord(**coords: int) list[int][source]#

Gets all ranks that match the specified coordinates for named axes.

Parameters:

**coords – Keyword arguments where the key is the axis name (e.g., “dp”, “tp”) and the value is the integer coordinate along that axis. Axes not specified will match all coordinates along that axis.

Returns:

A sorted list of unique rank integers that match the given coordinate criteria. Returns an empty list if no ranks match.

Raises:

ValueError – If an invalid axis name is provided.

get_ranks(
**kwargs: int,
) Union[nemo_rl.distributed.named_sharding.NamedSharding, int][source]#

Gets the ranks corresponding to specific indices along named axes.

Parameters:

**kwargs – Keyword arguments where the key is the axis name (e.g., “dp”, “tp”) and the value is the index along that axis.

Returns:

A new NamedSharding instance representing the subset of ranks. The shape of the returned sharding corresponds to the axes not specified in the kwargs. If all axes are specified, an int is returned.

Raises:

ValueError – If an invalid axis name is provided or if an index is out of bounds.

get_axis_index(name: str) int[source]#

Gets the numerical index of a named axis.

get_axis_size(name: str) int[source]#

Gets the size of a named axis.

__repr__() str[source]#
__eq__(other: object) bool[source]#