core.ssm.gated_delta_net#

Module Contents#

Classes#

GatedDeltaNetSubmodules

Contains the module specs for the input linear, output norm, and output linear layers.

GatedDeltaNet

Gated Delta Net (GDN) layer class

Functions#

_split_tensor_factory

Builds a factory that splits a given ShardedTensor into several independent chunks.

torch_chunk_gated_delta_rule

Torch-native implementation of chunked gated delta rule for deterministic mode. Need this because FLA is not deterministic.

Data#

API#

core.ssm.gated_delta_net.logger#

‘getLogger(…)’

class core.ssm.gated_delta_net.GatedDeltaNetSubmodules#

Contains the module specs for the input linear, output norm, and output linear layers.

in_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

out_norm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

out_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

class core.ssm.gated_delta_net.GatedDeltaNet(
config: megatron.core.transformer.TransformerConfig,
submodules: core.ssm.gated_delta_net.GatedDeltaNetSubmodules,
layer_number: int = None,
bias: bool = False,
conv_bias: bool = False,
conv_init: Optional[float] = None,
use_qk_l2norm: bool = True,
A_init_range: Tuple[float, float] = (1, 16),
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

Gated Delta Net (GDN) layer class

GDN layer takes input with size [s, b, h] and returns output of the same size.

Initialization

Parameters:
  • config – The config of the model.

  • submodules – Contains the module specs for the input and output linear layers.

  • layer_number – The layer number of this GDN layer.

  • bias – Whether to use bias in the linear layers.

  • conv_bias – Whether to use bias in the causal convolution.

  • conv_init – The initialization range for the causal convolution weights.

  • use_qk_l2norm – Whether to use L2 normalization in the kernel of the gated delta rule.

  • A_init_range – The initialization range for the attention weights.

  • pg_collection – The required process groups to use for tensor model parallel and context parallel.

reset_parameters()#

Reset the parameters.

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
attention_bias: Optional[torch.Tensor] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
**kwargs,
)#

Perform a forward pass through the GDN module.

Parameters:
  • hidden_states (Tensor) – Hidden states.

  • attention_mask (Tensor) – Attention mask.

  • key_value_states (Optional[Tensor]) – Key/value states (for cross attention).

  • inference_context (Optional[BaseInferenceContext]) – Inference context that manages KV cache.

  • attention_bias (Optional[Tensor]) – Attention bias.

  • packed_seq_params (Optional[PackedSeqparams]) – Parameters used for THD format.

  • sequence_len_offset (Optional[int]) – Sequence length offset used for inference CUDA graphs.

Returns:

(Tuple[Tensor, Tensor]) GDN output and bias.

_apply_gated_norm(x, gate)#
sharded_state_dict(
prefix='',
sharded_offsets=(),
metadata=None,
tp_group=None,
)#

Provide a sharded state dictionary for distributed checkpointing.

core.ssm.gated_delta_net._split_tensor_factory(
orig_sh_ten: megatron.core.dist_checkpointing.ShardedTensor,
split_sections: List[int],
split_names: List[str],
split_dim: int,
) megatron.core.dist_checkpointing.mapping.ShardedTensorFactory#

Builds a factory that splits a given ShardedTensor into several independent chunks.

core.ssm.gated_delta_net.torch_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=64,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
)#

Torch-native implementation of chunked gated delta rule for deterministic mode. Need this because FLA is not deterministic.

Reference: https://github.com/huggingface/transformers/blob/144c8ce2809a2e21914017652700e1ecb450501e/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L470-L547