bridge.models.llama.llama4_utils#

Module Contents#

Classes#

Llama4SelfAttention

Updated Transformer Layer to enable skip rope in some layers

Functions#

chunkify_cu_seqlens

Splits cumulative sequence lengths into chunks based on attention_chunk_size.

chunkify

Pads and reshapes a tensor for chunked processing.

get_llama4_layer_spec

Get llama4 layer spec

API#

bridge.models.llama.llama4_utils.chunkify_cu_seqlens(
cu_seqlens,
cu_seqlens_padded,
attention_chunk_size,
)#

Splits cumulative sequence lengths into chunks based on attention_chunk_size.

Parameters:
  • cu_seqlens (list[int]) – List of cumulative sequence lengths.

  • cu_seqlens_padded (list[int]) – List of padded cumulative sequence lengths.

  • attention_chunk_size (int) – The maximum size of each chunk.

Returns:

A tuple containing the new chunked cumulative sequence lengths and the new chunked padded cumulative sequence lengths.

Return type:

Tuple[list[int], list[int]]

bridge.models.llama.llama4_utils.chunkify(x, attention_chunk_size)#

Pads and reshapes a tensor for chunked processing.

This function takes an input tensor x (typically representing query, key, or value in attention mechanisms) and pads its sequence dimension (dim 0) to be a multiple of attention_chunk_size. It then reshapes the tensor so that the sequence dimension is split into chunks, and the chunk dimension is combined with the batch dimension.

Parameters:
  • x (torch.Tensor) – Input tensor, expected shape [seq_length, batch_size, …].

  • attention_chunk_size (int) – The desired size of chunks along the sequence dimension.

Returns:

The reshaped tensor with shape [attention_chunk_size, num_chunks * batch_size, …].

Return type:

torch.Tensor

bridge.models.llama.llama4_utils.get_llama4_layer_spec(
config,
) megatron.core.transformer.spec_utils.ModuleSpec#

Get llama4 layer spec

class bridge.models.llama.llama4_utils.Llama4SelfAttention(
is_nope_layer=False,
attention_chunk_size=8192,
*args,
**kwargs,
)#

Bases: megatron.core.transformer.attention.SelfAttention

Updated Transformer Layer to enable skip rope in some layers

Initialization

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
) Tuple[torch.Tensor, torch.Tensor]#

Perform a forward pass through the attention module.

Parameters:
  • hidden_states (Tensor) – Hidden states.

  • attention_mask (Tensor) – Attention mask.

  • key_value_states (Optional[Tensor]) – Key/value states (for cross attention).

  • inference_context (Optional[BaseInferenceContext]) – Inference context that manages KV cache.

  • rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]) – Rotary embedding tensor(s).

  • rotary_pos_cos (Optional[Tensor]) – Rotary embedding cosine.

  • rotary_pos_sin (Optional[Tensor]) – Rotary embedding sine.

  • attention_bias (Optional[Tensor]) – Attention bias.

  • packed_seq_params (Optional[PackedSeqparams]) – Parameters used for THD format.

  • sequence_len_offset (Optional[int]) – Sequence length offset used for inference CUDA graphs.

Returns:

(Tuple[Tensor, Tensor]) Attention output and bias.