nemo_automodel.components.models.llama.rope_utils

View as Markdown

Rotary Position Embedding utilities for Llama and Qwen2 models.

This module provides RoPE implementation following HuggingFace’s architecture.

Supports both:

  • LlamaConfig: uses config.rope_theta and config.rope_scaling
  • Qwen2Config: uses config.rope_parameters[“rope_theta”] and config.rope_parameters

Note: gpt_oss and deepseek_v3 have their own specialized rope_utils.py with model-specific optimizations (YaRN, MLA, etc.).

Module Contents

Classes

NameDescription
LlamaRotaryEmbeddingRotary Position Embedding module for Llama and Qwen2 models.

Functions

NameDescription
_compute_default_inv_freqComputes inverse frequencies for standard RoPE.
_compute_llama3_inv_freqComputes inverse frequencies for Llama3-style RoPE with smooth interpolation.
_get_rope_configExtract rope parameters from config (handles both Llama and Qwen2 formats).
apply_rotary_pos_embApplies Rotary Position Embedding to the query and key tensors.
apply_rotary_pos_emb_fusedApplies RoPE using TE’s fused kernel.
rotate_halfRotates half the hidden dims of the input.

Data

Qwen2RotaryEmbedding

RotaryEmbedding

__all__

API

class nemo_automodel.components.models.llama.rope_utils.LlamaRotaryEmbedding(
config,
device: typing.Optional[torch.device] = None,
rope_fusion: bool = False
)

Bases: Module

Rotary Position Embedding module for Llama and Qwen2 models.

Returns (cos, sin) tuple for use with apply_rotary_pos_emb.

dtype
inv_freq
Tensor
max_seq_len_cached
= 0
nemo_automodel.components.models.llama.rope_utils.LlamaRotaryEmbedding._build_cache(
seq_len: int,
device: torch.device
) -> None

Build cos/sin cache in config dtype for positions [0, seq_len).

nemo_automodel.components.models.llama.rope_utils.LlamaRotaryEmbedding._ensure_cache(
seq_len: int,
device: torch.device
) -> None

Build or grow the cos/sin cache so it covers positions [0, seq_len).

nemo_automodel.components.models.llama.rope_utils.LlamaRotaryEmbedding.forward(
x: torch.Tensor,
position_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]

Return (cos, sin) for the given positions.

In the non-fused path cos / sin are gathered by the values in position_ids, so non-contiguous positions receive the correct rotary phase: EAGLE TTT depth offsets (arange(seq_len) + step_idx), packed sequences, and context parallelism all pass position_ids != arange(seq_len). The earlier implementation returned cos_cache[:seq_len], which keyed only on the sequence length and silently ignored the position values. For the common position_ids == arange(seq_len) case the gather is numerically identical to that slice.

The fused TE path (rope_fusion=True) consumes raw angles indexed by sequence position and assumes contiguous [0, seq_len) positions, so it keeps the legacy contiguous slice and does NOT honor non-contiguous position_ids.

Parameters:

x
torch.Tensor

Input tensor (used for device and dtype)

position_ids
torch.Tensor

Position IDs tensor [batch, seq_len]

Returns: tuple[torch.Tensor, torch.Tensor]

(cos, sin) tensors [batch, seq_len, head_dim]

nemo_automodel.components.models.llama.rope_utils._compute_default_inv_freq(
config,
device: typing.Optional[torch.device] = None
) -> tuple[torch.Tensor, float]

Computes inverse frequencies for standard RoPE.

nemo_automodel.components.models.llama.rope_utils._compute_llama3_inv_freq(
config,
device: typing.Optional[torch.device] = None
) -> tuple[torch.Tensor, float]

Computes inverse frequencies for Llama3-style RoPE with smooth interpolation.

Branch logic (matches HF _compute_llama3_parameters):

  • Long wavelength (low freq, wavelen > low_freq_wavelen) → scale by factor
  • Short wavelength (high freq, wavelen < high_freq_wavelen) → unchanged
  • Medium band → smooth interpolation
nemo_automodel.components.models.llama.rope_utils._get_rope_config(
config
) -> tuple[float, dict]

Extract rope parameters from config (handles both Llama and Qwen2 formats).

Returns: tuple[float, dict]

Tuple of (rope_theta, rope_scaling_dict)

nemo_automodel.components.models.llama.rope_utils.apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]

Applies Rotary Position Embedding to the query and key tensors.

Parameters:

q
torch.Tensor

Query tensor [batch, num_heads, seq_len, head_dim]

k
torch.Tensor

Key tensor [batch, num_kv_heads, seq_len, head_dim]

cos
torch.Tensor

Cosine embeddings [batch, seq_len, head_dim]

sin
torch.Tensor

Sine embeddings [batch, seq_len, head_dim]

Returns: tuple[torch.Tensor, torch.Tensor]

Rotated (q, k) tensors

nemo_automodel.components.models.llama.rope_utils.apply_rotary_pos_emb_fused(
q: torch.Tensor,
k: torch.Tensor,
freqs_cis: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]

Applies RoPE using TE’s fused kernel.

Parameters:

q
torch.Tensor

Query tensor [batch, num_heads, seq_len, head_dim]

k
torch.Tensor

Key tensor [batch, num_kv_heads, seq_len, head_dim]

freqs_cis
torch.Tensor

Raw angles [seq_len, 1, 1, head_dim] in TE format

Returns: tuple[torch.Tensor, torch.Tensor]

Rotated (q, k) tensors

nemo_automodel.components.models.llama.rope_utils.rotate_half(
x: torch.Tensor
) -> torch.Tensor

Rotates half the hidden dims of the input.

nemo_automodel.components.models.llama.rope_utils.Qwen2RotaryEmbedding = LlamaRotaryEmbedding
nemo_automodel.components.models.llama.rope_utils.RotaryEmbedding = LlamaRotaryEmbedding
nemo_automodel.components.models.llama.rope_utils.__all__ = ['RotaryEmbedding', 'LlamaRotaryEmbedding', 'Qwen2RotaryEmbedding', 'rotate_half...