nemo_automodel.components.moe.utils#

Module Contents#

Classes#

Functions#

API#

class nemo_automodel.components.moe.utils.BackendConfig#
attn: Literal[te, sdpa, flex]#

‘te’

linear: Literal[torch, te]#

‘torch’

rms_norm: Literal[torch, te]#

‘te’

enable_deepep: bool#

False

fake_balanced_gate: bool#

False

enable_hf_state_dict_adapter: bool#

False

nemo_automodel.components.moe.utils.initialize_attn_module_and_func(
attn_impl: str,
num_attention_heads: int,
num_qk_channels: int,
num_v_channels: int,
softmax_scale: float,
attn_mask_type: str = 'causal',
qkv_format: str = 'bshd',
) tuple[torch.nn.Module | None, Callable]#
nemo_automodel.components.moe.utils.initialize_rms_norm_module(
rms_norm_impl: str,
dim: int,
eps: float = 1e-05,
device: torch.device | str = 'meta',
) torch.nn.Module#
nemo_automodel.components.moe.utils.initialize_linear_module(
linear_impl: str,
in_features: int,
out_features: int,
bias: bool = False,
device: torch.device | str = 'meta',
) torch.nn.Module#