core.ssm.gated_delta_net#
Module Contents#
Classes#
Contains the module specs for the input linear, output norm, and output linear layers. |
|
Gated Delta Net (GDN) layer class |
Functions#
Builds a factory that splits a given ShardedTensor into several independent chunks. |
|
Torch-native implementation of chunked gated delta rule for deterministic mode. Need this because FLA is not deterministic. |
Data#
API#
- core.ssm.gated_delta_net.logger#
‘getLogger(…)’
- class core.ssm.gated_delta_net.GatedDeltaNetSubmodules#
Contains the module specs for the input linear, output norm, and output linear layers.
- in_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- out_norm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- out_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- class core.ssm.gated_delta_net.GatedDeltaNet(
- config: megatron.core.transformer.TransformerConfig,
- submodules: core.ssm.gated_delta_net.GatedDeltaNetSubmodules,
- layer_number: int = None,
- bias: bool = False,
- conv_bias: bool = False,
- conv_init: Optional[float] = None,
- use_qk_l2norm: bool = True,
- A_init_range: Tuple[float, float] = (1, 16),
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
Bases:
megatron.core.transformer.module.MegatronModuleGated Delta Net (GDN) layer class
GDN layer takes input with size [s, b, h] and returns output of the same size.
Initialization
- Parameters:
config – The config of the model.
submodules – Contains the module specs for the input and output linear layers.
layer_number – The layer number of this GDN layer.
bias – Whether to use bias in the linear layers.
conv_bias – Whether to use bias in the causal convolution.
conv_init – The initialization range for the causal convolution weights.
use_qk_l2norm – Whether to use L2 normalization in the kernel of the gated delta rule.
A_init_range – The initialization range for the attention weights.
pg_collection – The required process groups to use for tensor model parallel and context parallel.
- reset_parameters()#
Reset the parameters.
- forward(
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- attention_bias: Optional[torch.Tensor] = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- sequence_len_offset: Optional[int] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- **kwargs,
Perform a forward pass through the GDN module.
- Parameters:
hidden_states (Tensor) – Hidden states.
attention_mask (Tensor) – Attention mask.
key_value_states (Optional[Tensor]) – Key/value states (for cross attention).
inference_context (Optional[BaseInferenceContext]) – Inference context that manages KV cache.
attention_bias (Optional[Tensor]) – Attention bias.
packed_seq_params (Optional[PackedSeqparams]) – Parameters used for THD format.
sequence_len_offset (Optional[int]) – Sequence length offset used for inference CUDA graphs.
- Returns:
(Tuple[Tensor, Tensor]) GDN output and bias.
- _apply_gated_norm(x, gate)#
- sharded_state_dict(
- prefix='',
- sharded_offsets=(),
- metadata=None,
- tp_group=None,
Provide a sharded state dictionary for distributed checkpointing.
- core.ssm.gated_delta_net._split_tensor_factory(
- orig_sh_ten: megatron.core.dist_checkpointing.ShardedTensor,
- split_sections: List[int],
- split_names: List[str],
- split_dim: int,
Builds a factory that splits a given ShardedTensor into several independent chunks.
- core.ssm.gated_delta_net.torch_chunk_gated_delta_rule(
- query,
- key,
- value,
- g,
- beta,
- chunk_size=64,
- initial_state=None,
- output_final_state=False,
- use_qk_l2norm_in_kernel=False,
Torch-native implementation of chunked gated delta rule for deterministic mode. Need this because FLA is not deterministic.
Reference: https://github.com/huggingface/transformers/blob/144c8ce2809a2e21914017652700e1ecb450501e/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L470-L547