nemo_automodel.components.distributed.pipelining.hf_utils
nemo_automodel.components.distributed.pipelining.hf_utils
Module Contents
Functions
Data
_PP_VLM_MODEL_TYPES_WITH_DEDICATED_FORWARD
API
Build a stage’s causal_mask_mapping, caching it per stage when safe.
Under pipeline parallelism the mask precomputed in the data pipeline only reaches
the first stage; non-first stages arrive with causal_mask_mapping=None and used
to recompute it on every microbatch (slow, and a torch.compile graph-break). When
no explicit attention_mask is provided — the common fixed-length / packed
training case, and exactly what non-first stages receive — the causal mask depends
only on (seq_len, dtype, device) and is constant across microbatches and steps,
so it is built once per stage and reused. With an explicit attention_mask (which
may encode per-batch padding) it is rebuilt each call. Behavior is identical to the
previous recompute; only the redundant recomputation is removed.
Return True only for Gemma4 VLM variants.
model.model.language_model alone is not enough to identify Gemma4 —
Kimi VL, Mistral4, Qwen3 VL MoE, Llava OneVision and others share that
structure. Gate the Gemma4-specific PP forward on the HF model_type
so unrelated VLMs fall through to the generic CausalLM path instead of
receiving Gemma4’s sliding/full-attention and softcapping logic.
Return True for Mistral3ForConditionalGeneration (Pixtral + Ministral3).
Best-effort check for whether model is a vision-language model.
Looks at the standard VLM markers used elsewhere in the codebase: a nested
text_config, a vision_tower attribute on the outer model, or a
visual attribute on the inner model (Qwen-VL convention).
Create a pipeline-compatible forward method for causal LM wrappers.
Pipeline-compatible forward for the Gemma4 text decoder backbone.
Works for both HF Gemma4TextModel (dense path) and Gemma4MoETextModelBackend (MoE path). Handles:
- Optional embed_tokens (None on non-first PP stages; hidden states arrive in input_ids slot)
- Both full_attention and sliding_attention causal masks (Gemma4 uses mixed layer types)
- Per-layer-type position embeddings: Gemma4RotaryEmbedding.forward(x, pos_ids, layer_type)
Pipeline-compatible forward for Gemma4ForConditionalGeneration (VLM top-level).
Stage 0: embeds text tokens, merges image features from vision tower (if pixel_values provided or stored in _vlm_pixel_values_chunks), then calls the patched language model. Non-first stages: passes hidden states straight to the patched language model. Last stage: applies lm_head and final-logit softcapping.
Create a pipeline-compatible forward method for HuggingFace inner models.
Pipeline-compatible forward for Mistral3ForConditionalGeneration (VLM top-level).
Stage 0: embeds text tokens, runs vision_tower + multi_modal_projector for
image tokens, merges image features into inputs_embeds via
get_placeholder_mask/masked_scatter, then calls the patched language
model. Non-first stages: passes hidden states straight through the patched
language model. Last stage: applies lm_head.
Mirrors the generic CausalLM PP forward but adds the Mistral3 vision path
so pixel_values/image_sizes reach get_image_features on stage 0.
Without this, the generic CausalLM path never touches vision_tower and
image tokens are embedded as garbage text tokens.
Return the nested text/LLM module if present, else the model itself.
Initialize HuggingFace model buffers needed before pipeline execution.
Return True when model opts out of pipeline-aware forward patching.
Used by the pipeline split call site to skip patch_hf_model_for_pp
entirely for models whose own forward is already PP-aware (typically
because it pulls pixel_values out of self._vlm_pixel_values_chunks
set by the training loop). Currently set on Qwen3-VL-MoE, Qwen3.5-MoE,
KimiVL, and Kimi-K2.5-VL.
Patch a HF model/module to produce pipeline-compatible forward.
The caller is responsible for skipping this function when the model
opts out via model_keeps_self_forward(model). This function itself
only branches on the patch flavor:
- Gemma4 VLM (
config.model_type == 'gemma4'with a nested text backbone atmodel.model.language_model): patch the text backbone and VLM outer with Gemma4-specific VLM-aware forwards. - Mistral3 VLM: patch the text backbone with the generic inner forward and the outer with the Mistral3-specific VLM forward.
- Other models with
model.model(e.g., LlamaForCausalLM and other LLMs): patch inner and outer with the generic CausalLM forwards. - Else: patch the module itself with the generic inner forward.
Validate if a model is compatible with torch.distributed.pipelining.