nemo_automodel.components.distributed.parallel_styles#

Module Contents#

Classes#

TPLinear

nn.Linear variant safe for torch.compile + DTensor tensor-parallel weights.

ColwiseParallelLora

RowwiseParallelLora

SequenceParallelLora

Functions#

API#

nemo_automodel.components.distributed.parallel_styles._distribute_param(
_module,
name,
device_mesh,
src_data_rank,
placements,
)#
class nemo_automodel.components.distributed.parallel_styles.TPLinear#

Bases: torch.nn.Linear

nn.Linear variant safe for torch.compile + DTensor tensor-parallel weights.

F.linear decomposes to aten.view + aten.mm + aten.view for 3-D input. In AOT-autograd backward tracing the view on a sharded DTensor activation hits DTensor’s slow-path sharding propagation (no explicit rule for aten.view that changes the shard-dim index), which recurses infinitely.

torch.bmm is a native 3-D op whose backward is also bmm – no view is ever emitted. DTensor has explicit strategies for bmm covering the ColwiseParallel (Replicate x Shard(2) -> Shard(2)) and RowwiseParallel (Shard(2) x Shard(1) -> Partial) patterns.

Note: expand(b, -1, -1) dispatches through DTensor’s ShardingPropagator which caches via lru_cache keyed on DTensorSpec. With dynamic shapes, b = x.shape[0] is a SymInt, making DTensorSpec._hash_impl raise TypeError. This is handled by _patch_dtensor_spec_hash_for_symint() in parallelizer.py which falls back to a placement-only hash for SymInt shapes.

Usage: after TP weight sharding, convert an nn.Linear instance by setting linear.__class__ = TPLinear. This is the same class-swap trick used by translate_to_lora, and ensures torch.compile/dynamo sees the correct type(module).forward rather than nn.Linear.forward.

forward(x)#
class nemo_automodel.components.distributed.parallel_styles.ColwiseParallelLora#

Bases: torch.distributed.tensor.parallel.ColwiseParallel

_partition_linear_fn(name, module, device_mesh)#
_partition_embedding_fn(name, module, device_mesh)#
class nemo_automodel.components.distributed.parallel_styles.RowwiseParallelLora#

Bases: torch.distributed.tensor.parallel.RowwiseParallel

_partition_linear_fn(name, module, device_mesh)#
_partition_embedding_fn(name, module, device_mesh)#
class nemo_automodel.components.distributed.parallel_styles.SequenceParallelLora#

Bases: torch.distributed.tensor.parallel.SequenceParallel

_replicate_module_fn(
name: str,
module: torch.nn.Module,
device_mesh: torch.distributed.tensor.DeviceMesh,
)#
nemo_automodel.components.distributed.parallel_styles.translate_to_lora(plan)#