nemo_automodel.components.models.gpt_oss.rope_utils#

Module Contents#

Classes#

Functions#

apply_rotary_emb

Apply rotary embeddings to input tensor.

position_ids_to_freqs_cis

API#

nemo_automodel.components.models.gpt_oss.rope_utils.apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) torch.Tensor#

Apply rotary embeddings to input tensor.

If cos/sin have fewer dimensions than x (due to partial_rotary_factor < 1.0), only the first rotary_dim dimensions of x are rotated, and the rest are passed through.

Parameters:
  • x – Input tensor (…, head_dim)

  • cos – Cosine tensor (…, rotary_dim // 2)

  • sin – Sine tensor (…, rotary_dim // 2)

class nemo_automodel.components.models.gpt_oss.rope_utils.RotaryEmbedding(
head_dim: int,
base: int,
dtype: torch.dtype,
initial_context_length: int = 4096,
scaling_factor: float = 1.0,
ntk_alpha: float = 1.0,
ntk_beta: float = 32.0,
partial_rotary_factor: float = 1.0,
device: torch.device | None = None,
)#

Bases: torch.nn.Module

Initialization

_compute_concentration_and_inv_freq() torch.Tensor#

See YaRN paper: https://arxiv.org/abs/2309.00071

Uses rotary_dim instead of head_dim to support partial rotary embeddings.

_compute_cos_sin(num_tokens: int)#
forward(
query: torch.Tensor,
key: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor]#
nemo_automodel.components.models.gpt_oss.rope_utils.position_ids_to_freqs_cis(
rotary_emb: nemo_automodel.components.models.gpt_oss.rope_utils.RotaryEmbedding,
position_ids: torch.Tensor,
qkv_format: str = 'bshd',
) torch.Tensor#