nemo_automodel.components.moe.utils#

Module Contents#

Classes#

Functions#

Data#

API#

nemo_automodel.components.moe.utils.HAVE_TE#

None

nemo_automodel.components.moe.utils.HAVE_DEEP_EP#

None

class nemo_automodel.components.moe.utils.BackendConfig#
attn: Literal[te, sdpa, flex]#

None

linear: Literal[torch, te]#

None

rms_norm: Literal[torch, te]#

None

enable_deepep: bool#

None

fake_balanced_gate: bool#

False

enable_hf_state_dict_adapter: bool#

True

enable_fsdp_optimizations: bool#

False

gate_precision: str | torch.dtype | None#

None

__post_init__()#
nemo_automodel.components.moe.utils.initialize_rms_norm_module(
rms_norm_impl: str,
dim: int,
eps: float = 1e-05,
device: torch.device | str = 'meta',
dtype: torch.dtype = torch.bfloat16,
) torch.nn.Module#
nemo_automodel.components.moe.utils.initialize_linear_module(
linear_impl: str,
in_features: int,
out_features: int,
bias: bool = False,
device: torch.device | str = 'meta',
dtype: torch.dtype = torch.bfloat16,
) torch.nn.Module#