nemo_automodel.components.moe.rope_utils#

Module Contents#

Functions#

yarn_get_mscale

precompute_freqs_cis

Precomputes frequency-based complex exponential values for rotary positional embeddings.

apply_rotary_emb

Applies rotary positional embeddings to the input tensor.

freqs_cis_from_position_ids

API#

nemo_automodel.components.moe.rope_utils.yarn_get_mscale(scale=1, mscale=1)#
nemo_automodel.components.moe.rope_utils.precompute_freqs_cis(
qk_rope_head_dim: int,
max_seq_len: int,
rope_theta: float,
rope_scaling: dict[str, float | int] | None,
) torch.Tensor#

Precomputes frequency-based complex exponential values for rotary positional embeddings.

Parameters:
  • qk_rope_head_dim (int) – Dimensionality of the rotary positional embeddings.

  • max_seq_len (int) – Maximum sequence length.

  • original_seq_len (int) – Original sequence length.

  • beta_fast (int) – Fast beta value for the exponential computation.

  • beta_slow (int) – Slow beta value for the exponential computation.

  • rope_theta (float) – Base value for the exponential computation.

  • rope_factor (float) – Factor value for the exponential computation.

Returns:

Precomputed complex exponential values for positional embeddings.

Return type:

torch.Tensor

nemo_automodel.components.moe.rope_utils.apply_rotary_emb(
x: torch.Tensor,
freqs_cis: torch.Tensor,
) torch.Tensor#

Applies rotary positional embeddings to the input tensor.

Parameters:
  • x (torch.Tensor) – Input tensor with positional embeddings to be applied.

  • freqs_cis (torch.Tensor) – Precomputed complex exponential values for positional embeddings.

Returns:

Tensor with rotary embeddings applied.

Return type:

torch.Tensor

nemo_automodel.components.moe.rope_utils.freqs_cis_from_position_ids(
position_ids: torch.Tensor,
freqs: torch.Tensor,
) torch.Tensor#