bridge.models.gemma.gemma2_provider#

Module Contents#

Classes#

Gemma2DotProductAttention

Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. We use the following notation: h: hidden size n: number of attention heads p: number of tensor model parallel partitions b: batch size s: sequence length

TERowParallelLinearLayerNorm

Modified From TERowParallelLinear with an additional Post-LN.

Gemma2OutputLayer

Extends from ColumnParallelLinear with logit soft capping.

Gemma2ModelProvider

Configuration class for Gemma2 models. Extends GPTModelProvider with specific settings optimized for Gemma2 architectures. Includes configurations for normalization, activation functions, and various Gemma2-specific options like attention logit softcapping and sliding window attention.

Gemma2ModelProvider2B

Configuration for a 2B parameter Gemma2 model. Specific configuration for the 2B Gemma2 model with 26 layers, 2304 hidden size, and 8 attention heads.

Gemma2ModelProvider9B

Configuration for a 9B parameter Gemma2 model. Specific configuration for the 9B Gemma2 model with 42 layers, 3584 hidden size, and 16 attention heads.

Gemma2ModelProvider27B

Configuration for a 27B parameter Gemma2 model. Specific configuration for the 27B Gemma2 model with 46 layers, 4608 hidden size, and 32 attention heads.

Functions#

logit_softcapping

Prevents logits from growing excessively by scaling them to a fixed range

get_swa

Create the equivalent attention mask for SWA in [seq_q, seq_kv] shape

gemma2_layer_spec

Gemma2-specific layer specification.

API#

class bridge.models.gemma.gemma2_provider.Gemma2DotProductAttention(
config: megatron.core.transformer.TransformerConfig,
layer_number: int,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
attention_type: str,
attention_dropout: float = None,
**kwargs,
)#

Bases: megatron.core.transformer.MegatronModule

Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. We use the following notation: h: hidden size n: number of attention heads p: number of tensor model parallel partitions b: batch size s: sequence length

Initialization

forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType = None,
packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams = None,
**kwargs,
)#

Forward. Modified from mcore.transformer.dot_product_attention to support Gemma2-specific final_logit_softcapping.

class bridge.models.gemma.gemma2_provider.TERowParallelLinearLayerNorm(
input_size: int,
output_size: int,
*,
config: megatron.core.transformer.TransformerConfig,
**kwargs,
)#

Bases: megatron.core.extensions.transformer_engine.TERowParallelLinear

Modified From TERowParallelLinear with an additional Post-LN.

Initialization

forward(x)#

Forward with additional Post LN on output

class bridge.models.gemma.gemma2_provider.Gemma2OutputLayer#

Bases: megatron.core.tensor_parallel.ColumnParallelLinear

Extends from ColumnParallelLinear with logit soft capping.

forward(*args, **kwargs)#

Forward with logit soft capping.

bridge.models.gemma.gemma2_provider.logit_softcapping(
logits: torch.Tensor,
scale: Optional[float],
) torch.Tensor#

Prevents logits from growing excessively by scaling them to a fixed range

bridge.models.gemma.gemma2_provider.get_swa(
seq_q: int,
seq_kv: int,
window_size: tuple[int, int],
) torch.Tensor#

Create the equivalent attention mask for SWA in [seq_q, seq_kv] shape

bridge.models.gemma.gemma2_provider.gemma2_layer_spec(
config: megatron.bridge.models.gpt_provider.GPTModelProvider,
) megatron.core.transformer.ModuleSpec#

Gemma2-specific layer specification.

class bridge.models.gemma.gemma2_provider.Gemma2ModelProvider#

Bases: megatron.bridge.models.gpt_provider.GPTModelProvider

Configuration class for Gemma2 models. Extends GPTModelProvider with specific settings optimized for Gemma2 architectures. Includes configurations for normalization, activation functions, and various Gemma2-specific options like attention logit softcapping and sliding window attention.

normalization: str#

‘RMSNorm’

activation_func: Callable#

None

gated_linear_unit: bool#

True

position_embedding_type: str#

‘rope’

add_bias_linear: bool#

False

seq_length: int#

8192

kv_channels: int#

256

attention_dropout: float#

0.0

hidden_dropout: float#

0.0

share_embeddings_and_output_weights: bool#

True

layernorm_zero_centered_gamma: bool#

True

layernorm_epsilon: float#

1e-06

rotary_base: float#

10000

window_size: tuple[int, int]#

(4096, 0)

vocab_size: int#

256000

gradient_accumulation_fusion: bool#

False

transformer_layer_spec: Union[megatron.core.transformer.ModuleSpec, Callable[[megatron.bridge.models.gpt_provider.GPTModelProvider], megatron.core.transformer.ModuleSpec]]#

None

query_pre_attn_scalar: int#

224

attn_logit_softcapping: float#

50.0

final_logit_softcapping: float#

30.0

provide(
pre_process=None,
post_process=None,
vp_stage=None,
) megatron.core.models.gpt.GPTModel#

Configure and instantiate a Megatron Core Gemma2 model. Extends the base configuration with Gemma2-specific embedding scaling and output layer modifications.

Parameters:
  • pre_process – Whether to include pre-processing in the model

  • post_process – Whether to include post-processing in the model

  • vp_stage – Virtual pipeline stage

  • tokenizer – Tokenizer used with the model

Returns:

Configured Megatron Core GPT model instance

Return type:

MCoreGPTModel

class bridge.models.gemma.gemma2_provider.Gemma2ModelProvider2B#

Bases: bridge.models.gemma.gemma2_provider.Gemma2ModelProvider

Configuration for a 2B parameter Gemma2 model. Specific configuration for the 2B Gemma2 model with 26 layers, 2304 hidden size, and 8 attention heads.

num_layers: int#

26

hidden_size: int#

2304

num_attention_heads: int#

8

num_query_groups: int#

4

ffn_hidden_size: int#

9216

query_pre_attn_scalar: int#

256

class bridge.models.gemma.gemma2_provider.Gemma2ModelProvider9B#

Bases: bridge.models.gemma.gemma2_provider.Gemma2ModelProvider

Configuration for a 9B parameter Gemma2 model. Specific configuration for the 9B Gemma2 model with 42 layers, 3584 hidden size, and 16 attention heads.

num_layers: int#

42

hidden_size: int#

3584

num_attention_heads: int#

16

num_query_groups: int#

8

ffn_hidden_size: int#

14336

query_pre_attn_scalar: int#

256

class bridge.models.gemma.gemma2_provider.Gemma2ModelProvider27B#

Bases: bridge.models.gemma.gemma2_provider.Gemma2ModelProvider

Configuration for a 27B parameter Gemma2 model. Specific configuration for the 27B Gemma2 model with 46 layers, 4608 hidden size, and 32 attention heads.

num_layers: int#

46

hidden_size: int#

4608

num_attention_heads: int#

32

num_query_groups: int#

16

kv_channels: int#

128

ffn_hidden_size: int#

36864

query_pre_attn_scalar: int#

144