nemo_automodel.components.models.gemma4_moe.cp_batch

View as Markdown

Gemma4’s contiguous-shard context-parallel batch sharding.

Contiguously shards the sequence across CP ranks (each rank keeps one seq_start:seq_end slice) so Gemma4 can run its own p2p ring FlexAttention over the shards. It performs no collective — the transport lives in Gemma4’s attention (see cp_attention.py); this is its batch-side counterpart, and the non-load-balanced peer of the context_parallel and TE/THD batch shardings.

Gemma4’s prepare_model_inputs_for_cp attaches make_contiguous_shard_cp_batch_and_ctx to the batch as _cp_make_batch_fn; cp_utils.make_cp_batch_and_ctx then invokes that callable (model-agnostically) in place of the default load-balanced context_parallel path.

Module Contents

Functions

NameDescription
_make_contiguous_shard_cp_batch-
_pad_position_ids_seq_dim_-
_pad_tensor_seq_dim_-
_prepare_manual_cp_batchPre-shard prep for the model-owned CP path.
_synthesize_single_document_seq_idsMaterialize the trivial single-document _packed_seq_ids map for the manual CP path.
make_contiguous_shard_cp_batch_and_ctxPrepare and contiguously shard a batch for Gemma4’s ring CP.

API

nemo_automodel.components.models.gemma4_moe.cp_batch._make_contiguous_shard_cp_batch(
cp_mesh,
batch,
primary_key,
seq_len,
labels,
position_ids,
pos_seq_dim,
loss_mask,
padding_token_id
)
nemo_automodel.components.models.gemma4_moe.cp_batch._pad_position_ids_seq_dim_(
position_ids: torch.Tensor,
seq_dim: int,
pad_len: int
) -> torch.Tensor
nemo_automodel.components.models.gemma4_moe.cp_batch._pad_tensor_seq_dim_(
tensor: torch.Tensor,
seq_dim: int,
pad_len: int,
value: float | int = 0
) -> torch.Tensor
nemo_automodel.components.models.gemma4_moe.cp_batch._prepare_manual_cp_batch(
cp_mesh,
tp_mesh,
batch,
loss_mask
)

Pre-shard prep for the model-owned CP path.

Kept here (rather than in cp_utils.make_cp_batch_and_ctx) so that function’s default, load-balanced path stays identical to upstream. Converts attention_mask to a padding_mask (preserving padding semantics for modules such as MoE), selects the primary sequence tensor, injects/normalizes position_ids, and resolves labels (falling back to loss_mask).

nemo_automodel.components.models.gemma4_moe.cp_batch._synthesize_single_document_seq_ids(
batch: dict,
primary_key: str,
seq_len: int
) -> None

Materialize the trivial single-document _packed_seq_ids map for the manual CP path.

The VLM/LLM collates emit _packed_seq_ids only when 2+ documents are packed (attention_mask.max() > 1), so a single, unpacked sequence arrives without it. The manual CP attention mask builder needs document boundaries even for one document, so synthesize the trivial map here (1 = real token, 0 = pad) instead of lowering each collate’s threshold — which would change behavior for every non-CP _packed_seq_ids consumer (e.g. SqrtCrossEntropy). Derived from padding_mask when present, else all-ones.

A no-op when _packed_seq_ids is already present (genuinely packed input).

Parameters:

batch
dict

The CP batch dict; mutated in place to add _packed_seq_ids.

primary_key
str

"input_ids" or "inputs_embeds" (selects batch / device).

seq_len
int

The pre-pad sequence length.

nemo_automodel.components.models.gemma4_moe.cp_batch.make_contiguous_shard_cp_batch_and_ctx(
cp_mesh,
tp_mesh,
batch,
loss_mask = None,
padding_token_id = 0
)

Prepare and contiguously shard a batch for Gemma4’s ring CP.

Gemma4 attaches this callable to the batch (as _cp_make_batch_fn) in its pre-embed; cp_utils.make_cp_batch_and_ctx invokes it. Runs the shared pre-shard prep, then keeps one contiguous sequence slice per CP rank (no collective; the transport lives in Gemma4’s own ring attention).