bridge.models.gemma.gemma3_provider#

Module Contents#

Classes#

Gemma3ModelProvider

Configuration and provider for Megatron Core Gemma3 models.

Gemma3ModelProvider1B

Gemma3 1B config

Gemma3ModelProvider4B

Gemma3 4B config

Gemma3ModelProvider12B

Gemma3 12B config

Gemma3ModelProvider27B

Gemma3 27B config

Gemma3SelfAttention

Gemma3 self attention.

Gemma3TEDotProductAttention

Gemma3 core attention.

Gemma3LanguageModelEmbedding

Gemma3 language token embedding.

Gemma3RotaryEmbedding

Gemma3 position rope embedding.

TERowParallelLinearLayerNorm

Modified From TERowParallelLinear with an additional Post-LN.

Functions#

gemma3_layer_spec

Gemma3 custom layer spec.

_is_local_attn_layer

API#

class bridge.models.gemma.gemma3_provider.Gemma3ModelProvider#

Bases: megatron.bridge.models.gpt_provider.GPTModelProvider

Configuration and provider for Megatron Core Gemma3 models.

seq_length: int#

131072

position_embedding_type: str#

‘rope’

rotary_base: tuple#

(10000, 1000000)

share_embeddings_and_output_weights: bool#

True

normalization: str#

‘RMSNorm’

layernorm_zero_centered_gamma: bool#

True

layernorm_epsilon: float#

1e-06

window_size: tuple#

512

interleaved_attn_pattern: tuple#

(5, 1)

attention_dropout: float#

0.0

hidden_dropout: float#

0.0

rope_scaling_factor: float#

1.0

attention_backend: megatron.core.transformer.enums.AttnBackend#

None

softmax_scale: float#

None

gated_linear_unit: bool#

True

add_bias_linear: bool#

False

activation_func: Callable#

‘field(…)’

is_vision_language: bool#

False

flash_decode: bool#

False

gradient_accumulation_fusion: bool#

False

transformer_layer_spec: Union[megatron.core.transformer.ModuleSpec, Callable[[bridge.models.gemma.gemma3_provider.Gemma3ModelProvider], megatron.core.transformer.ModuleSpec]]#

‘field(…)’

scatter_embedding_sequence_parallel: bool#

True

apply_rope_fusion: bool#

‘field(…)’

masked_softmax_fusion: bool#

‘field(…)’

bf16: bool#

True

fp16: bool#

False

params_dtype: torch.dtype#

None

autocast_dtype: torch.dtype#

None

provide(
pre_process=None,
post_process=None,
vp_stage=None,
) megatron.core.models.gpt.GPTModel#

Configure and instantiate a Megatron Core Gemma3 model.

Replaces the model’s embedding and rope with customized Gemma3 ones.

Parameters:
  • pre_process – Whether to include pre-processing in the model

  • post_process – Whether to include post-processing in the model

  • vp_stage – Virtual pipeline stage

Returns:

Configured Megatron Core GPT model instance

Return type:

MCoreGPTModel

class bridge.models.gemma.gemma3_provider.Gemma3ModelProvider1B#

Bases: bridge.models.gemma.gemma3_provider.Gemma3ModelProvider

Gemma3 1B config

is_vision_language: bool#

False

num_layers: int#

26

hidden_size: int#

1152

num_attention_heads: int#

4

num_query_groups: int#

1

kv_channels: int#

256

ffn_hidden_size: int#

6912

window_size: int#

512

rope_scaling_factor: float#

1.0

seq_length: int#

32768

bf16: bool#

True

vocab_size: int#

262144

class bridge.models.gemma.gemma3_provider.Gemma3ModelProvider4B#

Bases: bridge.models.gemma.gemma3_provider.Gemma3ModelProvider

Gemma3 4B config

is_vision_language: bool#

True

num_layers: int#

34

hidden_size: int#

2560

num_attention_heads: int#

8

num_query_groups: int#

4

kv_channels: int#

256

ffn_hidden_size: int#

10240

window_size: int#

1024

rope_scaling_factor: float#

8.0

vocab_size: int#

262208

class bridge.models.gemma.gemma3_provider.Gemma3ModelProvider12B#

Bases: bridge.models.gemma.gemma3_provider.Gemma3ModelProvider

Gemma3 12B config

is_vision_language: bool#

True

num_layers: int#

48

hidden_size: int#

3840

num_attention_heads: int#

16

num_query_groups: int#

8

kv_channels: int#

256

ffn_hidden_size: int#

15360

window_size: int#

1024

rope_scaling_factor: float#

8.0

vocab_size: int#

262208

class bridge.models.gemma.gemma3_provider.Gemma3ModelProvider27B#

Bases: bridge.models.gemma.gemma3_provider.Gemma3ModelProvider

Gemma3 27B config

is_vision_language: bool#

True

num_layers: int#

62

hidden_size: int#

5376

num_attention_heads: int#

32

num_query_groups: int#

16

kv_channels: int#

128

softmax_scale: int#

None

ffn_hidden_size: int#

21504

window_size: int#

1024

rope_scaling_factor: float#

8.0

vocab_size: int#

262208

bridge.models.gemma.gemma3_provider.gemma3_layer_spec(config) megatron.core.transformer.ModuleSpec#

Gemma3 custom layer spec.

class bridge.models.gemma.gemma3_provider.Gemma3SelfAttention#

Bases: megatron.core.transformer.attention.SelfAttention

Gemma3 self attention.

Uses local rope embedding for local layers, global rope embedding for global layers.

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
) Tuple[torch.Tensor, torch.Tensor]#

Switch to either local or global rope embedding before forward

class bridge.models.gemma.gemma3_provider.Gemma3TEDotProductAttention(
config: megatron.core.transformer.TransformerConfig,
layer_number: int,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
attention_type: str,
attention_dropout: Optional[float] = None,
**kwargs,
)#

Bases: bridge.models.gemma.gemma3_provider.TEDotProductAttention

Gemma3 core attention.

Switches between global and local sliding window attention based on the layer_number and pre-defined layer pattern.

Initialization

class bridge.models.gemma.gemma3_provider.Gemma3LanguageModelEmbedding#

Bases: megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding

Gemma3 language token embedding.

Adds a normalization to the embedding.

forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
tokentype_ids: int = None,
) torch.Tensor#

Calculate embedding and normalize

class bridge.models.gemma.gemma3_provider.Gemma3RotaryEmbedding(
rope_scaling: bool = False,
rope_scaling_factor: float = 8.0,
rotary_base: int = 1000000,
rotary_base_local: int = 10000,
**kwargs,
)#

Bases: megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding

Gemma3 position rope embedding.

Calculates rope embeddings for both local and global attention layers.

Initialization

forward(
max_seq_len: int,
offset: int = 0,
packed_seq: bool = False,
) torch.Tensor#

Get global and local rope embedding

bridge.models.gemma.gemma3_provider._is_local_attn_layer(
layer_number: int,
layer_pattern: Tuple[int, int],
) bool#
class bridge.models.gemma.gemma3_provider.TERowParallelLinearLayerNorm(
input_size: int,
output_size: int,
*,
config: megatron.core.transformer.TransformerConfig,
**kwargs,
)#

Bases: bridge.models.gemma.gemma3_provider.TERowParallelLinear

Modified From TERowParallelLinear with an additional Post-LN.

Initialization

forward(x)#

Forward with additional Post LN on output