bridge.diffusion.models.common.nemotron_labs_diffusion_attention#

NemotronLabsDiffusionAttention for sbd_block_diff diffusion LM training with YARN RoPE.

Module Contents#

Classes#

Ministral3RotaryEmbedding

RoPE with YARN support, driven by HF rope_parameters config.

NemotronLabsDiffusionAttention

NemotronLabsDiffusionAttention for semi-block-diffusion (sbd_block_diff) training.

Functions#

fused_flex_attention

Thin compiled wrapper around flex_attention.

rotate_half

Rotate the last half of the hidden dimension for RoPE.

apply_rotary_pos_emb

Apply rotary position embeddings to query and key tensors.

repeat_kv

Expand KV heads to match query heads for GQA.

_get_llama_4_attn_scale

API#

bridge.diffusion.models.common.nemotron_labs_diffusion_attention.fused_flex_attention(
q,
k,
v,
score_mod=None,
block_mask=None,
return_lse=False,
)#

Thin compiled wrapper around flex_attention.

bridge.diffusion.models.common.nemotron_labs_diffusion_attention.rotate_half(x)#

Rotate the last half of the hidden dimension for RoPE.

bridge.diffusion.models.common.nemotron_labs_diffusion_attention.apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1)#

Apply rotary position embeddings to query and key tensors.

bridge.diffusion.models.common.nemotron_labs_diffusion_attention.repeat_kv(hidden_states: torch.Tensor, n_rep: int) torch.Tensor#

Expand KV heads to match query heads for GQA.

bridge.diffusion.models.common.nemotron_labs_diffusion_attention._get_llama_4_attn_scale(
position_ids: torch.Tensor,
beta: float,
max_position_embeddings: int,
) torch.Tensor#
class bridge.diffusion.models.common.nemotron_labs_diffusion_attention.Ministral3RotaryEmbedding(config, device=None)#

Bases: torch.nn.Module

RoPE with YARN support, driven by HF rope_parameters config.

Initialization

inv_freq: torch.Tensor#

None

static _compute_default_rope_parameters(
config=None,
device=None,
seq_len=None,
)#
forward(x, position_ids)#
class bridge.diffusion.models.common.nemotron_labs_diffusion_attention.NemotronLabsDiffusionAttention(
config: megatron.core.transformer.transformer_config.TransformerConfig,
layer_number: int,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
attention_type: str,
attention_dropout: float = None,
softmax_scale: float = None,
cp_comm_type: str = None,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

NemotronLabsDiffusionAttention for semi-block-diffusion (sbd_block_diff) training.

The sequence is doubled to [xt | x0] where xt are noised tokens and x0 are clean tokens. RoPE is applied independently to each half. Llama-4 style query-key layer scaling is applied when configured.

Initialization

set_inference_mode(enabled: bool)#

Enable or disable inference mode. Clears cache on disable.

set_inference_params(causal: bool, cache_enabled: bool)#
clear_kv_cache()#
forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor = None,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType = None,
attention_bias: torch.Tensor = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
)#
_inference_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) torch.Tensor#

SDPA-based forward for inference with KV cache support.

Parameters:
  • query – [seq_len, batch, num_heads, head_dim] (Megatron layout)

  • key – [seq_len, batch, num_heads, head_dim] (Megatron layout)

  • value – [seq_len, batch, num_heads, head_dim] (Megatron layout)

The method:

  1. Computes position IDs accounting for cached tokens

  2. Applies RoPE (same module as training)

  3. Applies Llama-4 attention scaling

  4. Concatenates new K/V with cached K/V

  5. Applies GQA repeat_kv

  6. Runs SDPA with causal or bidirectional mask

  7. Optionally stores the new K/V in cache