bridge.models.gemma.modules#

Module Contents#

Classes#

EmbeddingScalingMixin

A mixin class for scaling embeddings in Megatron GPT. The scaling is applied only if the configuration (accessible via self.config) includes apply_embedding_scaling set to True.

Functions#

extend_instance

Apply mixins to a class instance after creation

API#

bridge.models.gemma.modules.extend_instance(obj, mixin)#

Apply mixins to a class instance after creation

class bridge.models.gemma.modules.EmbeddingScalingMixin#

Bases: torch.nn.Module

A mixin class for scaling embeddings in Megatron GPT. The scaling is applied only if the configuration (accessible via self.config) includes apply_embedding_scaling set to True.

forward(**kwargs)#

Forward pass that scales the output embeddings from the forward method of the superclass by the square root of the hidden size specified in the configuration.