nemo_automodel.components.distributed.parallel_styles#

Module Contents#

Classes#

Functions#

API#

nemo_automodel.components.distributed.parallel_styles._distribute_param(
_module,
name,
device_mesh,
src_data_rank,
placements,
)#
class nemo_automodel.components.distributed.parallel_styles.ColwiseParallelLora#

Bases: torch.distributed.tensor.parallel.ColwiseParallel

_partition_linear_fn(name, module, device_mesh)#
_partition_embedding_fn(name, module, device_mesh)#
class nemo_automodel.components.distributed.parallel_styles.RowwiseParallelLora#

Bases: torch.distributed.tensor.parallel.RowwiseParallel

_partition_linear_fn(name, module, device_mesh)#
_partition_embedding_fn(name, module, device_mesh)#
class nemo_automodel.components.distributed.parallel_styles.SequenceParallelLora#

Bases: torch.distributed.tensor.parallel.SequenceParallel

_replicate_module_fn(
name: str,
module: torch.nn.Module,
device_mesh: torch.distributed.tensor.DeviceMesh,
)#
nemo_automodel.components.distributed.parallel_styles.translate_to_lora(plan)#