nemo_automodel.components.distributed.pipelining.hf_utils

View as Markdown

Module Contents

Functions

NameDescription
_build_or_reuse_pp_causal_maskBuild a stage’s causal_mask_mapping, caching it per stage when safe.
_is_gemma4_vlmReturn True only for Gemma4 VLM variants.
_is_mistral3_vlmReturn True for Mistral3ForConditionalGeneration (Pixtral + Ministral3).
_is_vlmBest-effort check for whether model is a vision-language model.
create_pipeline_forward_causal_lmCreate a pipeline-compatible forward method for causal LM wrappers.
create_pipeline_forward_gemma4_textPipeline-compatible forward for the Gemma4 text decoder backbone.
create_pipeline_forward_gemma4_vlmPipeline-compatible forward for Gemma4ForConditionalGeneration (VLM top-level).
create_pipeline_forward_innerCreate a pipeline-compatible forward method for HuggingFace inner models.
create_pipeline_forward_mistral3_vlmPipeline-compatible forward for Mistral3ForConditionalGeneration (VLM top-level).
get_text_moduleReturn the nested text/LLM module if present, else the model itself.
init_hf_model_buffersInitialize HuggingFace model buffers needed before pipeline execution.
model_keeps_self_forwardReturn True when model opts out of pipeline-aware forward patching.
patch_hf_model_for_ppPatch a HF model/module to produce pipeline-compatible forward.
validate_hf_model_for_pipeline_supportValidate if a model is compatible with torch.distributed.pipelining.

Data

MULTIMODAL_SUFFIXES

TEXT_MODULE_ATTRS

_PP_VLM_MODEL_TYPES_WITH_DEDICATED_FORWARD

logger

API

nemo_automodel.components.distributed.pipelining.hf_utils._build_or_reuse_pp_causal_mask(
module,
inputs_embeds,
attention_mask,
cache_position,
position_ids
)

Build a stage’s causal_mask_mapping, caching it per stage when safe.

Under pipeline parallelism the mask precomputed in the data pipeline only reaches the first stage; non-first stages arrive with causal_mask_mapping=None and used to recompute it on every microbatch (slow, and a torch.compile graph-break). When no explicit attention_mask is provided — the common fixed-length / packed training case, and exactly what non-first stages receive — the causal mask depends only on (seq_len, dtype, device) and is constant across microbatches and steps, so it is built once per stage and reused. With an explicit attention_mask (which may encode per-batch padding) it is rebuilt each call. Behavior is identical to the previous recompute; only the redundant recomputation is removed.

nemo_automodel.components.distributed.pipelining.hf_utils._is_gemma4_vlm(
model: torch.nn.Module
) -> bool

Return True only for Gemma4 VLM variants.

model.model.language_model alone is not enough to identify Gemma4 — Kimi VL, Mistral4, Qwen3 VL MoE, Llava OneVision and others share that structure. Gate the Gemma4-specific PP forward on the HF model_type so unrelated VLMs fall through to the generic CausalLM path instead of receiving Gemma4’s sliding/full-attention and softcapping logic.

nemo_automodel.components.distributed.pipelining.hf_utils._is_mistral3_vlm(
model: torch.nn.Module
) -> bool

Return True for Mistral3ForConditionalGeneration (Pixtral + Ministral3).

nemo_automodel.components.distributed.pipelining.hf_utils._is_vlm(
model: torch.nn.Module
) -> bool

Best-effort check for whether model is a vision-language model.

Looks at the standard VLM markers used elsewhere in the codebase: a nested text_config, a vision_tower attribute on the outer model, or a visual attribute on the inner model (Qwen-VL convention).

nemo_automodel.components.distributed.pipelining.hf_utils.create_pipeline_forward_causal_lm() -> typing.Callable

Create a pipeline-compatible forward method for causal LM wrappers.

nemo_automodel.components.distributed.pipelining.hf_utils.create_pipeline_forward_gemma4_text() -> typing.Callable

Pipeline-compatible forward for the Gemma4 text decoder backbone.

Works for both HF Gemma4TextModel (dense path) and Gemma4MoETextModelBackend (MoE path). Handles:

  • Optional embed_tokens (None on non-first PP stages; hidden states arrive in input_ids slot)
  • Both full_attention and sliding_attention causal masks (Gemma4 uses mixed layer types)
  • Per-layer-type position embeddings: Gemma4RotaryEmbedding.forward(x, pos_ids, layer_type)
