bridge.diffusion.models.common.nemotron_labs_diffusion_attention#
NemotronLabsDiffusionAttention for sbd_block_diff diffusion LM training with YARN RoPE.
Module Contents#
Classes#
RoPE with YARN support, driven by HF |
|
NemotronLabsDiffusionAttention for semi-block-diffusion (sbd_block_diff) training. |
Functions#
Thin compiled wrapper around flex_attention. |
|
Rotate the last half of the hidden dimension for RoPE. |
|
Apply rotary position embeddings to query and key tensors. |
|
Expand KV heads to match query heads for GQA. |
|
API#
- bridge.diffusion.models.common.nemotron_labs_diffusion_attention.fused_flex_attention(
- q,
- k,
- v,
- score_mod=None,
- block_mask=None,
- return_lse=False,
Thin compiled wrapper around flex_attention.
- bridge.diffusion.models.common.nemotron_labs_diffusion_attention.rotate_half(x)#
Rotate the last half of the hidden dimension for RoPE.
- bridge.diffusion.models.common.nemotron_labs_diffusion_attention.apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1)#
Apply rotary position embeddings to query and key tensors.
- bridge.diffusion.models.common.nemotron_labs_diffusion_attention.repeat_kv(hidden_states: torch.Tensor, n_rep: int) torch.Tensor#
Expand KV heads to match query heads for GQA.
- bridge.diffusion.models.common.nemotron_labs_diffusion_attention._get_llama_4_attn_scale(
- position_ids: torch.Tensor,
- beta: float,
- max_position_embeddings: int,
- class bridge.diffusion.models.common.nemotron_labs_diffusion_attention.Ministral3RotaryEmbedding(config, device=None)#
Bases:
torch.nn.ModuleRoPE with YARN support, driven by HF
rope_parametersconfig.Initialization
- inv_freq: torch.Tensor#
None
- static _compute_default_rope_parameters(
- config=None,
- device=None,
- seq_len=None,
- forward(x, position_ids)#
- class bridge.diffusion.models.common.nemotron_labs_diffusion_attention.NemotronLabsDiffusionAttention(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- layer_number: int,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
- attention_type: str,
- attention_dropout: float = None,
- softmax_scale: float = None,
- cp_comm_type: str = None,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
Bases:
megatron.core.transformer.module.MegatronModuleNemotronLabsDiffusionAttention for semi-block-diffusion (sbd_block_diff) training.
The sequence is doubled to
[xt | x0]where xt are noised tokens and x0 are clean tokens. RoPE is applied independently to each half. Llama-4 style query-key layer scaling is applied when configured.Initialization
- set_inference_mode(enabled: bool)#
Enable or disable inference mode. Clears cache on disable.
- set_inference_params(causal: bool, cache_enabled: bool)#
- clear_kv_cache()#
- forward(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor = None,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType = None,
- attention_bias: torch.Tensor = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- _inference_forward(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
SDPA-based forward for inference with KV cache support.
- Parameters:
query – [seq_len, batch, num_heads, head_dim] (Megatron layout)
key – [seq_len, batch, num_heads, head_dim] (Megatron layout)
value – [seq_len, batch, num_heads, head_dim] (Megatron layout)
The method:
Computes position IDs accounting for cached tokens
Applies RoPE (same module as training)
Applies Llama-4 attention scaling
Concatenates new K/V with cached K/V
Applies GQA repeat_kv
Runs SDPA with causal or bidirectional mask
Optionally stores the new K/V in cache