core.models.multimodal.context_parallel#

Multimodal Sequence Parallel (SP) and Context Parallel (CP) functionality.

Module Contents#

Classes#

GatherFromContextParallelRanks

Gather the input from context parallel ranks.

Functions#

get_padding

Calculate padding needed for SP, CP, TP comm overlap, and FP8.

get_packed_seq_params

Get PackedSeqParams for CP.

split_to_context_parallel_ranks

Split the tensor global_t into context parallel world size parts.

_gather_along_second_dim

_reduce_scatter_along_second_dim

gather_from_context_parallel_ranks

Gather local_t across CP ranks, removing global_pad trailing pad tokens.

gather_from_context_parallel_ranks_dynamic_res

Gather dynamic-resolution tensors (variable seq per rank) from CP ranks.

_compute_tubelet_aware_split_points

Compute frame-space split points that respect tubelet boundaries within videos.

_split_num_frames

Return per-media frame counts clipped to the frame range [lb, ub).

split_to_context_parallel_ranks_dynamic_res

Split patched vision input across CP ranks.

API#

core.models.multimodal.context_parallel.get_padding(
seq_len,
cp_size,
tp_size,
has_sp,
decoder_tp_comm_overlap=False,
decoder_seq_len=None,
fp8_enabled=False,
fp8_recipe=None,
)#

Calculate padding needed for SP, CP, TP comm overlap, and FP8.

Parameters:
  • seq_len (int) – Model sequence length.

  • cp_size (int) – Context parallel size.

  • tp_size (int) – Tensor parallel size.

  • has_sp (bool) – Model uses sequence parallelism.

  • decoder_tp_comm_overlap (bool) – Decoder (LLM) uses tensor parallel communication overlap.

  • decoder_seq_len (int) – Decoder (LLM) maximum sequence length.

  • fp8_enabled (bool) – FP8 is enabled.

  • fp8_recipe (str) – FP8 recipe. Affects required padding.

Returns:

Padding needed given model configuration.

Return type:

padding (int)

core.models.multimodal.context_parallel.get_packed_seq_params(
tokens,
img_seq_len,
padding_needed,
cp_size,
use_packed_sequence=False,
)#

Get PackedSeqParams for CP.

Parameters:
  • tokens (torch.Tensor) – [batch, seq_len] input tokens.

  • img_seq_len (int) – Image sequence length.

  • padding_needed (int) – Padding to add.

  • cp_size (int) – Context parallel size.

  • use_packed_sequence (bool) – Uses sequence packing.

Returns:

Parameters to be sent to Transformer Engine.

Return type:

packed_seq_params (PackedSeqParams)

core.models.multimodal.context_parallel.split_to_context_parallel_ranks(global_t, pad_value=0)#

Split the tensor global_t into context parallel world size parts.

Parameters:
  • global_t – [batch, …]

  • pad_value – Value to pad the last rank with.

Returns:

[samples_per_rank, …]. samples_per_rank is the # of samples per CP rank. global_pad: Total padding to have equal samples_per_rank across context parallel ranks.

Return type:

local_t

core.models.multimodal.context_parallel._gather_along_second_dim(local_t)#
core.models.multimodal.context_parallel._reduce_scatter_along_second_dim(global_t)#
class core.models.multimodal.context_parallel.GatherFromContextParallelRanks#

Bases: torch.autograd.Function

Gather the input from context parallel ranks.

static symbolic(graph, input_)#

Symbolic forward used during torch.jit tracing.

static forward(ctx, input_)#

All-gather input_ along its second dimension across CP ranks.

static backward(ctx, grad_output)#

Reduce-scatter the gradient along the second dimension.

core.models.multimodal.context_parallel.gather_from_context_parallel_ranks(local_t, global_pad)#

Gather local_t across CP ranks, removing global_pad trailing pad tokens.

core.models.multimodal.context_parallel.gather_from_context_parallel_ranks_dynamic_res(
local_t,
num_padded_imgs=0,
)#

Gather dynamic-resolution tensors (variable seq per rank) from CP ranks.

core.models.multimodal.context_parallel._compute_tubelet_aware_split_points(
num_frames,
temporal_patch_size,
cp_size,
total_frames,
)#

Compute frame-space split points that respect tubelet boundaries within videos.

Returns cp_size + 1 split points in frame indices (not tubelet indices), since callers slice per-frame cu_seqlens and imgs_sizes with these bounds. Splits land on either media boundaries or tubelet boundaries inside a media so that no rank receives a partial tubelet.

core.models.multimodal.context_parallel._split_num_frames(num_frames, lb, ub)#

Return per-media frame counts clipped to the frame range [lb, ub).

lb and ub are frame indices (the same coordinate system used by

Func:

_compute_tubelet_aware_split_points and the per-frame seqlens array in :func:split_to_context_parallel_ranks_dynamic_res). The returned list has one entry per media that contributes at least one frame to the range, with the value being the number of frames of that media in the range.

core.models.multimodal.context_parallel.split_to_context_parallel_ranks_dynamic_res(
global_t,
global_imgs_sizes,
global_packed_seq_params,
*,
patch_dim,
fp8_enabled=False,
fp8_recipe=None,
num_frames=None,
temporal_patch_size=1,
)#

Split patched vision input across CP ranks.

global_packed_seq_params provides per-image seqlens; the split respects them so each rank owns an integer number of images. When temporal_patch_size > 1, splits also respect tubelet boundaries and num_frames is required.

Parameters:
  • global_t[1, total_patches, C * patch_dim * patch_dim] patched tokens (pre-embedder). The last dim must equal 3 * patch_dim * patch_dim.

  • global_imgs_sizes[num_imgs, 2] per-image (H, W) in pixels.

  • global_packed_seq_paramsPackedSeqParams with per-image cu_seqlens_q.

  • patch_dim – Patch size of the vision backbone (e.g. 14 for SigLIP, 16 for many ViTs). Required because dummy padding tensors are sized in patch units and the default would silently mismatch some backbones.

  • fp8_enabled – If True, pad each rank’s local sequence to the FP8 multiple (16 by default; 32 for mxfp8).

  • fp8_recipe – Forwarded to :func:get_padding so the FP8 padding multiple matches the active recipe.

  • num_frames – Per-media frame count, required when temporal_patch_size > 1.

  • temporal_patch_size – Tubelet size for temporal compression.

Returns:

(local_t, local_imgs_sizes, local_packed_seq_params, has_padding, num_padded_ranks, local_num_frames)