nemo_automodel.components.models.deepseek_v4.cp
nemo_automodel.components.models.deepseek_v4.cp
Context-parallel helpers for the DeepSeek V4 custom model.
This implements the Miles-style training path: each CP rank owns a contiguous query shard, while K/V and compressed K/V are all-gathered with autograd-aware collectives before DSV4 sparse attention consumes them.
Module Contents
Functions
API
Build local-query/global-key additive mask for Miles-style DSV4 CP.
position_ids are the local query positions after contiguous CP slicing.
Keys are in global sequence order because DSV4 gathers K/V along sequence.
padding_mask follows the internal convention True=padding.
All-gather activation tensors across CP ranks and concatenate on dim.
The distributed.nn functional collective preserves autograd, so backward routes gradients for gathered remote slices back to their owning ranks.
All-gather non-differentiable metadata such as padding masks.
Return whether a real CP process group is active.
Required per-CP-rank sequence-length multiple for DSV4 Miles-style CP.
Compress-ratio layers constrain how the sequence may be split across CP ranks:
a ratio-R layer needs each local shard divisible by R, and ratio-4 layers use
cross-window overlap so they need 2*R. The returned value is the LCM across
all configured compress_ratios (1 when none are configured).
Return this rank’s index in the DSV4 CP group, or 0 without CP.
Return the DSV4 CP group size, or 1 without CP.
Contiguously shard a batch for DeepSeek V4 Miles-style context parallelism.
Attached to the batch as _cp_make_batch_fn (via functools.partial to bind
pad_multiple) and invoked by cp_utils.make_cp_batch_and_ctx. Each CP rank
keeps one seq_start:seq_end slice; DSV4 attention all-gathers K/V across CP
ranks during forward. No collective happens here — this is the batch-side
counterpart of cp.py’s activation gathers. Returns (nullcontext, batch).
pad_multiple is the required per-CP-rank shard multiple (from
dsv4_cp_local_seq_multiple); the global sequence is padded so it is divisible
by cp_size and each local shard is divisible by pad_multiple (>= 2).