nemo_automodel.components.distributed.pipelining.hf_utils.create_pipeline_forward_gemma4_vlm() -> typing.Callable

Pipeline-compatible forward for Gemma4ForConditionalGeneration (VLM top-level).

Stage 0: embeds text tokens, merges image features from vision tower (if pixel_values provided or stored in _vlm_pixel_values_chunks), then calls the patched language model. Non-first stages: passes hidden states straight to the patched language model. Last stage: applies lm_head and final-logit softcapping.

nemo_automodel.components.distributed.pipelining.hf_utils.create_pipeline_forward_inner(
model_class_name: str = 'AutoModel'
) -> typing.Callable

Create a pipeline-compatible forward method for HuggingFace inner models.

nemo_automodel.components.distributed.pipelining.hf_utils.create_pipeline_forward_mistral3_vlm() -> typing.Callable

Pipeline-compatible forward for Mistral3ForConditionalGeneration (VLM top-level).

Stage 0: embeds text tokens, runs vision_tower + multi_modal_projector for image tokens, merges image features into inputs_embeds via get_placeholder_mask/masked_scatter, then calls the patched language model. Non-first stages: passes hidden states straight through the patched language model. Last stage: applies lm_head.

Mirrors the generic CausalLM PP forward but adds the Mistral3 vision path so pixel_values/image_sizes reach get_image_features on stage 0. Without this, the generic CausalLM path never touches vision_tower and image tokens are embedded as garbage text tokens.

nemo_automodel.components.distributed.pipelining.hf_utils.get_text_module(
model: torch.nn.Module
) -> torch.nn.Module

Return the nested text/LLM module if present, else the model itself.

nemo_automodel.components.distributed.pipelining.hf_utils.init_hf_model_buffers(
model: torch.nn.Module,
device: torch.device
) -> None

Initialize HuggingFace model buffers needed before pipeline execution.

nemo_automodel.components.distributed.pipelining.hf_utils.model_keeps_self_forward(
model: torch.nn.Module
) -> bool

Return True when model opts out of pipeline-aware forward patching.

Used by the pipeline split call site to skip patch_hf_model_for_pp entirely for models whose own forward is already PP-aware (typically because it pulls pixel_values out of self._vlm_pixel_values_chunks set by the training loop). Currently set on Qwen3-VL-MoE, Qwen3.5-MoE, KimiVL, and Kimi-K2.5-VL.

nemo_automodel.components.distributed.pipelining.hf_utils.patch_hf_model_for_pp(
model,
patch_inner_model: bool = True,
patch_causal_lm_model: bool = True
) -> None

Patch a HF model/module to produce pipeline-compatible forward.

The caller is responsible for skipping this function when the model opts out via model_keeps_self_forward(model). This function itself only branches on the patch flavor:

  • Gemma4 VLM (config.model_type == 'gemma4' with a nested text backbone at model.model.language_model): patch the text backbone and VLM outer with Gemma4-specific VLM-aware forwards.
  • Mistral3 VLM: patch the text backbone with the generic inner forward and the outer with the Mistral3-specific VLM forward.
  • Other models with model.model (e.g., LlamaForCausalLM and other LLMs): patch inner and outer with the generic CausalLM forwards.
  • Else: patch the module itself with the generic inner forward.
nemo_automodel.components.distributed.pipelining.hf_utils.validate_hf_model_for_pipeline_support(
model: torch.nn.Module
) -> None

Validate if a model is compatible with torch.distributed.pipelining.

nemo_automodel.components.distributed.pipelining.hf_utils.MULTIMODAL_SUFFIXES = ('vision_tower', 'visual', 'vision_model', 'image_encoder', 'vision_encoder', 'e...
nemo_automodel.components.distributed.pipelining.hf_utils.TEXT_MODULE_ATTRS = ('language_model', 'text_model', 'text_decoder')
nemo_automodel.components.distributed.pipelining.hf_utils._PP_VLM_MODEL_TYPES_WITH_DEDICATED_FORWARD: tuple[str, ...] = ('gemma4', 'mistral3')
nemo_automodel.components.distributed.pipelining.hf_utils.logger = logging.getLogger(__name__)