bridge.models.gemma.gemma3_provider#

Module Contents#

Classes#

Gemma3ModelProvider

Configuration and provider for Megatron Core Gemma3 models.

Gemma3SelfAttention

Gemma3 self attention.

Gemma3TEDotProductAttention

Gemma3 core attention.

Gemma3LanguageModelEmbedding

Gemma3 language token embedding.

Gemma3RotaryEmbedding

Gemma3 position rope embedding.

TERowParallelLinearLayerNorm

Modified From TERowParallelLinear with an additional Post-LN.

Functions#

gemma3_layer_spec

Gemma3 custom layer spec.

_is_local_attn_layer

API#

class bridge.models.gemma.gemma3_provider.Gemma3ModelProvider#

Bases: megatron.bridge.models.gpt_provider.GPTModelProvider

Configuration and provider for Megatron Core Gemma3 models.

seq_length: int#

131072

position_embedding_type: str#

‘rope’

rotary_base: tuple#

(10000, 1000000)

share_embeddings_and_output_weights: bool#

True

normalization: str#

‘RMSNorm’

layernorm_zero_centered_gamma: bool#

True

layernorm_epsilon: float#

1e-06

qk_layernorm: bool#

True

window_size: tuple#

512

interleaved_attn_pattern: tuple#

(5, 1)

attention_dropout: float#

0.0

hidden_dropout: float#

0.0

rope_scaling_factor: float#

1.0

attention_backend: megatron.core.transformer.enums.AttnBackend#

None

softmax_scale: float#

None

gated_linear_unit: bool#

True

add_bias_linear: bool#

False

activation_func: Callable#

None

is_vision_language: bool#

False

flash_decode: bool#

False

transformer_layer_spec: Union[megatron.core.transformer.ModuleSpec, Callable[[bridge.models.gemma.gemma3_provider.Gemma3ModelProvider], megatron.core.transformer.ModuleSpec]]#

‘field(…)’

scatter_embedding_sequence_parallel: bool#

True

bf16: bool#

True

fp16: bool#

False

params_dtype: torch.dtype#

None

autocast_dtype: torch.dtype#

None

provide(
pre_process=None,
post_process=None,
vp_stage=None,
) megatron.core.models.gpt.GPTModel#

Configure and instantiate a Megatron Core Gemma3 model.

Replaces the model’s embedding and rope with customized Gemma3 ones.

Parameters:
  • pre_process – Whether to include pre-processing in the model

  • post_process – Whether to include post-processing in the model

  • vp_stage – Virtual pipeline stage

Returns:

Configured Megatron Core GPT model instance

Return type:

MCoreGPTModel

bridge.models.gemma.gemma3_provider.gemma3_layer_spec(config) megatron.core.transformer.ModuleSpec#

Gemma3 custom layer spec.

class bridge.models.gemma.gemma3_provider.Gemma3SelfAttention#

Bases: megatron.core.transformer.attention.SelfAttention

Gemma3 self attention.

Uses local rope embedding for local layers, global rope embedding for global layers.

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
rotary_pos_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_bias: Optional[torch.Tensor] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
) Tuple[torch.Tensor, torch.Tensor]#

Switch to either local or global rope embedding before forward

class bridge.models.gemma.gemma3_provider.Gemma3TEDotProductAttention(
config: megatron.core.transformer.TransformerConfig,
layer_number: int,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
attention_type: str,
attention_dropout: Optional[float] = None,
**kwargs,
)#

Bases: bridge.models.gemma.gemma3_provider.TEDotProductAttention

Gemma3 core attention.

Switches between global and local sliding window attention based on the layer_number and pre-defined layer pattern.

Initialization

class bridge.models.gemma.gemma3_provider.Gemma3LanguageModelEmbedding#

Bases: megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding

Gemma3 language token embedding.

Adds a normalization to the embedding.

forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
tokentype_ids: int = None,
) torch.Tensor#

Calculate embedding and normalize

class bridge.models.gemma.gemma3_provider.Gemma3RotaryEmbedding(
rope_scaling: bool = False,
rope_scaling_factor: float = 8.0,
rotary_base: int = 1000000,
rotary_base_local: int = 10000,
**kwargs,
)#

Bases: megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding

Gemma3 position rope embedding.

Calculates rope embeddings for both local and global attention layers.

Initialization

forward(
max_seq_len: int,
offset: int = 0,
packed_seq: bool = False,
cp_group: torch.distributed.ProcessGroup | None = None,
) torch.Tensor#

Get global and local rope embedding.

Note: Caching is bypassed when cp_group is provided since ProcessGroup is unhashable.

_forward_cached(
max_seq_len: int,
offset: int = 0,
packed_seq: bool = False,
) torch.Tensor#

Cached forward for hashable parameters only.

bridge.models.gemma.gemma3_provider._is_local_attn_layer(
layer_number: int,
layer_pattern: Tuple[int, int],
) bool#
class bridge.models.gemma.gemma3_provider.TERowParallelLinearLayerNorm(
input_size: int,
output_size: int,
*,
config: megatron.core.transformer.TransformerConfig,
**kwargs,
)#

Bases: bridge.models.gemma.gemma3_provider.TERowParallelLinear

Modified From TERowParallelLinear with an additional Post-LN.

Initialization

forward(x)#

Forward with additional Post LN on output