nemo_automodel.components.models.qwen3_next.layers#

Module Contents#

Classes#

API#

class nemo_automodel.components.models.qwen3_next.layers.Qwen3NextRMSNorm(dim: int, eps: float = 1e-06)#

Bases: torch.nn.Module

Initialization

_norm(x)#
forward(x)#
extra_repr()#
reset_parameters()#
class nemo_automodel.components.models.qwen3_next.layers.Qwen3NextAttention(
config: transformers.models.qwen3_next.configuration_qwen3_next.Qwen3NextConfig,
layer_idx: int,
backend: nemo_automodel.components.moe.utils.BackendConfig,
)#

Bases: torch.nn.Module

Initialization

forward(
x: torch.Tensor,
*,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#
init_weights(buffer_device: torch.device, init_std: float = 0.02)#