core.ssm.mamba_context_parallel#

Module Contents#

Classes#

MambaContextParallel

This class provides the following functionality related to Mamba “all-to-all” context parallel:

Functions#

_all_to_all_cp2hp

Perform AlltoAll communication on a context parallel group, transform the input tensor from shape [global-sequence/context-parallel-size, batch, local-hidden] to [global-sequence, batch, local-hidden/context-parallel-size].

_all_to_all_hp2cp

Perform AlltoAll communication on a context parallel group, transform the input tensor from shape [global-sequence, batch, local-hidden/context-parallel-size] to [global-sequence/context-parallel-size, batch, local-hidden].

_undo_attention_load_balancing

Undoes the context parallel attention load balancing For example, for cp_size=3, converts 162534 to 123456 for sequential processing by the convolution and SSM.

_redo_attention_load_balancing

Redo the context parallel attention load balancing For example, for cp_size=3, converts 123456 to 162534 for efficient processing by attention.

API#

class core.ssm.mamba_context_parallel.MambaContextParallel(
cp_group: torch.distributed.ProcessGroup,
d_inner_local_tp: int,
nheads_local_tp: int,
ngroups_local_tp: int,
d_state: int,
conv1d_cp1: torch.nn.Conv1d,
dt_bias_cp1: torch.Tensor,
A_log_cp1: torch.Tensor,
D_cp1: torch.Tensor,
D_has_hdim: bool,
)#

This class provides the following functionality related to Mamba “all-to-all” context parallel:

  1. Error checking, and creation of, relevant parameters (e.g. nheads_local_tpcp)

  2. Collective operations on activations, on each context parallel rank, before and after the convolution and SSM

  3. A convolution operator that uses the correct slices of trainable variables on the current context parallel rank

  4. Sliced views of relevant trainable variables for the current context parallel rank

This class is intentionally not a sub-class of MegatronModule. This class does not contain any trainable variables of its own and should not be involved in any checkpoint loading or saving.

Parameters:
  • cp_group (torch.distributed.ProcessGroup) – The process group to use for context parallel.

  • d_inner_local_tp (int) – d_inner on the current tp rank

  • nheads_local_tp (int) – nheads on the current tp rank

  • ngroups_local_tp (int) – ngroups on the current tp rank

  • d_state (int) – Mamba d_state

  • conv1d_cp1 (nn.Conv1d) – The conv1d op which would be applied on this tp rank if cp_size was 1

  • dt_bias_cp1 (torch.Tensor) – The dt_bias parameter which would be used on this tp rank if cp_size was 1

  • A_log_cp1 (torch.Tensor) – The A_log parameter which would be used on this tp rank if cp_size was 1

  • D_cp1 (torch.Tensor) – The D parameter which would be used on this tp rank if cp_size was 1

  • D_has_hdim (bool) – D parameter is sized to hidden dimension, rather than being per-head

Initialization

pre_conv_ssm(input_: torch.Tensor) torch.Tensor#

Method to be applied before the convolution and SSM

post_conv_ssm(input_: torch.Tensor) torch.Tensor#

Method to be applied after the convolution and SSM

conv1d(input_: torch.Tensor) torch.Tensor#

Performs a conv1d on one context parallel rank, using slices of the weight and bias from the convolution that would be run when cp_size=1

conv1d_channels()#

Returns the number of convolution channels on the current context parallel rank

get_conv1d_weight() torch.Tensor#

Returns a slice of the conv1d weight relevant to the current context parallel rank

get_conv1d_bias() torch.Tensor#

Returns a slice of the conv1d bias relevant to the current context parallel rank

get_dt_bias() torch.Tensor#

Returns a slice of dt_bias relevant to the current context parallel rank

get_A_log() torch.Tensor#

Returns a slice of A_log relevant to the current context parallel rank

get_D() torch.Tensor#

Returns a slice of D relevant to the current context parallel rank

_slice_conv_param(param: torch.Tensor) torch.Tensor#

Slices a cp_size=1 conv1d parameter (either weight or bias) along the first dimension, returning the parts of the parameter needed for convolution on the current context parallel rank. Parameter slicing is done in the forward path so that gradients will backpropagate to the cp_size=1 parameters.

_slice_vector_param(
param: torch.Tensor,
has_hdim: bool = False,
) torch.Tensor#

Slices a cp_size=1 vector parameter along the first dimension, returning the part of the parameter needed on the current context parallel rank. Parameter slicing is done in the forward path so that gradients will backpropagate to the cp_size=1 parameters.

core.ssm.mamba_context_parallel._all_to_all_cp2hp(
input_: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
) torch.Tensor#

Perform AlltoAll communication on a context parallel group, transform the input tensor from shape [global-sequence/context-parallel-size, batch, local-hidden] to [global-sequence, batch, local-hidden/context-parallel-size].

Parameters:
  • input_ (torch.Tensor) – The input tensor, which is partitioned along the sequence dimension

  • cp_group (torch.distributed.ProcessGroup) – Process group to use for context parallel

Returns:

The output tensor with shape [global-sequence, batch, local-hidden/context-parallel-size].

Return type:

torch.Tensor

core.ssm.mamba_context_parallel._all_to_all_hp2cp(
input_: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
) torch.Tensor#

Perform AlltoAll communication on a context parallel group, transform the input tensor from shape [global-sequence, batch, local-hidden/context-parallel-size] to [global-sequence/context-parallel-size, batch, local-hidden].

Parameters:
  • input_ (torch.Tensor) – The input tensor, which is partitioned along the hidden dimension

  • cp_group (torch.distributed.ProcessGroup) – Process group to use for context parallel

Returns:

The output tensor with shape [global-sequence/context-parallel-size, batch, local-hidden].

Return type:

torch.Tensor

core.ssm.mamba_context_parallel._undo_attention_load_balancing(
input_: torch.Tensor,
cp_size: int,
) torch.Tensor#

Undoes the context parallel attention load balancing For example, for cp_size=3, converts 162534 to 123456 for sequential processing by the convolution and SSM.

core.ssm.mamba_context_parallel._redo_attention_load_balancing(
input_: torch.Tensor,
cp_size: int,
) torch.Tensor#

Redo the context parallel attention load balancing For example, for cp_size=3, converts 123456 to 162534 for efficient processing by attention.