nemo_automodel.components.models.llama.rope_utils#
Rotary Position Embedding utilities for Llama and Qwen2 models.
This module provides RoPE implementation following HuggingFaceâs architecture.
API: rotary_emb = RotaryEmbedding(config) cos, sin = rotary_emb(x, position_ids) # Returns (cos, sin) tuple q, k = apply_rotary_pos_emb(q, k, cos, sin) # Applies RoPE
Supports both:
LlamaConfig: uses config.rope_theta and config.rope_scaling
Qwen2Config: uses config.rope_parameters[ârope_thetaâ] and config.rope_parameters
Note: gpt_oss and deepseek_v3 have their own specialized rope_utils.py with model-specific optimizations (YaRN, MLA, etc.).
Module Contents#
Classes#
Rotary Position Embedding module for Llama and Qwen2 models. |
Functions#
Rotates half the hidden dims of the input. |
|
Applies Rotary Position Embedding to the query and key tensors. |
|
Extract rope parameters from config (handles both Llama and Qwen2 formats). |
|
Computes inverse frequencies for standard RoPE. |
|
Computes inverse frequencies for Llama3-style RoPE with smooth interpolation. |
Data#
API#
- nemo_automodel.components.models.llama.rope_utils.rotate_half(x: torch.Tensor) torch.Tensor#
Rotates half the hidden dims of the input.
- nemo_automodel.components.models.llama.rope_utils.apply_rotary_pos_emb(
- q: torch.Tensor,
- k: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
Applies Rotary Position Embedding to the query and key tensors.
- Parameters:
q â Query tensor [batch, num_heads, seq_len, head_dim]
k â Key tensor [batch, num_kv_heads, seq_len, head_dim]
cos â Cosine embeddings [batch, seq_len, head_dim]
sin â Sine embeddings [batch, seq_len, head_dim]
- Returns:
Rotated (q, k) tensors
- nemo_automodel.components.models.llama.rope_utils._get_rope_config(config) tuple[float, dict]#
Extract rope parameters from config (handles both Llama and Qwen2 formats).
- Returns:
Tuple of (rope_theta, rope_scaling_dict)
- nemo_automodel.components.models.llama.rope_utils._compute_default_inv_freq(
- config,
- device: Optional[torch.device] = None,
Computes inverse frequencies for standard RoPE.
- nemo_automodel.components.models.llama.rope_utils._compute_llama3_inv_freq(
- config,
- device: Optional[torch.device] = None,
Computes inverse frequencies for Llama3-style RoPE with smooth interpolation.
Branch logic (matches HF _compute_llama3_parameters):
Long wavelength (low freq, wavelen > low_freq_wavelen) â scale by factor
Short wavelength (high freq, wavelen < high_freq_wavelen) â unchanged
Medium band â smooth interpolation
- class nemo_automodel.components.models.llama.rope_utils.LlamaRotaryEmbedding(
- config,
- device: Optional[torch.device] = None,
Bases:
torch.nn.ModuleRotary Position Embedding module for Llama and Qwen2 models.
Returns (cos, sin) tuple for use with apply_rotary_pos_emb.
Usage: rotary_emb = RotaryEmbedding(config) cos, sin = rotary_emb(x, position_ids) q, k = apply_rotary_pos_emb(q, k, cos, sin)
Initialization
- inv_freq: torch.Tensor#
None
- _build_cache(seq_len: int, device: torch.device) None#
Build cos/sin cache in config dtype for positions [0, seq_len).
- forward(
- x: torch.Tensor,
- position_ids: torch.Tensor,
Return (cos, sin) for the given positions.
- Parameters:
x â Input tensor (used for device and dtype)
position_ids â Position IDs tensor [batch, seq_len]
- Returns:
(cos, sin) tensors [batch, seq_len, head_dim]
- nemo_automodel.components.models.llama.rope_utils.RotaryEmbedding#
None
- nemo_automodel.components.models.llama.rope_utils.Qwen2RotaryEmbedding#
None
- nemo_automodel.components.models.llama.rope_utils.__all__#
[âRotaryEmbeddingâ, âLlamaRotaryEmbeddingâ, âQwen2RotaryEmbeddingâ, ârotate_halfâ, âapply_rotary_posâŠ