nemo_automodel.components.moe.parallelizer#

Module Contents#

Classes#

ExpertParallel

ExpertParallel class is used to shard the MoE parameters on the EP mesh. Dim 0 of each parameter is sharded since that is the expert dimension.

Functions#

apply_ep

Applies EP to MoE module.

apply_ac

Apply activation checkpointing to the model.

apply_fsdp

parallelize_model

Data#

API#

nemo_automodel.components.moe.parallelizer.logger#

‘getLogger(…)’

class nemo_automodel.components.moe.parallelizer.ExpertParallel#

Bases: torch.distributed.tensor.parallel.ParallelStyle

ExpertParallel class is used to shard the MoE parameters on the EP mesh. Dim 0 of each parameter is sharded since that is the expert dimension.

_partition_fn(name, module, device_mesh)#
_apply(
module: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
) torch.nn.Module#
nemo_automodel.components.moe.parallelizer.apply_ep(
model: torch.nn.Module,
ep_mesh: torch.distributed.device_mesh.DeviceMesh,
)#

Applies EP to MoE module.

nemo_automodel.components.moe.parallelizer.apply_ac(
model: torch.nn.Module,
ignore_router: bool = False,
hidden_size: int = 7168,
num_experts: int = 256,
)#

Apply activation checkpointing to the model.

nemo_automodel.components.moe.parallelizer.apply_fsdp(
model: torch.nn.Module,
fsdp_mesh: torch.distributed.device_mesh.DeviceMesh,
pp_enabled: bool,
ep_enabled: bool,
ep_shard_enabled: bool,
ep_shard_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
mp_policy: torch.distributed.fsdp._fully_shard.MixedPrecisionPolicy | None = None,
offload_policy: torch.distributed.fsdp._fully_shard.OffloadPolicy | None = None,
)#
nemo_automodel.components.moe.parallelizer.parallelize_model(
model: torch.nn.Module,
world_mesh: torch.distributed.device_mesh.DeviceMesh,
moe_mesh: torch.distributed.device_mesh.DeviceMesh | None,
*,
pp_enabled: bool,
dp_axis_names: tuple[str, ...],
cp_axis_name: str | None = None,
tp_axis_name: str | None = None,
ep_axis_name: str | None = None,
ep_shard_axis_names: tuple[str, ...] | None = None,
activation_checkpointing: bool = False,
)#