bridge.models.gemma.gemma2_provider
#
Module Contents#
Classes#
Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. We use the following notation: h: hidden size n: number of attention heads p: number of tensor model parallel partitions b: batch size s: sequence length |
|
Modified From TERowParallelLinear with an additional Post-LN. |
|
Extends from ColumnParallelLinear with logit soft capping. |
|
Configuration class for Gemma2 models. Extends GPTModelProvider with specific settings optimized for Gemma2 architectures. Includes configurations for normalization, activation functions, and various Gemma2-specific options like attention logit softcapping and sliding window attention. |
|
Configuration for a 2B parameter Gemma2 model. Specific configuration for the 2B Gemma2 model with 26 layers, 2304 hidden size, and 8 attention heads. |
|
Configuration for a 9B parameter Gemma2 model. Specific configuration for the 9B Gemma2 model with 42 layers, 3584 hidden size, and 16 attention heads. |
|
Configuration for a 27B parameter Gemma2 model. Specific configuration for the 27B Gemma2 model with 46 layers, 4608 hidden size, and 32 attention heads. |
Functions#
Prevents logits from growing excessively by scaling them to a fixed range |
|
Create the equivalent attention mask for SWA in [seq_q, seq_kv] shape |
|
Gemma2-specific layer specification. |
API#
- class bridge.models.gemma.gemma2_provider.Gemma2DotProductAttention(
- config: megatron.core.transformer.TransformerConfig,
- layer_number: int,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
- attention_type: str,
- attention_dropout: float = None,
- **kwargs,
Bases:
megatron.core.transformer.MegatronModule
Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. We use the following notation: h: hidden size n: number of attention heads p: number of tensor model parallel partitions b: batch size s: sequence length
Initialization
- forward(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType = None,
- packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams = None,
- **kwargs,
Forward. Modified from mcore.transformer.dot_product_attention to support Gemma2-specific final_logit_softcapping.
- class bridge.models.gemma.gemma2_provider.TERowParallelLinearLayerNorm(
- input_size: int,
- output_size: int,
- *,
- config: megatron.core.transformer.TransformerConfig,
- **kwargs,
Bases:
megatron.core.extensions.transformer_engine.TERowParallelLinear
Modified From TERowParallelLinear with an additional Post-LN.
Initialization
- forward(x)#
Forward with additional Post LN on output
- class bridge.models.gemma.gemma2_provider.Gemma2OutputLayer#
Bases:
megatron.core.tensor_parallel.ColumnParallelLinear
Extends from ColumnParallelLinear with logit soft capping.
- forward(*args, **kwargs)#
Forward with logit soft capping.
- bridge.models.gemma.gemma2_provider.logit_softcapping(
- logits: torch.Tensor,
- scale: Optional[float],
Prevents logits from growing excessively by scaling them to a fixed range
- bridge.models.gemma.gemma2_provider.get_swa(
- seq_q: int,
- seq_kv: int,
- window_size: tuple[int, int],
Create the equivalent attention mask for SWA in [seq_q, seq_kv] shape
- bridge.models.gemma.gemma2_provider.gemma2_layer_spec(
- config: megatron.bridge.models.gpt_provider.GPTModelProvider,
Gemma2-specific layer specification.
- class bridge.models.gemma.gemma2_provider.Gemma2ModelProvider#
Bases:
megatron.bridge.models.gpt_provider.GPTModelProvider
Configuration class for Gemma2 models. Extends GPTModelProvider with specific settings optimized for Gemma2 architectures. Includes configurations for normalization, activation functions, and various Gemma2-specific options like attention logit softcapping and sliding window attention.
- normalization: str#
‘RMSNorm’
- activation_func: Callable#
None
- gated_linear_unit: bool#
True
- position_embedding_type: str#
‘rope’
- add_bias_linear: bool#
False
- seq_length: int#
8192
- kv_channels: int#
256
- attention_dropout: float#
0.0
0.0
True
- layernorm_zero_centered_gamma: bool#
True
- layernorm_epsilon: float#
1e-06
- rotary_base: float#
10000
- window_size: tuple[int, int]#
(4096, 0)
- vocab_size: int#
256000
- gradient_accumulation_fusion: bool#
False
- transformer_layer_spec: Union[megatron.core.transformer.ModuleSpec, Callable[[megatron.bridge.models.gpt_provider.GPTModelProvider], megatron.core.transformer.ModuleSpec]]#
None
- query_pre_attn_scalar: int#
224
- attn_logit_softcapping: float#
50.0
- final_logit_softcapping: float#
30.0
- provide(
- pre_process=None,
- post_process=None,
- vp_stage=None,
Configure and instantiate a Megatron Core Gemma2 model. Extends the base configuration with Gemma2-specific embedding scaling and output layer modifications.
- Parameters:
pre_process – Whether to include pre-processing in the model
post_process – Whether to include post-processing in the model
vp_stage – Virtual pipeline stage
tokenizer – Tokenizer used with the model
- Returns:
Configured Megatron Core GPT model instance
- Return type:
MCoreGPTModel
- class bridge.models.gemma.gemma2_provider.Gemma2ModelProvider2B#
Bases:
bridge.models.gemma.gemma2_provider.Gemma2ModelProvider
Configuration for a 2B parameter Gemma2 model. Specific configuration for the 2B Gemma2 model with 26 layers, 2304 hidden size, and 8 attention heads.
- num_layers: int#
26
2304
- num_attention_heads: int#
8
- num_query_groups: int#
4
9216
- query_pre_attn_scalar: int#
256
- class bridge.models.gemma.gemma2_provider.Gemma2ModelProvider9B#
Bases:
bridge.models.gemma.gemma2_provider.Gemma2ModelProvider
Configuration for a 9B parameter Gemma2 model. Specific configuration for the 9B Gemma2 model with 42 layers, 3584 hidden size, and 16 attention heads.
- num_layers: int#
42
3584
- num_attention_heads: int#
16
- num_query_groups: int#
8
14336
- query_pre_attn_scalar: int#
256
- class bridge.models.gemma.gemma2_provider.Gemma2ModelProvider27B#
Bases:
bridge.models.gemma.gemma2_provider.Gemma2ModelProvider
Configuration for a 27B parameter Gemma2 model. Specific configuration for the 27B Gemma2 model with 46 layers, 4608 hidden size, and 32 attention heads.
- num_layers: int#
46
4608
- num_attention_heads: int#
32
- num_query_groups: int#
16
- kv_channels: int#
128
36864
- query_pre_attn_scalar: int#
144