bridge.models.gemma.gemma3_provider
#
Module Contents#
Classes#
Configuration and provider for Megatron Core Gemma3 models. |
|
Gemma3 1B config |
|
Gemma3 4B config |
|
Gemma3 12B config |
|
Gemma3 27B config |
|
Gemma3 self attention. |
|
Gemma3 core attention. |
|
Gemma3 language token embedding. |
|
Gemma3 position rope embedding. |
|
Modified From TERowParallelLinear with an additional Post-LN. |
Functions#
Gemma3 custom layer spec. |
|
API#
- class bridge.models.gemma.gemma3_provider.Gemma3ModelProvider#
Bases:
megatron.bridge.models.gpt_provider.GPTModelProvider
Configuration and provider for Megatron Core Gemma3 models.
- seq_length: int#
131072
- position_embedding_type: str#
‘rope’
- rotary_base: tuple#
(10000, 1000000)
True
- normalization: str#
‘RMSNorm’
- layernorm_zero_centered_gamma: bool#
True
- layernorm_epsilon: float#
1e-06
- window_size: tuple#
512
- interleaved_attn_pattern: tuple#
(5, 1)
- attention_dropout: float#
0.0
0.0
- rope_scaling_factor: float#
1.0
- attention_backend: megatron.core.transformer.enums.AttnBackend#
None
- softmax_scale: float#
None
- gated_linear_unit: bool#
True
- add_bias_linear: bool#
False
- activation_func: Callable#
‘field(…)’
- is_vision_language: bool#
False
- flash_decode: bool#
False
- gradient_accumulation_fusion: bool#
False
- transformer_layer_spec: Union[megatron.core.transformer.ModuleSpec, Callable[[bridge.models.gemma.gemma3_provider.Gemma3ModelProvider], megatron.core.transformer.ModuleSpec]]#
‘field(…)’
- scatter_embedding_sequence_parallel: bool#
True
- apply_rope_fusion: bool#
‘field(…)’
- masked_softmax_fusion: bool#
‘field(…)’
- bf16: bool#
True
- fp16: bool#
False
- params_dtype: torch.dtype#
None
- autocast_dtype: torch.dtype#
None
- provide(
- pre_process=None,
- post_process=None,
- vp_stage=None,
Configure and instantiate a Megatron Core Gemma3 model.
Replaces the model’s embedding and rope with customized Gemma3 ones.
- Parameters:
pre_process – Whether to include pre-processing in the model
post_process – Whether to include post-processing in the model
vp_stage – Virtual pipeline stage
- Returns:
Configured Megatron Core GPT model instance
- Return type:
MCoreGPTModel
- class bridge.models.gemma.gemma3_provider.Gemma3ModelProvider1B#
Bases:
bridge.models.gemma.gemma3_provider.Gemma3ModelProvider
Gemma3 1B config
- is_vision_language: bool#
False
- num_layers: int#
26
1152
- num_attention_heads: int#
4
- num_query_groups: int#
1
- kv_channels: int#
256
6912
- window_size: int#
512
- rope_scaling_factor: float#
1.0
- seq_length: int#
32768
- bf16: bool#
True
- vocab_size: int#
262144
- class bridge.models.gemma.gemma3_provider.Gemma3ModelProvider4B#
Bases:
bridge.models.gemma.gemma3_provider.Gemma3ModelProvider
Gemma3 4B config
- is_vision_language: bool#
True
- num_layers: int#
34
2560
- num_attention_heads: int#
8
- num_query_groups: int#
4
- kv_channels: int#
256
10240
- window_size: int#
1024
- rope_scaling_factor: float#
8.0
- vocab_size: int#
262208
- class bridge.models.gemma.gemma3_provider.Gemma3ModelProvider12B#
Bases:
bridge.models.gemma.gemma3_provider.Gemma3ModelProvider
Gemma3 12B config
- is_vision_language: bool#
True
- num_layers: int#
48
3840
- num_attention_heads: int#
16
- num_query_groups: int#
8
- kv_channels: int#
256
15360
- window_size: int#
1024
- rope_scaling_factor: float#
8.0
- vocab_size: int#
262208
- class bridge.models.gemma.gemma3_provider.Gemma3ModelProvider27B#
Bases:
bridge.models.gemma.gemma3_provider.Gemma3ModelProvider
Gemma3 27B config
- is_vision_language: bool#
True
- num_layers: int#
62
5376
- num_attention_heads: int#
32
- num_query_groups: int#
16
- kv_channels: int#
128
- softmax_scale: int#
None
21504
- window_size: int#
1024
- rope_scaling_factor: float#
8.0
- vocab_size: int#
262208
- bridge.models.gemma.gemma3_provider.gemma3_layer_spec(config) megatron.core.transformer.ModuleSpec #
Gemma3 custom layer spec.
- class bridge.models.gemma.gemma3_provider.Gemma3SelfAttention#
Bases:
megatron.core.transformer.attention.SelfAttention
Gemma3 self attention.
Uses local rope embedding for local layers, global rope embedding for global layers.
- forward(
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
- rotary_pos_cos: Optional[torch.Tensor] = None,
- rotary_pos_sin: Optional[torch.Tensor] = None,
- attention_bias: Optional[torch.Tensor] = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- sequence_len_offset: Optional[int] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Switch to either local or global rope embedding before forward
- class bridge.models.gemma.gemma3_provider.Gemma3TEDotProductAttention(
- config: megatron.core.transformer.TransformerConfig,
- layer_number: int,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
- attention_type: str,
- attention_dropout: Optional[float] = None,
- **kwargs,
Bases:
bridge.models.gemma.gemma3_provider.TEDotProductAttention
Gemma3 core attention.
Switches between global and local sliding window attention based on the layer_number and pre-defined layer pattern.
Initialization
- class bridge.models.gemma.gemma3_provider.Gemma3LanguageModelEmbedding#
Bases:
megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding
Gemma3 language token embedding.
Adds a normalization to the embedding.
- forward(
- input_ids: torch.Tensor,
- position_ids: torch.Tensor,
- tokentype_ids: int = None,
Calculate embedding and normalize
- class bridge.models.gemma.gemma3_provider.Gemma3RotaryEmbedding(
- rope_scaling: bool = False,
- rope_scaling_factor: float = 8.0,
- rotary_base: int = 1000000,
- rotary_base_local: int = 10000,
- **kwargs,
Bases:
megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding
Gemma3 position rope embedding.
Calculates rope embeddings for both local and global attention layers.
Initialization
- forward(
- max_seq_len: int,
- offset: int = 0,
- packed_seq: bool = False,
Get global and local rope embedding
- bridge.models.gemma.gemma3_provider._is_local_attn_layer(
- layer_number: int,
- layer_pattern: Tuple[int, int],
- class bridge.models.gemma.gemma3_provider.TERowParallelLinearLayerNorm(
- input_size: int,
- output_size: int,
- *,
- config: megatron.core.transformer.TransformerConfig,
- **kwargs,
Bases:
bridge.models.gemma.gemma3_provider.TERowParallelLinear
Modified From TERowParallelLinear with an additional Post-LN.
Initialization
- forward(x)#
Forward with additional Post LN on output