core.ssm.mamba_context_parallel#
Module Contents#
Classes#
This class provides the following functionality related to Mamba “all-to-all” context parallel: |
Functions#
Perform AlltoAll communication on a context parallel group, transform the input tensor from shape [global-sequence/context-parallel-size, batch, local-hidden] to [global-sequence, batch, local-hidden/context-parallel-size]. |
|
Perform AlltoAll communication on a context parallel group, transform the input tensor from shape [global-sequence, batch, local-hidden/context-parallel-size] to [global-sequence/context-parallel-size, batch, local-hidden]. |
|
Undoes the context parallel attention load balancing For example, for cp_size=3, converts 162534 to 123456 for sequential processing by the convolution and SSM. |
|
Redo the context parallel attention load balancing For example, for cp_size=3, converts 123456 to 162534 for efficient processing by attention. |
API#
- class core.ssm.mamba_context_parallel.MambaContextParallel(
- cp_group: torch.distributed.ProcessGroup,
- d_inner_local_tp: int,
- nheads_local_tp: int,
- ngroups_local_tp: int,
- d_state: int,
- conv1d_cp1: torch.nn.Conv1d,
- dt_bias_cp1: torch.Tensor,
- A_log_cp1: torch.Tensor,
- D_cp1: torch.Tensor,
- D_has_hdim: bool,
This class provides the following functionality related to Mamba “all-to-all” context parallel:
Error checking, and creation of, relevant parameters (e.g. nheads_local_tpcp)
Collective operations on activations, on each context parallel rank, before and after the convolution and SSM
A convolution operator that uses the correct slices of trainable variables on the current context parallel rank
Sliced views of relevant trainable variables for the current context parallel rank
This class is intentionally not a sub-class of MegatronModule. This class does not contain any trainable variables of its own and should not be involved in any checkpoint loading or saving.
- Parameters:
cp_group (torch.distributed.ProcessGroup) – The process group to use for context parallel.
d_inner_local_tp (int) – d_inner on the current tp rank
nheads_local_tp (int) – nheads on the current tp rank
ngroups_local_tp (int) – ngroups on the current tp rank
d_state (int) – Mamba d_state
conv1d_cp1 (nn.Conv1d) – The conv1d op which would be applied on this tp rank if cp_size was 1
dt_bias_cp1 (torch.Tensor) – The dt_bias parameter which would be used on this tp rank if cp_size was 1
A_log_cp1 (torch.Tensor) – The A_log parameter which would be used on this tp rank if cp_size was 1
D_cp1 (torch.Tensor) – The D parameter which would be used on this tp rank if cp_size was 1
D_has_hdim (bool) – D parameter is sized to hidden dimension, rather than being per-head
Initialization
- pre_conv_ssm(input_: torch.Tensor) torch.Tensor#
Method to be applied before the convolution and SSM
- post_conv_ssm(input_: torch.Tensor) torch.Tensor#
Method to be applied after the convolution and SSM
- conv1d(input_: torch.Tensor) torch.Tensor#
Performs a conv1d on one context parallel rank, using slices of the weight and bias from the convolution that would be run when cp_size=1
- conv1d_channels()#
Returns the number of convolution channels on the current context parallel rank
- get_conv1d_weight() torch.Tensor#
Returns a slice of the conv1d weight relevant to the current context parallel rank
- get_conv1d_bias() torch.Tensor#
Returns a slice of the conv1d bias relevant to the current context parallel rank
- get_dt_bias() torch.Tensor#
Returns a slice of dt_bias relevant to the current context parallel rank
- get_A_log() torch.Tensor#
Returns a slice of A_log relevant to the current context parallel rank
- get_D() torch.Tensor#
Returns a slice of D relevant to the current context parallel rank
- _slice_conv_param(param: torch.Tensor) torch.Tensor#
Slices a cp_size=1 conv1d parameter (either weight or bias) along the first dimension, returning the parts of the parameter needed for convolution on the current context parallel rank. Parameter slicing is done in the forward path so that gradients will backpropagate to the cp_size=1 parameters.
- _slice_vector_param(
- param: torch.Tensor,
- has_hdim: bool = False,
Slices a cp_size=1 vector parameter along the first dimension, returning the part of the parameter needed on the current context parallel rank. Parameter slicing is done in the forward path so that gradients will backpropagate to the cp_size=1 parameters.
- core.ssm.mamba_context_parallel._all_to_all_cp2hp(
- input_: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
Perform AlltoAll communication on a context parallel group, transform the input tensor from shape [global-sequence/context-parallel-size, batch, local-hidden] to [global-sequence, batch, local-hidden/context-parallel-size].
- Parameters:
input_ (torch.Tensor) – The input tensor, which is partitioned along the sequence dimension
cp_group (torch.distributed.ProcessGroup) – Process group to use for context parallel
- Returns:
The output tensor with shape [global-sequence, batch, local-hidden/context-parallel-size].
- Return type:
torch.Tensor
- core.ssm.mamba_context_parallel._all_to_all_hp2cp(
- input_: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
Perform AlltoAll communication on a context parallel group, transform the input tensor from shape [global-sequence, batch, local-hidden/context-parallel-size] to [global-sequence/context-parallel-size, batch, local-hidden].
- Parameters:
input_ (torch.Tensor) – The input tensor, which is partitioned along the hidden dimension
cp_group (torch.distributed.ProcessGroup) – Process group to use for context parallel
- Returns:
The output tensor with shape [global-sequence/context-parallel-size, batch, local-hidden].
- Return type:
torch.Tensor
- core.ssm.mamba_context_parallel._undo_attention_load_balancing(
- input_: torch.Tensor,
- cp_size: int,
Undoes the context parallel attention load balancing For example, for cp_size=3, converts 162534 to 123456 for sequential processing by the convolution and SSM.
- core.ssm.mamba_context_parallel._redo_attention_load_balancing(
- input_: torch.Tensor,
- cp_size: int,
Redo the context parallel attention load balancing For example, for cp_size=3, converts 123456 to 162534 for efficient processing by attention.