nemo_automodel.components.models.glm_moe_dsa.cp

View as Markdown

Context-parallel helpers for GLM MoE DSA TileLang attention.

Module Contents

Functions

NameDescription
_contiguous_cp_indices-
_slice_thd_chunk_for_cp-
glm_dsa_cp_all_gatherAll-gather activation tensors across CP ranks while preserving autograd.
glm_dsa_cp_enabledReturn whether a real GLM DSA CP process group is active.
make_glm_dsa_packed_cp_batch_and_ctxConvert packed GLM DSA batches to THD and keep a contiguous query shard per CP rank.

API

nemo_automodel.components.models.glm_moe_dsa.cp._contiguous_cp_indices(
total_tokens: int,
cp_size: int,
cp_rank: int,
device: torch.device
) -> torch.Tensor
nemo_automodel.components.models.glm_moe_dsa.cp._slice_thd_chunk_for_cp(
chunk: dict[str, torch.Tensor],
cp_group,
cp_size: int,
cp_rank: int,
padding_token_id: int
) -> dict[str, torch.Tensor]
nemo_automodel.components.models.glm_moe_dsa.cp.glm_dsa_cp_all_gather(
tensor: torch.Tensor,
dim: int,
cp_group
) -> torch.Tensor

All-gather activation tensors across CP ranks while preserving autograd.

nemo_automodel.components.models.glm_moe_dsa.cp.glm_dsa_cp_enabled(
cp_group
) -> bool

Return whether a real GLM DSA CP process group is active.

nemo_automodel.components.models.glm_moe_dsa.cp.make_glm_dsa_packed_cp_batch_and_ctx(
cp_mesh,
tp_mesh,
batch,
loss_mask = None,
padding_token_id: int = 0,
num_chunks: int = 1,
seq_lens_padding_value: int = -1000
)

Convert packed GLM DSA batches to THD and keep a contiguous query shard per CP rank.

GLM DSA sparse attention gathers K/V activations inside the model. The batch side only slices local query tokens and carries the full packed-sequence cu_seqlens plus per-query global token indices for TileLang’s causal top-k window.