nemo_automodel.components.models.gemma4_moe.cp_attention

View as Markdown

Gemma4-specific context-parallel attention helpers.

Module Contents

Classes

NameDescription
CPRingAttentionContextInputs for Gemma4 manual ring CP attention (built by the run_cp_manual_attention seam).
_Gemma4FlexRingAttention-

Functions

NameDescription
_base_gemma4_cp_mask-
_block_mask_set_generationReset the per-step block-mask cache when a new batch (new metadata) arrives.
_cached_block_mask-
_collect_ring_kv_chunks-
_compiled_flex_attention-
_detach_metadata-
_direct_exchange-
_duck_shape_disabledLocally disable flex duck-shape specialization for the wrapped flex call.
_gemma4_cp_manual_attentionGemma4-owned manual ring CP attention entry.
_install_gemma4_cp_ring_sdpaSwap F.scaled_dot_product_attention -> Gemma4 ring CP attention on this module.
_merge_flex_chunk-
_metadata_like-
_patch_fsdp_accumulated_grad_guardGuard FSDPParam.to_accumulated_grad_if_needed against uninitialized params.
_ring_exchange-
_run_gemma4_cp_ring_attentionRun Gemma4 local-query/ring-key CP attention with FlexAttention.
_run_gemma4_cp_ring_attention_forwardRun Gemma4 local-query/ring-key CP attention forward with FlexAttention.
_run_gemma4_flex_chunk-
_zero_if_none-
attach_gemma4_cp_ring_attentionRegister Gemma4’s model-owned p2p ring CP attention on a self-attention module.
gemma4_vision_group_idsReturn per-image-block ids for Gemma4 vision tokens, or -1 for text/padding.

Data

_BLOCK_MASK_CACHE

_BLOCK_MASK_GEN

_GEMMA4_CP_FLEX_RING_OK_LOGGED

logger

API

class nemo_automodel.components.models.gemma4_moe.cp_attention.CPRingAttentionContext(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cp_mesh: typing.Any,
cp_group: typing.Any,
cp_size: int,
cp_rank: int,
seq_local: int,
seq_full: int,
seq_global_start: int,
attn_mask: typing.Any,
dropout_p: float,
is_causal: bool,
scale: typing.Any,
enable_gqa: bool,
kwargs: dict[str, typing.Any],
metadata: dict[str, torch.Tensor | None],
metadata_seq_dims: dict[str, int]
)
Dataclass

Inputs for Gemma4 manual ring CP attention (built by the run_cp_manual_attention seam).

cp_rank
int
cp_size
int
dropout_p
float
enable_gqa
bool
is_causal
bool
key
Tensor
kwargs
dict[str, Any]
metadata
dict[str, Tensor | None]
metadata_seq_dims
dict[str, int]
module
Module
query
Tensor
seq_full
int
seq_global_start
int
seq_local
int
value
Tensor
class nemo_automodel.components.models.gemma4_moe.cp_attention._Gemma4FlexRingAttention()

Bases: Function

nemo_automodel.components.models.gemma4_moe.cp_attention._Gemma4FlexRingAttention.backward(
autograd_ctx,
grad_output: torch.Tensor
)
staticmethod
nemo_automodel.components.models.gemma4_moe.cp_attention._Gemma4FlexRingAttention.forward(
autograd_ctx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
ring_ctx: typing.Any
)
staticmethod
nemo_automodel.components.models.gemma4_moe.cp_attention._base_gemma4_cp_mask(
attention_module: torch.nn.Module,
ctx: typing.Any,
q_idx,
kv_idx,
kv_global_start: int = 0
)
nemo_automodel.components.models.gemma4_moe.cp_attention._block_mask_set_generation(
gen_tensor
) -> None

Reset the per-step block-mask cache when a new batch (new metadata) arrives.

nemo_automodel.components.models.gemma4_moe.cp_attention._cached_block_mask(
key,
build
)
nemo_automodel.components.models.gemma4_moe.cp_attention._collect_ring_kv_chunks(
ctx: typing.Any
) -> list[tuple[int, torch.Tensor, torch.Tensor, dict[str, torch.Tensor | None]]]
nemo_automodel.components.models.gemma4_moe.cp_attention._compiled_flex_attention(
attention_module: torch.nn.Module
)
nemo_automodel.components.models.gemma4_moe.cp_attention._detach_metadata(
metadata: dict[str, torch.Tensor | None]
) -> dict[str, torch.Tensor | None]
nemo_automodel.components.models.gemma4_moe.cp_attention._direct_exchange(
tensors: list[tuple[torch.Tensor, torch.Tensor]],
cp_group: typing.Any,
cp_rank: int,
send_cp_rank: int,
recv_cp_rank: int
) -> None
nemo_automodel.components.models.gemma4_moe.cp_attention._duck_shape_disabled()

Locally disable flex duck-shape specialization for the wrapped flex call.

