nemo_automodel.components.models.gpt_oss.rope_utils

View as Markdown

Module Contents

Classes

NameDescription
RotaryEmbedding-

Functions

NameDescription
apply_rotary_embApply rotary embeddings to input tensor.
apply_rotary_emb_qkApply rotary embeddings to query and key tensors.
position_ids_to_freqs_cis-

API

class nemo_automodel.components.models.gpt_oss.rope_utils.RotaryEmbedding(
head_dim: int,
base: int,
dtype: torch.dtype,
initial_context_length: int = 4096,
scaling_factor: float = 1.0,
ntk_alpha: float = 1.0,
ntk_beta: float = 32.0,
partial_rotary_factor: float = 1.0,
device: torch.device | None = None
)

Bases: Module

rotary_dim
= int(head_dim * partial_rotary_factor)
nemo_automodel.components.models.gpt_oss.rope_utils.RotaryEmbedding._compute_concentration_and_inv_freq() -> torch.Tensor

See YaRN paper: https://arxiv.org/abs/2309.00071

Uses rotary_dim instead of head_dim to support partial rotary embeddings.

nemo_automodel.components.models.gpt_oss.rope_utils.RotaryEmbedding._compute_cos_sin(
num_tokens: int
)
nemo_automodel.components.models.gpt_oss.rope_utils.RotaryEmbedding.forward(
query: torch.Tensor,
key: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.models.gpt_oss.rope_utils.apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor
) -> torch.Tensor

Apply rotary embeddings to input tensor.

If cos/sin have fewer dimensions than x (due to partial_rotary_factor < 1.0), only the first rotary_dim dimensions of x are rotated, and the rest are passed through.

Parameters:

x
torch.Tensor

Input tensor (…, head_dim)

cos
torch.Tensor

Cosine tensor (…, rotary_dim // 2)

sin
torch.Tensor

Sine tensor (…, rotary_dim // 2)

nemo_automodel.components.models.gpt_oss.rope_utils.apply_rotary_emb_qk(
q: torch.Tensor,
k: torch.Tensor,
freqs_cis: torch.Tensor,
format: str = 'bshd',
rope_fusion: bool = True,
cu_seqlens: torch.Tensor | None = None,
concentration: float | None = None,
cp_size: int = 1,
cp_rank: int = 0
) -> tuple[torch.Tensor, torch.Tensor]

Apply rotary embeddings to query and key tensors.

Parameters:

q
torch.Tensor

Query tensor.

k
torch.Tensor

Key tensor.

freqs_cis
torch.Tensor

Frequency tensor. Format depends on rope_fusion:

  • If rope_fusion=True: [angles, angles] for TE fused rope
  • If rope_fusion=False: [cos, sin] with concentration applied
format
strDefaults to 'bshd'

QKV format (“bshd” or “thd”).

rope_fusion
boolDefaults to True

If True, use TE fused rope. If False, use non-fused rope.

cu_seqlens
torch.Tensor | NoneDefaults to None

Cumulative sequence lengths for variable-length sequences.

cp_size
intDefaults to 1

Context parallelism size.

cp_rank
intDefaults to 0

Context parallelism rank.

Returns: tuple[torch.Tensor, torch.Tensor]

Tuple of (q, k) with rotary embeddings applied.

nemo_automodel.components.models.gpt_oss.rope_utils.position_ids_to_freqs_cis(
rotary_emb: nemo_automodel.components.models.gpt_oss.rope_utils.RotaryEmbedding,
position_ids: torch.Tensor,
qkv_format: str = 'bshd',
for_fused_rope: bool = True,
cp_size: int = 1
) -> torch.Tensor