core.ssm.gated_delta_net#

Module Contents#

Classes#

GatedDeltaNetSubmodules

Contains the module specs for the input linear, output norm, and output linear layers.

GatedDeltaNet

Gated Delta Net (GDN) layer class

Functions#

_split_tensor_factory

Builds a factory that splits a given ShardedTensor into several independent chunks.

get_parameter_local_cp

Get the local parameter for the current context parallel rank.

tensor_a2a_cp2hp

All-to-all context parallel to hidden parallel.

tensor_a2a_hp2cp

All-to-all hidden parallel to context parallel.

torch_chunk_gated_delta_rule

Torch-native implementation of chunked gated delta rule for deterministic mode. Need this because FLA is not deterministic.

Data#

API#

core.ssm.gated_delta_net.logger#

‘getLogger(…)’

class core.ssm.gated_delta_net.GatedDeltaNetSubmodules#

Contains the module specs for the input linear, output norm, and output linear layers.

in_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

out_norm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

out_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

class core.ssm.gated_delta_net.GatedDeltaNet(
config: megatron.core.transformer.TransformerConfig,
submodules: core.ssm.gated_delta_net.GatedDeltaNetSubmodules,
layer_number: int = None,
bias: bool = False,
conv_bias: bool = False,
conv_init: Optional[float] = None,
use_qk_l2norm: bool = True,
A_init_range: Tuple[float, float] = (1, 16),
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

Gated Delta Net (GDN) layer class

GDN layer takes input with size [s, b, h] and returns output of the same size.

Initialization

Parameters:
  • config – The config of the model.

  • submodules – Contains the module specs for the input and output linear layers.

  • layer_number – The layer number of this GDN layer.

  • bias – Whether to use bias in the linear layers.

  • conv_bias – Whether to use bias in the causal convolution.

  • conv_init – The initialization range for the causal convolution weights.

  • use_qk_l2norm – Whether to use L2 normalization in the kernel of the gated delta rule.

  • A_init_range – The initialization range for the attention weights.

  • pg_collection – The required process groups to use for tensor model parallel and context parallel.

reset_parameters()#

Reset the parameters.

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
**kwargs,
)#

Perform a forward pass through the GDN module.

Parameters:
  • hidden_states (Tensor) – Hidden states.

  • attention_mask (Tensor) – Attention mask.

  • inference_context (Optional[BaseInferenceContext]) – Inference context that manages KV cache.

  • packed_seq_params (Optional[PackedSeqparams]) – Parameters used for THD format.

  • sequence_len_offset (Optional[int]) – Sequence length offset used for inference CUDA graphs.

Returns:

(Tuple[Tensor, Tensor]) GDN output and bias.

_apply_gated_norm(x, gate)#
_prepare_qkv_for_gated_delta_rule(
qkv,
gate,
beta,
alpha,
batch,
seq_len,
)#

Prepare query, key, value, gate, beta, alpha tensors for gated delta rule. Fuses split, reshape, L2 norm, repeat_interleave, and contiguous operations.

_compute_g_and_beta(A_log_local_cp, dt_bias_local_cp, alpha, beta)#

Compute g (decay) and beta (sigmoid) for gated delta rule. Fuses exp, softplus, mul, neg, and sigmoid operations.

sharded_state_dict(
prefix='',
sharded_offsets=(),
metadata=None,
tp_group=None,
)#

Provide a sharded state dictionary for distributed checkpointing.

backward_dw()#

Execute weight gradient computation for all linear layers.

_backward_in_proj()#

Computes weight gradients of input projection layer.

_backward_out_proj()#

Computes weight gradients of output projection layer.

core.ssm.gated_delta_net._split_tensor_factory(
orig_sh_ten: megatron.core.dist_checkpointing.ShardedTensor,
split_sections: List[int],
split_names: List[str],
split_dim: int,
) megatron.core.dist_checkpointing.mapping.ShardedTensorFactory#

Builds a factory that splits a given ShardedTensor into several independent chunks.

core.ssm.gated_delta_net.get_parameter_local_cp(
param: torch.Tensor,
dim: int,
cp_group: torch.distributed.ProcessGroup,
split_sections: Optional[List[int]] = None,
) torch.Tensor#

Get the local parameter for the current context parallel rank.

Parameters:
  • param (torch.Tensor) – The entire parameter to get the local parameter for.

  • dim (int) – The dimension to split the parameter along. Usually the dimension of head.

  • cp_group (torch.distributed.ProcessGroup) – The context parallel group.

  • split_sections (Optional[List[int]]) – If not None, first split the parameter along the dimension dim into sections, then get the local hidden parallel weights separately, finally concatenate the local hidden parallel weights along the dimension dim.

Returns:

The local parameter for the current context parallel rank.

Return type:

torch.Tensor

core.ssm.gated_delta_net.tensor_a2a_cp2hp(
tensor: torch.Tensor,
seq_dim: int,
head_dim: int,
cp_group: torch.distributed.ProcessGroup,
split_sections: Optional[List[int]] = None,
undo_attention_load_balancing: bool = True,
)#

All-to-all context parallel to hidden parallel.

Parameters:
  • tensor (torch.Tensor) – The tensor to all-to-all. Currently only support (seq_len, batch, head_dim) shaped tensor.

  • seq_dim (int) – The dimension of sequence length. Currently only supports seq_dim == 0.

  • head_dim (int) – The dimension of head. Currently only supports head_dim == -1 or 2.

  • cp_group (torch.distributed.ProcessGroup) – The context parallel group.

  • split_sections (Optional[List[int]]) – If not None, split the tensor along the dimension head_dim into sections first, then do all-to-all for each section separately, finally concatenate the separated tensors along the dimension head_dim.

  • undo_attention_load_balancing (bool) – Whether to undo the attention load balancing of CP.

Returns:

The all-to-all tensor.

Return type:

torch.Tensor

core.ssm.gated_delta_net.tensor_a2a_hp2cp(
tensor: torch.Tensor,
seq_dim: int,
head_dim: int,
cp_group: torch.distributed.ProcessGroup,
split_sections: Optional[List[int]] = None,
redo_attention_load_balancing: bool = True,
)#

All-to-all hidden parallel to context parallel.

Parameters:
  • tensor (torch.Tensor) – The tensor to all-to-all. Currently only support (seq_len, batch, head_dim) shaped tensor.

  • seq_dim (int) – The dimension of sequence length. Currently only supports seq_dim == 0.

  • head_dim (int) – The dimension of head. Currently only supports head_dim == -1 or 2.

  • cp_group (torch.distributed.ProcessGroup) – The context parallel group.

  • split_sections (Optional[List[int]]) – If not None, first split the tensor along the dimension head_dim into sections, then do all-to-all for each section separately, finally concatenate the separated tensors along the dimension head_dim.

  • redo_attention_load_balancing (bool) – Whether to redo the attention load balancing of HP.

Returns:

The all-to-all tensor.

Return type:

torch.Tensor

core.ssm.gated_delta_net.torch_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=64,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
)#

Torch-native implementation of chunked gated delta rule for deterministic mode. Need this because FLA is not deterministic.

Reference: https://github.com/huggingface/transformers/blob/144c8ce2809a2e21914017652700e1ecb450501e/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L470-L547