nemo_automodel.components.models.deepseek_v4.cp

View as Markdown

Context-parallel helpers for the DeepSeek V4 custom model.

This implements the Miles-style training path: each CP rank owns a contiguous query shard, while K/V and compressed K/V are all-gathered with autograd-aware collectives before DSV4 sparse attention consumes them.

Module Contents

Functions

NameDescription
_lcm-
_pad_position_ids_seq_dim_-
_pad_tensor_seq_dim_-
build_dsv4_cp_causal_padding_maskBuild local-query/global-key additive mask for Miles-style DSV4 CP.
dsv4_cp_all_gatherAll-gather activation tensors across CP ranks and concatenate on dim.
dsv4_cp_all_gather_metadataAll-gather non-differentiable metadata such as padding masks.
dsv4_cp_enabledReturn whether a real CP process group is active.
dsv4_cp_local_seq_multipleRequired per-CP-rank sequence-length multiple for DSV4 Miles-style CP.
dsv4_cp_rankReturn this rank’s index in the DSV4 CP group, or 0 without CP.
dsv4_cp_sizeReturn the DSV4 CP group size, or 1 without CP.
make_dsv4_contiguous_shard_cp_batch_and_ctxContiguously shard a batch for DeepSeek V4 Miles-style context parallelism.

API

nemo_automodel.components.models.deepseek_v4.cp._lcm(
a: int,
b: int
) -> int
nemo_automodel.components.models.deepseek_v4.cp._pad_position_ids_seq_dim_(
position_ids: torch.Tensor,
seq_dim: int,
pad_len: int
) -> torch.Tensor
nemo_automodel.components.models.deepseek_v4.cp._pad_tensor_seq_dim_(
tensor: torch.Tensor,
seq_dim: int,
pad_len: int,
value
) -> torch.Tensor
nemo_automodel.components.models.deepseek_v4.cp.build_dsv4_cp_causal_padding_mask(
position_ids: torch.Tensor,
key_len: int,
dtype: torch.dtype,
device: torch.device,
cp_group,
padding_mask: torch.Tensor | None = None,
sliding_window: int | None = None
) -> torch.Tensor

Build local-query/global-key additive mask for Miles-style DSV4 CP.

position_ids are the local query positions after contiguous CP slicing. Keys are in global sequence order because DSV4 gathers K/V along sequence. padding_mask follows the internal convention True=padding.

nemo_automodel.components.models.deepseek_v4.cp.dsv4_cp_all_gather(
tensor: torch.Tensor,
dim: int,
cp_group
) -> torch.Tensor

All-gather activation tensors across CP ranks and concatenate on dim.

The distributed.nn functional collective preserves autograd, so backward routes gradients for gathered remote slices back to their owning ranks.

nemo_automodel.components.models.deepseek_v4.cp.dsv4_cp_all_gather_metadata(
tensor: torch.Tensor | None,
dim: int,
cp_group
) -> torch.Tensor | None

All-gather non-differentiable metadata such as padding masks.

nemo_automodel.components.models.deepseek_v4.cp.dsv4_cp_enabled(
cp_group
) -> bool

Return whether a real CP process group is active.

nemo_automodel.components.models.deepseek_v4.cp.dsv4_cp_local_seq_multiple(
model_or_config
) -> int

Required per-CP-rank sequence-length multiple for DSV4 Miles-style CP.

Compress-ratio layers constrain how the sequence may be split across CP ranks: a ratio-R layer needs each local shard divisible by R, and ratio-4 layers use cross-window overlap so they need 2*R. The returned value is the LCM across all configured compress_ratios (1 when none are configured).

nemo_automodel.components.models.deepseek_v4.cp.dsv4_cp_rank(
cp_group
) -> int

Return this rank’s index in the DSV4 CP group, or 0 without CP.

nemo_automodel.components.models.deepseek_v4.cp.dsv4_cp_size(
cp_group
) -> int

Return the DSV4 CP group size, or 1 without CP.

nemo_automodel.components.models.deepseek_v4.cp.make_dsv4_contiguous_shard_cp_batch_and_ctx(
cp_mesh,
tp_mesh,
batch,
loss_mask = None,
padding_token_id: int = 0,
pad_multiple: int | None = None
)

Contiguously shard a batch for DeepSeek V4 Miles-style context parallelism.

Attached to the batch as _cp_make_batch_fn (via functools.partial to bind pad_multiple) and invoked by cp_utils.make_cp_batch_and_ctx. Each CP rank keeps one seq_start:seq_end slice; DSV4 attention all-gathers K/V across CP ranks during forward. No collective happens here — this is the batch-side counterpart of cp.py’s activation gathers. Returns (nullcontext, batch).

pad_multiple is the required per-CP-rank shard multiple (from dsv4_cp_local_seq_multiple); the global sequence is padded so it is divisible by cp_size and each local shard is divisible by pad_multiple (>= 2).