nemo_automodel.components.models.gpt_oss.rope_utils#
Module Contents#
Classes#
Functions#
Apply rotary embeddings to input tensor. |
|
Apply rotary embeddings to query and key tensors. |
|
API#
- nemo_automodel.components.models.gpt_oss.rope_utils.apply_rotary_emb(
- x: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
Apply rotary embeddings to input tensor.
If cos/sin have fewer dimensions than x (due to partial_rotary_factor < 1.0), only the first rotary_dim dimensions of x are rotated, and the rest are passed through.
- Parameters:
x β Input tensor (β¦, head_dim)
cos β Cosine tensor (β¦, rotary_dim // 2)
sin β Sine tensor (β¦, rotary_dim // 2)
- class nemo_automodel.components.models.gpt_oss.rope_utils.RotaryEmbedding(
- head_dim: int,
- base: int,
- dtype: torch.dtype,
- initial_context_length: int = 4096,
- scaling_factor: float = 1.0,
- ntk_alpha: float = 1.0,
- ntk_beta: float = 32.0,
- partial_rotary_factor: float = 1.0,
- device: torch.device | None = None,
Bases:
torch.nn.ModuleInitialization
- _compute_concentration_and_inv_freq() torch.Tensor#
See YaRN paper: https://arxiv.org/abs/2309.00071
Uses rotary_dim instead of head_dim to support partial rotary embeddings.
- _compute_cos_sin(num_tokens: int)#
- forward(
- query: torch.Tensor,
- key: torch.Tensor,
- nemo_automodel.components.models.gpt_oss.rope_utils.apply_rotary_emb_qk(
- q: torch.Tensor,
- k: torch.Tensor,
- freqs_cis: torch.Tensor,
- format: str = 'bshd',
- rope_fusion: bool = True,
- cu_seqlens: torch.Tensor | None = None,
- concentration: float | None = None,
- cp_size: int = 1,
- cp_rank: int = 0,
Apply rotary embeddings to query and key tensors.
- Parameters:
q β Query tensor.
k β Key tensor.
freqs_cis β
Frequency tensor. Format depends on rope_fusion:
If rope_fusion=True: [angles, angles] for TE fused rope
If rope_fusion=False: [cos, sin] with concentration applied
format β QKV format (βbshdβ or βthdβ).
rope_fusion β If True, use TE fused rope. If False, use non-fused rope.
cu_seqlens β Cumulative sequence lengths for variable-length sequences.
cp_size β Context parallelism size.
cp_rank β Context parallelism rank.
- Returns:
Tuple of (q, k) with rotary embeddings applied.
- nemo_automodel.components.models.gpt_oss.rope_utils.position_ids_to_freqs_cis(
- rotary_emb: nemo_automodel.components.models.gpt_oss.rope_utils.RotaryEmbedding,
- position_ids: torch.Tensor,
- qkv_format: str = 'bshd',
- for_fused_rope: bool = True,
- cp_size: int = 1,