core.models.common.embeddings.rope_utils#

Module Contents#

Functions#

get_pos_emb_on_this_cp_rank

Get the position embedding on the current context parallel rank.

_rotate_half

Change sign so the last dimension becomes [-odd, +even]

_apply_rotary_pos_emb_bshd

Apply rotary positional embedding to input tensor T.

_get_thd_freqs_on_this_cp_rank

Get the correct frequency slice for this context parallel rank with optional sequence offset.

_apply_rotary_pos_emb_thd

A baseline implementation of applying RoPE for thd format.

apply_rotary_pos_emb

Reroute to the appropriate apply_rotary_pos_emb function depending on fused/unfused kernels, or bshd (conventional) / thd (packed seq) format

apply_rotary_pos_emb_with_cos_sin

This function applies rotary positional embedding to the target tensor t using precomputed cos and sin of size (seq_len, d_rot / 2)

Data#

API#

core.models.common.embeddings.rope_utils.logger#

‘getLogger(…)’

core.models.common.embeddings.rope_utils.__all__#

[‘apply_rotary_pos_emb’, ‘apply_rotary_emb_flash’, ‘apply_rotary_pos_emb_with_cos_sin’, ‘fused_apply…

core.models.common.embeddings.rope_utils.get_pos_emb_on_this_cp_rank(
pos_emb: torch.Tensor,
seq_dim: int,
cp_group: torch.distributed.ProcessGroup,
) torch.Tensor#

Get the position embedding on the current context parallel rank.

Parameters:
  • pos_emb (Tensor) – Positional embedding tensor

  • seq_dim (int) – Sequence dimension

  • cp_group (torch.distributed.ProcessGroup) – The context parallel group

core.models.common.embeddings.rope_utils._rotate_half(
x: torch.Tensor,
rotary_interleaved: bool,
) torch.Tensor#

Change sign so the last dimension becomes [-odd, +even]

Parameters:

x (Tensor) – Input tensor

Returns:

Tensor rotated half

Return type:

Tensor

core.models.common.embeddings.rope_utils._apply_rotary_pos_emb_bshd(
t: torch.Tensor,
freqs: torch.Tensor,
rotary_interleaved: bool = False,
multi_latent_attention: bool = False,
mscale: float = 1.0,
) torch.Tensor#

Apply rotary positional embedding to input tensor T.

check https://kexue.fm/archives/8265 for detailed formulas

Parameters:
  • t (Tensor) – Input tensor T is of shape [seq_length, … , dim]

  • freqs (Tensor) – Rotary Positional embedding tensor freq is of shape [seq_length, …, dim]

Returns:

The input tensor after applying RoPE

Return type:

Tensor

core.models.common.embeddings.rope_utils._get_thd_freqs_on_this_cp_rank(
cp_rank: int,
cp_size: int,
x: torch.Tensor,
freqs: torch.Tensor,
offset: int = 0,
) torch.Tensor#

Get the correct frequency slice for this context parallel rank with optional sequence offset.

Parameters:
  • cp_rank – Current context parallel rank

  • cp_size – Total context parallel size

  • x – Input tensor for current sequence

  • freqs – Frequency tensor - either full batch positions or max sequence length

  • offset – Starting position offset for this sequence in the original batch (default: 0)

Returns:

Frequency slice corresponding to this CP rank’s portion of the sequence

Return type:

Tensor

.. note::

This function supports two modes based on the offset parameter:

  1. offset > 0: Exact mapping mode - freqs contains all positions across all sequences. The offset ensures each sequence gets frequencies from its actual position within the overall batch. Critical for non-1D RoPE in VLMs where spatial positions matter.

  2. offset = 0: Traditional mode - freqs contains only max sequence length positions. All sequences use frequencies starting from position 0, preserving backward compatibility.

core.models.common.embeddings.rope_utils._apply_rotary_pos_emb_thd(
t: torch.Tensor,
cu_seqlens: torch.Tensor,
freqs: torch.Tensor,
rotary_interleaved: bool = False,
multi_latent_attention: bool = False,
mscale: float = 1.0,
cp_group: torch.distributed.ProcessGroup = None,
) torch.Tensor#

A baseline implementation of applying RoPE for thd format.

Parameters:
  • t (Tensor) – Input tensor T is of shape [t, h, d]

  • cu_seqlens (Tensor) – Cumulative sum of sequence lengths in a batch for t,

  • torch.int32. (with shape [b + 1] and dtype)

  • freqs (Tensor) – Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]

  • cp_group (torch.distributed.ProcessGroup) – The context parallel group

Returns:

Shape [t, h, d]. The input tensor after applying RoPE.

Return type:

Tensor

core.models.common.embeddings.rope_utils.apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
config: megatron.core.transformer.transformer_config.TransformerConfig,
cu_seqlens: Optional[torch.Tensor] = None,
mscale: float = 1.0,
cp_group: torch.distributed.ProcessGroup = None,
)#

Reroute to the appropriate apply_rotary_pos_emb function depending on fused/unfused kernels, or bshd (conventional) / thd (packed seq) format

core.models.common.embeddings.rope_utils.apply_rotary_pos_emb_with_cos_sin(
t: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
rotary_interleaved: bool = False,
) torch.Tensor#

This function applies rotary positional embedding to the target tensor t using precomputed cos and sin of size (seq_len, d_rot / 2)