nemo_automodel.components.models.gemma4_moe.cp_attention
nemo_automodel.components.models.gemma4_moe.cp_attention
Gemma4-specific context-parallel attention helpers.
Module Contents
Classes
Functions
Data
_GEMMA4_CP_FLEX_RING_OK_LOGGED
API
Inputs for Gemma4 manual ring CP attention (built by the run_cp_manual_attention seam).
Bases: Function
Reset the per-step block-mask cache when a new batch (new metadata) arrives.
Locally disable flex duck-shape specialization for the wrapped flex call.
With variable-length (unpacked) batches the compiled flex kernel otherwise
guards on incidental dim-equalities (e.g. block_mask.kv_indices.size()[2] == key.size()[1]) and recompiles on every new sequence length, collapsing
throughput to ~warmup speed. use_duck_shape is read by dynamo at (re)trace
time — which happens inside the flex call — so scoping it to the call window
is sufficient and, unlike setting it once at compile time, does not leave the
process-global torch.fx config mutated for unrelated torch.compile users.
Gemma4-owned manual ring CP attention entry.
Plugs into cp_utils’ generic run_cp_manual_attention seam: receives the
raw local (un-gathered) Q/K/V plus cp_mesh, builds the ring context, and
runs the p2p ring FlexAttention. K/V are rotated across CP ranks inside the
ring autograd function — they are never all-gathered.
Swap F.scaled_dot_product_attention -> Gemma4 ring CP attention on this module.
Gemma4 owns its CP attention end-to-end (it does not use cp_utils’ generic CP
SDPA hooks). It installs its own @torch._dynamo.disable SDPA wrapper — on
the inner attention module so it also fires during gradient-checkpointing
recompute — that runs the p2p ring FlexAttention. The per-forward attention
kwargs the ring needs (mm_token_type_ids, packed-seq ids, padding/vision masks)
are captured off the forward kwargs into _cp_manual_metadata here, since the
swapped SDPA only receives Q/K/V.
Guard FSDPParam.to_accumulated_grad_if_needed against uninitialized params.
On some torch builds that method reads self._unsharded_param (the lazily
set unsharded tensor) without first checking it exists. In FSDP2 post-backward
under fp32 grad-reduce, frozen / never-unsharded params (e.g. the frozen Gemma4
vision tower and embeddings) have no _unsharded_param yet and it raises
AttributeError. Such params carry no grad to upcast anyway, so wrap the
method to skip them when uninitialized. No-op once applied / on fixed builds.
Run Gemma4 local-query/ring-key CP attention with FlexAttention.
Run Gemma4 local-query/ring-key CP attention forward with FlexAttention.
Register Gemma4’s model-owned p2p ring CP attention on a self-attention module.
Declares the metadata keys the ring needs and exposes setup_cp_attention(cp_mesh)
— the model-owned CP-attention seam the parallelizer calls (with the CP mesh)
instead of cp_utils’ generic SDPA hooks. run_cp_manual_attention is also bound
as the ring entry point.
Return per-image-block ids for Gemma4 vision tokens, or -1 for text/padding.