core.ssm.gated_delta_net#
Module Contents#
Classes#
Contains the module specs for the input linear, output norm, and output linear layers. |
|
Gated Delta Net (GDN) layer class |
Functions#
Builds a factory that splits a given ShardedTensor into several independent chunks. |
|
Get the local parameter for the current context parallel rank. |
|
All-to-all context parallel to hidden parallel. |
|
All-to-all hidden parallel to context parallel. |
|
Torch-native implementation of chunked gated delta rule for deterministic mode. Need this because FLA is not deterministic. |
Data#
API#
- core.ssm.gated_delta_net.logger#
‘getLogger(…)’
- class core.ssm.gated_delta_net.GatedDeltaNetSubmodules#
Contains the module specs for the input linear, output norm, and output linear layers.
- class core.ssm.gated_delta_net.GatedDeltaNet(
- config: megatron.core.transformer.TransformerConfig,
- submodules: core.ssm.gated_delta_net.GatedDeltaNetSubmodules,
- layer_number: int = None,
- bias: bool = False,
- conv_bias: bool = False,
- conv_init: Optional[float] = None,
- use_qk_l2norm: bool = True,
- A_init_range: Tuple[float, float] = (1, 16),
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
Bases:
megatron.core.transformer.module.MegatronModuleGated Delta Net (GDN) layer class
GDN layer takes input with size [s, b, h] and returns output of the same size.
Initialization
- Parameters:
config – The config of the model.
submodules – Contains the module specs for the input and output linear layers.
layer_number – The layer number of this GDN layer.
bias – Whether to use bias in the linear layers.
conv_bias – Whether to use bias in the causal convolution.
conv_init – The initialization range for the causal convolution weights.
use_qk_l2norm – Whether to use L2 normalization in the kernel of the gated delta rule.
A_init_range – The initialization range for the attention weights.
pg_collection – The required process groups to use for tensor model parallel and context parallel.
- reset_parameters()#
Reset the parameters.
- forward(
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- sequence_len_offset: Optional[int] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- **kwargs,
Perform a forward pass through the GDN module.
- Parameters:
hidden_states (Tensor) – Hidden states.
attention_mask (Tensor) – Attention mask.
inference_context (Optional[BaseInferenceContext]) – Inference context that manages KV cache.
packed_seq_params (Optional[PackedSeqparams]) – Parameters used for THD format.
sequence_len_offset (Optional[int]) – Sequence length offset used for inference CUDA graphs.
- Returns:
(Tuple[Tensor, Tensor]) GDN output and bias.
- _apply_gated_norm(x, gate)#
- _prepare_qkv_for_gated_delta_rule(
- qkv,
- gate,
- beta,
- alpha,
- batch,
- seq_len,
Prepare query, key, value, gate, beta, alpha tensors for gated delta rule. Fuses split, reshape, L2 norm, repeat_interleave, and contiguous operations.
- _compute_g_and_beta(A_log_local_cp, dt_bias_local_cp, alpha, beta)#
Compute g (decay) and beta (sigmoid) for gated delta rule. Fuses exp, softplus, mul, neg, and sigmoid operations.
- sharded_state_dict(
- prefix='',
- sharded_offsets=(),
- metadata=None,
- tp_group=None,
Provide a sharded state dictionary for distributed checkpointing.
- backward_dw()#
Execute weight gradient computation for all linear layers.
- _backward_in_proj()#
Computes weight gradients of input projection layer.
- _backward_out_proj()#
Computes weight gradients of output projection layer.
- core.ssm.gated_delta_net._split_tensor_factory(
- orig_sh_ten: megatron.core.dist_checkpointing.ShardedTensor,
- split_sections: List[int],
- split_names: List[str],
- split_dim: int,
Builds a factory that splits a given ShardedTensor into several independent chunks.
- core.ssm.gated_delta_net.get_parameter_local_cp(
- param: torch.Tensor,
- dim: int,
- cp_group: torch.distributed.ProcessGroup,
- split_sections: Optional[List[int]] = None,
Get the local parameter for the current context parallel rank.
- Parameters:
param (torch.Tensor) – The entire parameter to get the local parameter for.
dim (int) – The dimension to split the parameter along. Usually the dimension of head.
cp_group (torch.distributed.ProcessGroup) – The context parallel group.
split_sections (Optional[List[int]]) – If not None, first split the parameter along the dimension dim into sections, then get the local hidden parallel weights separately, finally concatenate the local hidden parallel weights along the dimension dim.
- Returns:
The local parameter for the current context parallel rank.
- Return type:
torch.Tensor
- core.ssm.gated_delta_net.tensor_a2a_cp2hp(
- tensor: torch.Tensor,
- seq_dim: int,
- head_dim: int,
- cp_group: torch.distributed.ProcessGroup,
- split_sections: Optional[List[int]] = None,
- undo_attention_load_balancing: bool = True,
All-to-all context parallel to hidden parallel.
- Parameters:
tensor (torch.Tensor) – The tensor to all-to-all. Currently only support (seq_len, batch, head_dim) shaped tensor.
seq_dim (int) – The dimension of sequence length. Currently only supports seq_dim == 0.
head_dim (int) – The dimension of head. Currently only supports head_dim == -1 or 2.
cp_group (torch.distributed.ProcessGroup) – The context parallel group.
split_sections (Optional[List[int]]) – If not None, split the tensor along the dimension head_dim into sections first, then do all-to-all for each section separately, finally concatenate the separated tensors along the dimension head_dim.
undo_attention_load_balancing (bool) – Whether to undo the attention load balancing of CP.
- Returns:
The all-to-all tensor.
- Return type:
torch.Tensor
- core.ssm.gated_delta_net.tensor_a2a_hp2cp(
- tensor: torch.Tensor,
- seq_dim: int,
- head_dim: int,
- cp_group: torch.distributed.ProcessGroup,
- split_sections: Optional[List[int]] = None,
- redo_attention_load_balancing: bool = True,
All-to-all hidden parallel to context parallel.
- Parameters:
tensor (torch.Tensor) – The tensor to all-to-all. Currently only support (seq_len, batch, head_dim) shaped tensor.
seq_dim (int) – The dimension of sequence length. Currently only supports seq_dim == 0.
head_dim (int) – The dimension of head. Currently only supports head_dim == -1 or 2.
cp_group (torch.distributed.ProcessGroup) – The context parallel group.
split_sections (Optional[List[int]]) – If not None, first split the tensor along the dimension head_dim into sections, then do all-to-all for each section separately, finally concatenate the separated tensors along the dimension head_dim.
redo_attention_load_balancing (bool) – Whether to redo the attention load balancing of HP.
- Returns:
The all-to-all tensor.
- Return type:
torch.Tensor
- core.ssm.gated_delta_net.torch_chunk_gated_delta_rule(
- query,
- key,
- value,
- g,
- beta,
- chunk_size=64,
- initial_state=None,
- output_final_state=False,
- use_qk_l2norm_in_kernel=False,
Torch-native implementation of chunked gated delta rule for deterministic mode. Need this because FLA is not deterministic.
Reference: https://github.com/huggingface/transformers/blob/144c8ce2809a2e21914017652700e1ecb450501e/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L470-L547