bridge.models.llama.llama4_utils
#
Module Contents#
Classes#
Updated Transformer Layer to enable skip rope in some layers |
Functions#
Splits cumulative sequence lengths into chunks based on attention_chunk_size. |
|
Pads and reshapes a tensor for chunked processing. |
|
Get llama4 layer spec |
API#
- bridge.models.llama.llama4_utils.chunkify_cu_seqlens(
- cu_seqlens,
- cu_seqlens_padded,
- attention_chunk_size,
Splits cumulative sequence lengths into chunks based on attention_chunk_size.
- Parameters:
cu_seqlens (list[int]) β List of cumulative sequence lengths.
cu_seqlens_padded (list[int]) β List of padded cumulative sequence lengths.
attention_chunk_size (int) β The maximum size of each chunk.
- Returns:
A tuple containing the new chunked cumulative sequence lengths and the new chunked padded cumulative sequence lengths.
- Return type:
Tuple[list[int], list[int]]
- bridge.models.llama.llama4_utils.chunkify(x, attention_chunk_size)#
Pads and reshapes a tensor for chunked processing.
This function takes an input tensor
x
(typically representing query, key, or value in attention mechanisms) and pads its sequence dimension (dim 0) to be a multiple ofattention_chunk_size
. It then reshapes the tensor so that the sequence dimension is split into chunks, and the chunk dimension is combined with the batch dimension.- Parameters:
x (torch.Tensor) β Input tensor, expected shape [seq_length, batch_size, β¦].
attention_chunk_size (int) β The desired size of chunks along the sequence dimension.
- Returns:
The reshaped tensor with shape [attention_chunk_size, num_chunks * batch_size, β¦].
- Return type:
torch.Tensor
- bridge.models.llama.llama4_utils.get_llama4_layer_spec(
- config,
Get llama4 layer spec
- class bridge.models.llama.llama4_utils.Llama4SelfAttention(
- is_nope_layer=False,
- attention_chunk_size=8192,
- *args,
- **kwargs,
Bases:
megatron.core.transformer.attention.SelfAttention
Updated Transformer Layer to enable skip rope in some layers
Initialization
- forward(
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
- rotary_pos_cos: Optional[torch.Tensor] = None,
- rotary_pos_sin: Optional[torch.Tensor] = None,
- attention_bias: Optional[torch.Tensor] = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- sequence_len_offset: Optional[int] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Perform a forward pass through the attention module.
- Parameters:
hidden_states (Tensor) β Hidden states.
attention_mask (Tensor) β Attention mask.
key_value_states (Optional[Tensor]) β Key/value states (for cross attention).
inference_context (Optional[BaseInferenceContext]) β Inference context that manages KV cache.
rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]) β Rotary embedding tensor(s).
rotary_pos_cos (Optional[Tensor]) β Rotary embedding cosine.
rotary_pos_sin (Optional[Tensor]) β Rotary embedding sine.
attention_bias (Optional[Tensor]) β Attention bias.
packed_seq_params (Optional[PackedSeqparams]) β Parameters used for THD format.
sequence_len_offset (Optional[int]) β Sequence length offset used for inference CUDA graphs.
- Returns:
(Tuple[Tensor, Tensor]) Attention output and bias.