nemo_automodel.components.models.llama.rope_utils
nemo_automodel.components.models.llama.rope_utils
Rotary Position Embedding utilities for Llama and Qwen2 models.
This module provides RoPE implementation following HuggingFace’s architecture.
Supports both:
- LlamaConfig: uses config.rope_theta and config.rope_scaling
- Qwen2Config: uses config.rope_parameters[“rope_theta”] and config.rope_parameters
Note: gpt_oss and deepseek_v3 have their own specialized rope_utils.py with model-specific optimizations (YaRN, MLA, etc.).
Module Contents
Classes
Functions
Data
API
Bases: Module
Rotary Position Embedding module for Llama and Qwen2 models.
Returns (cos, sin) tuple for use with apply_rotary_pos_emb.
Build cos/sin cache in config dtype for positions [0, seq_len).
Build or grow the cos/sin cache so it covers positions [0, seq_len).
Return (cos, sin) for the given positions.
In the non-fused path cos / sin are gathered by the values in
position_ids, so non-contiguous positions receive the correct rotary
phase: EAGLE TTT depth offsets (arange(seq_len) + step_idx), packed
sequences, and context parallelism all pass
position_ids != arange(seq_len). The earlier implementation returned
cos_cache[:seq_len], which keyed only on the sequence length and
silently ignored the position values. For the common
position_ids == arange(seq_len) case the gather is numerically
identical to that slice.
The fused TE path (rope_fusion=True) consumes raw angles indexed by
sequence position and assumes contiguous [0, seq_len) positions, so
it keeps the legacy contiguous slice and does NOT honor non-contiguous
position_ids.
Parameters:
Input tensor (used for device and dtype)
Position IDs tensor [batch, seq_len]
Returns: tuple[torch.Tensor, torch.Tensor]
(cos, sin) tensors [batch, seq_len, head_dim]
Computes inverse frequencies for standard RoPE.
Computes inverse frequencies for Llama3-style RoPE with smooth interpolation.
Branch logic (matches HF _compute_llama3_parameters):
- Long wavelength (low freq, wavelen > low_freq_wavelen) → scale by factor
- Short wavelength (high freq, wavelen < high_freq_wavelen) → unchanged
- Medium band → smooth interpolation
Extract rope parameters from config (handles both Llama and Qwen2 formats).
Returns: tuple[float, dict]
Tuple of (rope_theta, rope_scaling_dict)
Applies Rotary Position Embedding to the query and key tensors.
Parameters:
Query tensor [batch, num_heads, seq_len, head_dim]
Key tensor [batch, num_kv_heads, seq_len, head_dim]
Cosine embeddings [batch, seq_len, head_dim]
Sine embeddings [batch, seq_len, head_dim]
Returns: tuple[torch.Tensor, torch.Tensor]
Rotated (q, k) tensors
Applies RoPE using TE’s fused kernel.
Parameters:
Query tensor [batch, num_heads, seq_len, head_dim]
Key tensor [batch, num_kv_heads, seq_len, head_dim]
Raw angles [seq_len, 1, 1, head_dim] in TE format
Returns: tuple[torch.Tensor, torch.Tensor]
Rotated (q, k) tensors
Rotates half the hidden dims of the input.