nemo_automodel.components.models.gemma4_moe.cp_batch
nemo_automodel.components.models.gemma4_moe.cp_batch
Gemma4’s contiguous-shard context-parallel batch sharding.
Contiguously shards the sequence across CP ranks (each rank keeps one
seq_start:seq_end slice) so Gemma4 can run its own p2p ring FlexAttention
over the shards. It performs no collective — the transport lives in Gemma4’s
attention (see cp_attention.py); this is its batch-side counterpart, and the
non-load-balanced peer of the context_parallel and TE/THD batch shardings.
Gemma4’s prepare_model_inputs_for_cp attaches
make_contiguous_shard_cp_batch_and_ctx to the batch as _cp_make_batch_fn;
cp_utils.make_cp_batch_and_ctx then invokes that callable (model-agnostically)
in place of the default load-balanced context_parallel path.
Module Contents
Functions
API
Pre-shard prep for the model-owned CP path.
Kept here (rather than in cp_utils.make_cp_batch_and_ctx) so that
function’s default, load-balanced path stays identical to upstream. Converts
attention_mask to a padding_mask (preserving padding semantics for
modules such as MoE), selects the primary sequence tensor, injects/normalizes
position_ids, and resolves labels (falling back to loss_mask).
Materialize the trivial single-document _packed_seq_ids map for the manual CP path.
The VLM/LLM collates emit _packed_seq_ids only when 2+ documents are
packed (attention_mask.max() > 1), so a single, unpacked sequence arrives
without it. The manual CP attention mask builder needs document boundaries
even for one document, so synthesize the trivial map here (1 = real token, 0 =
pad) instead of lowering each collate’s threshold — which would change
behavior for every non-CP _packed_seq_ids consumer (e.g. SqrtCrossEntropy).
Derived from padding_mask when present, else all-ones.
A no-op when _packed_seq_ids is already present (genuinely packed input).
Parameters:
The CP batch dict; mutated in place to add _packed_seq_ids.
"input_ids" or "inputs_embeds" (selects batch / device).
The pre-pad sequence length.
Prepare and contiguously shard a batch for Gemma4’s ring CP.
Gemma4 attaches this callable to the batch (as _cp_make_batch_fn) in its
pre-embed; cp_utils.make_cp_batch_and_ctx invokes it. Runs the shared
pre-shard prep, then keeps one contiguous sequence slice per CP rank (no
collective; the transport lives in Gemma4’s own ring attention).