With variable-length (unpacked) batches the compiled flex kernel otherwise guards on incidental dim-equalities (e.g. block_mask.kv_indices.size()[2] == key.size()[1]) and recompiles on every new sequence length, collapsing throughput to ~warmup speed. use_duck_shape is read by dynamo at (re)trace time — which happens inside the flex call — so scoping it to the call window is sufficient and, unlike setting it once at compile time, does not leave the process-global torch.fx config mutated for unrelated torch.compile users.

nemo_automodel.components.models.gemma4_moe.cp_attention._gemma4_cp_manual_attention(
attention_module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cp_mesh,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
kwargs
) -> torch.Tensor

Gemma4-owned manual ring CP attention entry.

Plugs into cp_utils’ generic run_cp_manual_attention seam: receives the raw local (un-gathered) Q/K/V plus cp_mesh, builds the ring context, and runs the p2p ring FlexAttention. K/V are rotated across CP ranks inside the ring autograd function — they are never all-gathered.

nemo_automodel.components.models.gemma4_moe.cp_attention._install_gemma4_cp_ring_sdpa(
attention_module: torch.nn.Module,
cp_mesh
) -> None

Swap F.scaled_dot_product_attention -> Gemma4 ring CP attention on this module.

Gemma4 owns its CP attention end-to-end (it does not use cp_utils’ generic CP SDPA hooks). It installs its own @torch._dynamo.disable SDPA wrapper — on the inner attention module so it also fires during gradient-checkpointing recompute — that runs the p2p ring FlexAttention. The per-forward attention kwargs the ring needs (mm_token_type_ids, packed-seq ids, padding/vision masks) are captured off the forward kwargs into _cp_manual_metadata here, since the swapped SDPA only receives Q/K/V.

nemo_automodel.components.models.gemma4_moe.cp_attention._merge_flex_chunk(
out_acc: torch.Tensor | None,
lse_acc: torch.Tensor | None,
out_step: torch.Tensor,
lse_step: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.models.gemma4_moe.cp_attention._metadata_like(
metadata: dict[str, torch.Tensor | None]
) -> dict[str, torch.Tensor | None]
nemo_automodel.components.models.gemma4_moe.cp_attention._patch_fsdp_accumulated_grad_guard() -> None

Guard FSDPParam.to_accumulated_grad_if_needed against uninitialized params.

On some torch builds that method reads self._unsharded_param (the lazily set unsharded tensor) without first checking it exists. In FSDP2 post-backward under fp32 grad-reduce, frozen / never-unsharded params (e.g. the frozen Gemma4 vision tower and embeddings) have no _unsharded_param yet and it raises AttributeError. Such params carry no grad to upcast anyway, so wrap the method to skip them when uninitialized. No-op once applied / on fixed builds.

nemo_automodel.components.models.gemma4_moe.cp_attention._ring_exchange(
tensors: list[tuple[torch.Tensor, torch.Tensor]],
cp_group: typing.Any,
cp_rank: int,
cp_size: int
) -> None
nemo_automodel.components.models.gemma4_moe.cp_attention._run_gemma4_cp_ring_attention(
attention_module: torch.nn.Module,
ctx: typing.Any
) -> torch.Tensor

Run Gemma4 local-query/ring-key CP attention with FlexAttention.

nemo_automodel.components.models.gemma4_moe.cp_attention._run_gemma4_cp_ring_attention_forward(
attention_module: torch.nn.Module,
ctx: typing.Any
) -> torch.Tensor

Run Gemma4 local-query/ring-key CP attention forward with FlexAttention.

nemo_automodel.components.models.gemma4_moe.cp_attention._run_gemma4_flex_chunk(
attention_module: torch.nn.Module,
ctx: typing.Any,
key_chunk: torch.Tensor,
value_chunk: torch.Tensor,
metadata_chunk: dict[str, torch.Tensor | None],
kv_global_start: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, int]
nemo_automodel.components.models.gemma4_moe.cp_attention._zero_if_none(
grad: torch.Tensor | None,
like: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.models.gemma4_moe.cp_attention.attach_gemma4_cp_ring_attention(
attention_module: torch.nn.Module
) -> None

Register Gemma4’s model-owned p2p ring CP attention on a self-attention module.

Declares the metadata keys the ring needs and exposes setup_cp_attention(cp_mesh) — the model-owned CP-attention seam the parallelizer calls (with the CP mesh) instead of cp_utils’ generic SDPA hooks. run_cp_manual_attention is also bound as the ring entry point.

nemo_automodel.components.models.gemma4_moe.cp_attention.gemma4_vision_group_ids(
mm_token_type_ids: torch.Tensor
) -> torch.Tensor

Return per-image-block ids for Gemma4 vision tokens, or -1 for text/padding.

nemo_automodel.components.models.gemma4_moe.cp_attention._BLOCK_MASK_CACHE: dict = {}
nemo_automodel.components.models.gemma4_moe.cp_attention._BLOCK_MASK_GEN: list = [None, None]
nemo_automodel.components.models.gemma4_moe.cp_attention._GEMMA4_CP_FLEX_RING_OK_LOGGED = False
nemo_automodel.components.models.gemma4_moe.cp_attention.logger = logging.getLogger(__name__)