core.models.common.embeddings.rope_utils#
Module Contents#
Functions#
Get the position embedding on the current context parallel rank. |
|
Change sign so the last dimension becomes [-odd, +even] |
|
Apply rotary positional embedding to input tensor T. |
|
Get the correct frequency slice for this context parallel rank with optional sequence offset. |
|
A baseline implementation of applying RoPE for |
|
Reroute to the appropriate apply_rotary_pos_emb function depending on fused/unfused kernels, or bshd (conventional) / thd (packed seq) format |
|
This function applies rotary positional embedding to the target tensor t using precomputed cos and sin of size (seq_len, d_rot / 2) |
Data#
API#
- core.models.common.embeddings.rope_utils.logger#
‘getLogger(…)’
- core.models.common.embeddings.rope_utils.__all__#
[‘apply_rotary_pos_emb’, ‘apply_rotary_emb_flash’, ‘apply_rotary_pos_emb_with_cos_sin’, ‘fused_apply…
- core.models.common.embeddings.rope_utils.get_pos_emb_on_this_cp_rank(
- pos_emb: torch.Tensor,
- seq_dim: int,
- cp_group: torch.distributed.ProcessGroup,
Get the position embedding on the current context parallel rank.
- Parameters:
pos_emb (Tensor) – Positional embedding tensor
seq_dim (int) – Sequence dimension
cp_group (torch.distributed.ProcessGroup) – The context parallel group
- core.models.common.embeddings.rope_utils._rotate_half(
- x: torch.Tensor,
- rotary_interleaved: bool,
Change sign so the last dimension becomes [-odd, +even]
- Parameters:
x (Tensor) – Input tensor
- Returns:
Tensor rotated half
- Return type:
Tensor
- core.models.common.embeddings.rope_utils._apply_rotary_pos_emb_bshd(
- t: torch.Tensor,
- freqs: torch.Tensor,
- rotary_interleaved: bool = False,
- multi_latent_attention: bool = False,
- mscale: float = 1.0,
Apply rotary positional embedding to input tensor T.
check https://kexue.fm/archives/8265 for detailed formulas
- Parameters:
t (Tensor) – Input tensor T is of shape [seq_length, … , dim]
freqs (Tensor) – Rotary Positional embedding tensor freq is of shape [seq_length, …, dim]
- Returns:
The input tensor after applying RoPE
- Return type:
Tensor
- core.models.common.embeddings.rope_utils._get_thd_freqs_on_this_cp_rank(
- cp_rank: int,
- cp_size: int,
- x: torch.Tensor,
- freqs: torch.Tensor,
- offset: int = 0,
Get the correct frequency slice for this context parallel rank with optional sequence offset.
- Parameters:
cp_rank – Current context parallel rank
cp_size – Total context parallel size
x – Input tensor for current sequence
freqs – Frequency tensor - either full batch positions or max sequence length
offset – Starting position offset for this sequence in the original batch (default: 0)
- Returns:
Frequency slice corresponding to this CP rank’s portion of the sequence
- Return type:
Tensor
.. note::
This function supports two modes based on the offset parameter:
offset > 0: Exact mapping mode - freqs contains all positions across all sequences. The offset ensures each sequence gets frequencies from its actual position within the overall batch. Critical for non-1D RoPE in VLMs where spatial positions matter.
offset = 0: Traditional mode - freqs contains only max sequence length positions. All sequences use frequencies starting from position 0, preserving backward compatibility.
- core.models.common.embeddings.rope_utils._apply_rotary_pos_emb_thd(
- t: torch.Tensor,
- cu_seqlens: torch.Tensor,
- freqs: torch.Tensor,
- rotary_interleaved: bool = False,
- multi_latent_attention: bool = False,
- mscale: float = 1.0,
- cp_group: torch.distributed.ProcessGroup = None,
A baseline implementation of applying RoPE for
thdformat.- Parameters:
t (Tensor) – Input tensor T is of shape [t, h, d]
cu_seqlens (Tensor) – Cumulative sum of sequence lengths in a batch for
t,torch.int32. (with shape [b + 1] and dtype)
freqs (Tensor) – Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
cp_group (torch.distributed.ProcessGroup) – The context parallel group
- Returns:
Shape [t, h, d]. The input tensor after applying RoPE.
- Return type:
Tensor
- core.models.common.embeddings.rope_utils.apply_rotary_pos_emb(
- t: torch.Tensor,
- freqs: torch.Tensor,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- cu_seqlens: Optional[torch.Tensor] = None,
- mscale: float = 1.0,
- cp_group: torch.distributed.ProcessGroup = None,
Reroute to the appropriate apply_rotary_pos_emb function depending on fused/unfused kernels, or bshd (conventional) / thd (packed seq) format
- core.models.common.embeddings.rope_utils.apply_rotary_pos_emb_with_cos_sin(
- t: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- rotary_interleaved: bool = False,
This function applies rotary positional embedding to the target tensor t using precomputed cos and sin of size (seq_len, d_rot / 2)