nemo_automodel.components.distributed.parallel_styles
nemo_automodel.components.distributed.parallel_styles
Module Contents
Classes
Functions
API
Bases: ColwiseParallel
Column-wise tensor parallel style for LoRA-aware modules.
Bases: RowwiseParallel
Row-wise tensor parallel style for LoRA-aware modules.
Bases: SequenceParallel
Sequence parallel style that replicates LoRA module parameters.
Bases: Linear
nn.Linear variant safe for torch.compile + DTensor tensor-parallel weights.
F.linear decomposes to aten.view + aten.mm + aten.view for 3-D input. In AOT-autograd backward tracing the view on a sharded DTensor activation hits DTensor’s slow-path sharding propagation (no explicit rule for aten.view that changes the shard-dim index), which recurses infinitely.
torch.bmm is a native 3-D op whose backward is also bmm — no view is ever emitted. DTensor has explicit strategies for bmm covering the ColwiseParallel (Replicate x Shard(2) -> Shard(2)) and RowwiseParallel (Shard(2) x Shard(1) -> Partial) patterns.
Note: expand(b, -1, -1) dispatches through DTensor’s ShardingPropagator which caches via lru_cache keyed on DTensorSpec. With dynamic shapes, b = x.shape[0] is a SymInt, making DTensorSpec._hash_impl raise TypeError. This is handled by _patch_dtensor_spec_hash_for_symint() in parallelizer.py which falls back to a placement-only hash for SymInt shapes.
Usage: after TP weight sharding, convert an nn.Linear instance by setting
linear.__class__ = TPLinear. This is the same class-swap trick used
by translate_to_lora, and ensures torch.compile/dynamo sees the correct
type(module).forward rather than nn.Linear.forward.
Mutate a tensor-parallel plan to the matching LoRA-aware style.