core.models.multimodal.context_parallel#
Multimodal Sequence Parallel (SP) and Context Parallel (CP) functionality.
Module Contents#
Classes#
Gather the input from context parallel ranks. |
Functions#
Calculate padding needed for SP, CP, TP comm overlap, and FP8. |
|
Get PackedSeqParams for CP. |
|
Split the tensor global_t into context parallel world size parts. |
|
Gather |
|
Gather dynamic-resolution tensors (variable seq per rank) from CP ranks. |
|
Compute frame-space split points that respect tubelet boundaries within videos. |
|
Return per-media frame counts clipped to the frame range |
|
Split patched vision input across CP ranks. |
API#
- core.models.multimodal.context_parallel.get_padding(
- seq_len,
- cp_size,
- tp_size,
- has_sp,
- decoder_tp_comm_overlap=False,
- decoder_seq_len=None,
- fp8_enabled=False,
- fp8_recipe=None,
Calculate padding needed for SP, CP, TP comm overlap, and FP8.
- Parameters:
seq_len (int) – Model sequence length.
cp_size (int) – Context parallel size.
tp_size (int) – Tensor parallel size.
has_sp (bool) – Model uses sequence parallelism.
decoder_tp_comm_overlap (bool) – Decoder (LLM) uses tensor parallel communication overlap.
decoder_seq_len (int) – Decoder (LLM) maximum sequence length.
fp8_enabled (bool) – FP8 is enabled.
fp8_recipe (str) – FP8 recipe. Affects required padding.
- Returns:
Padding needed given model configuration.
- Return type:
padding (int)
- core.models.multimodal.context_parallel.get_packed_seq_params(
- tokens,
- img_seq_len,
- padding_needed,
- cp_size,
- use_packed_sequence=False,
Get PackedSeqParams for CP.
- Parameters:
tokens (torch.Tensor) – [batch, seq_len] input tokens.
img_seq_len (int) – Image sequence length.
padding_needed (int) – Padding to add.
cp_size (int) – Context parallel size.
use_packed_sequence (bool) – Uses sequence packing.
- Returns:
Parameters to be sent to Transformer Engine.
- Return type:
packed_seq_params (PackedSeqParams)
- core.models.multimodal.context_parallel.split_to_context_parallel_ranks(global_t, pad_value=0)#
Split the tensor global_t into context parallel world size parts.
- Parameters:
global_t – [batch, …]
pad_value – Value to pad the last rank with.
- Returns:
[samples_per_rank, …]. samples_per_rank is the # of samples per CP rank. global_pad: Total padding to have equal samples_per_rank across context parallel ranks.
- Return type:
local_t
- core.models.multimodal.context_parallel._gather_along_second_dim(local_t)#
- core.models.multimodal.context_parallel._reduce_scatter_along_second_dim(global_t)#
- class core.models.multimodal.context_parallel.GatherFromContextParallelRanks#
Bases:
torch.autograd.FunctionGather the input from context parallel ranks.
- static symbolic(graph, input_)#
Symbolic forward used during
torch.jittracing.
- static forward(ctx, input_)#
All-gather
input_along its second dimension across CP ranks.
- static backward(ctx, grad_output)#
Reduce-scatter the gradient along the second dimension.
- core.models.multimodal.context_parallel.gather_from_context_parallel_ranks(local_t, global_pad)#
Gather
local_tacross CP ranks, removingglobal_padtrailing pad tokens.
- core.models.multimodal.context_parallel.gather_from_context_parallel_ranks_dynamic_res(
- local_t,
- num_padded_imgs=0,
Gather dynamic-resolution tensors (variable seq per rank) from CP ranks.
- core.models.multimodal.context_parallel._compute_tubelet_aware_split_points(
- num_frames,
- temporal_patch_size,
- cp_size,
- total_frames,
Compute frame-space split points that respect tubelet boundaries within videos.
Returns
cp_size + 1split points in frame indices (not tubelet indices), since callers slice per-framecu_seqlensandimgs_sizeswith these bounds. Splits land on either media boundaries or tubelet boundaries inside a media so that no rank receives a partial tubelet.
- core.models.multimodal.context_parallel._split_num_frames(num_frames, lb, ub)#
Return per-media frame counts clipped to the frame range
[lb, ub).lbandubare frame indices (the same coordinate system used by- Func:
_compute_tubelet_aware_split_pointsand the per-frameseqlensarray in :func:split_to_context_parallel_ranks_dynamic_res). The returned list has one entry per media that contributes at least one frame to the range, with the value being the number of frames of that media in the range.
- core.models.multimodal.context_parallel.split_to_context_parallel_ranks_dynamic_res(
- global_t,
- global_imgs_sizes,
- global_packed_seq_params,
- *,
- patch_dim,
- fp8_enabled=False,
- fp8_recipe=None,
- num_frames=None,
- temporal_patch_size=1,
Split patched vision input across CP ranks.
global_packed_seq_paramsprovides per-image seqlens; the split respects them so each rank owns an integer number of images. Whentemporal_patch_size > 1, splits also respect tubelet boundaries andnum_framesis required.- Parameters:
global_t –
[1, total_patches, C * patch_dim * patch_dim]patched tokens (pre-embedder). The last dim must equal3 * patch_dim * patch_dim.global_imgs_sizes –
[num_imgs, 2]per-image (H, W) in pixels.global_packed_seq_params –
PackedSeqParamswith per-imagecu_seqlens_q.patch_dim – Patch size of the vision backbone (e.g. 14 for SigLIP, 16 for many ViTs). Required because dummy padding tensors are sized in patch units and the default would silently mismatch some backbones.
fp8_enabled – If True, pad each rank’s local sequence to the FP8 multiple (16 by default; 32 for
mxfp8).fp8_recipe – Forwarded to :func:
get_paddingso the FP8 padding multiple matches the active recipe.num_frames – Per-media frame count, required when
temporal_patch_size > 1.temporal_patch_size – Tubelet size for temporal compression.
- Returns:
(local_t, local_imgs_sizes, local_packed_seq_params, has_padding, num_padded_ranks, local_num_frames)