nemo_automodel.components.distributed.parallel_styles#
Module Contents#
Classes#
nn.Linear variant safe for torch.compile + DTensor tensor-parallel weights. |
|
Functions#
API#
- nemo_automodel.components.distributed.parallel_styles._distribute_param(
- _module,
- name,
- device_mesh,
- src_data_rank,
- placements,
- class nemo_automodel.components.distributed.parallel_styles.TPLinear#
Bases:
torch.nn.Linearnn.Linear variant safe for torch.compile + DTensor tensor-parallel weights.
F.linear decomposes to aten.view + aten.mm + aten.view for 3-D input. In AOT-autograd backward tracing the view on a sharded DTensor activation hits DTensor’s slow-path sharding propagation (no explicit rule for aten.view that changes the shard-dim index), which recurses infinitely.
torch.bmm is a native 3-D op whose backward is also bmm – no view is ever emitted. DTensor has explicit strategies for bmm covering the ColwiseParallel (Replicate x Shard(2) -> Shard(2)) and RowwiseParallel (Shard(2) x Shard(1) -> Partial) patterns.
Note: expand(b, -1, -1) dispatches through DTensor’s ShardingPropagator which caches via lru_cache keyed on DTensorSpec. With dynamic shapes, b = x.shape[0] is a SymInt, making DTensorSpec._hash_impl raise TypeError. This is handled by _patch_dtensor_spec_hash_for_symint() in parallelizer.py which falls back to a placement-only hash for SymInt shapes.
Usage: after TP weight sharding, convert an nn.Linear instance by setting
linear.__class__ = TPLinear. This is the same class-swap trick used by translate_to_lora, and ensures torch.compile/dynamo sees the correct type(module).forward rather than nn.Linear.forward.- forward(x)#
- class nemo_automodel.components.distributed.parallel_styles.ColwiseParallelLora#
Bases:
torch.distributed.tensor.parallel.ColwiseParallel- _partition_linear_fn(name, module, device_mesh)#
- _partition_embedding_fn(name, module, device_mesh)#
- class nemo_automodel.components.distributed.parallel_styles.RowwiseParallelLora#
Bases:
torch.distributed.tensor.parallel.RowwiseParallel- _partition_linear_fn(name, module, device_mesh)#
- _partition_embedding_fn(name, module, device_mesh)#
- class nemo_automodel.components.distributed.parallel_styles.SequenceParallelLora#
Bases:
torch.distributed.tensor.parallel.SequenceParallel- _replicate_module_fn(
- name: str,
- module: torch.nn.Module,
- device_mesh: torch.distributed.tensor.DeviceMesh,
- nemo_automodel.components.distributed.parallel_styles.translate_to_lora(plan)#