core.models.multimodal.llava_model#
Module Contents#
Classes#
LLaVA multi-modal model. |
Functions#
Hook to ignore missing keys during checkpoint loading. |
|
Hook to ignore Transformer Engine _extra_state used for FP8. |
|
Pixel shuffle based on InternVL but adapted for our use case. |
Data#
API#
- core.models.multimodal.llava_model.IGNORE_INDEX#
None
- core.models.multimodal.llava_model.DEFAULT_IMAGE_TOKEN_INDEX#
None
- core.models.multimodal.llava_model.IMAGE_TOKEN#
‘
’
- core.models.multimodal.llava_model.VIDEO_TOKEN#
‘
- class core.models.multimodal.llava_model.LLaVAModel(
- language_transformer_config: megatron.core.transformer.transformer_config.TransformerConfig,
- language_transformer_layer_spec: megatron.core.transformer.spec_utils.ModuleSpec,
- language_vocab_size: int,
- language_max_sequence_length: int,
- vision_transformer_config: megatron.core.transformer.transformer_config.TransformerConfig,
- vision_transformer_layer_spec: megatron.core.transformer.spec_utils.ModuleSpec,
- drop_vision_class_token: bool,
- vision_projection_config: megatron.core.transformer.transformer_config.TransformerConfig,
- vision_projection_layer_spec: megatron.core.transformer.spec_utils.ModuleSpec,
- vision_projection_type: str = 'mlp',
- allow_missing_vision_projection_checkpoint: bool = False,
- parallel_output: bool = True,
- share_embeddings_and_output_weights: bool = False,
- language_position_embedding_type: str = 'learned_absolute',
- language_rotary_percent: float = 1.0,
- pre_process: bool = True,
- post_process: bool = True,
- add_encoder: bool = True,
- add_decoder: bool = True,
- img_h: int = 336,
- img_w: int = 336,
- patch_dim: int = 14,
- language_rotary_base: int = 10000,
- language_rope_scaling: bool = False,
- language_rope_scaling_factor: float = 8.0,
- hybrid_attention_ratio: float = 1.0,
- hybrid_mlp_ratio: float = 1.0,
- hybrid_override_pattern: str = None,
- fp16_lm_cross_entropy: bool = False,
- image_token_index: int = DEFAULT_IMAGE_TOKEN_INDEX,
- pixel_shuffle: bool = False,
- tile_tags: Optional[list] = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
- max_num_tiles: int = 0,
- tokenizer_type: str = '',
- vp_stage: Optional[int] = None,
- use_vision_backbone_fp8_arch: bool = False,
Bases:
megatron.core.transformer.MegatronModuleLLaVA multi-modal model.
- Parameters:
language_transformer_config (TransformerConfig) – Transformer config for the language model.
language_transformer_layer_spec (ModuleSpec) – Language model spec.
language_vocab_size (int) – Language model vocabulary size.
language_max_sequence_length (int) – Language model maximum sequence length.
vision_transformer_config (TransformerConfig) – Transformer config for the vision model.
vision_transformer_layer_spec (ModuleSpec) – Vision model spec.
drop_vision_class_token (bool) – Drop vision class token(s) before the language model.
vision_projection_config (TransformerConfig) – Vision projection config.
vision_projection_layer_spec (ModuleSpec) – Vision projection spec.
vision_projection_type (str) – Type of the vision projection. Default: 2-layer MLP.
allow_missing_vision_projection_checkpoint (bool) – Allow vision projection weights to be missing when loading a checkpoint. Default False.
parallel_output (bool) – Keep outputs split across tensor parallel ranks. This is typically True for training and False for inference.
share_embeddings_and_output_weights (bool) – Input embedding and output layer share weights.
language_position_embedding_type (str) – Language model position embedding type.
language_rotary_percent (float) – RoPE percent. Defaults to 1.0.
pre_process (bool) – Include embedding layer in the decoder (used with pipeline parallel).
post_process (bool) – Include output layer in the decoder (used with pipeline parallel).
add_encoder (bool) – Construct the encoder (used with pipeline parallel). When we use pipelining, the encoder will live on only the first stage
add_decoder (bool) – Construct the decoder (used with pipeline parallel). When we use pipelining, the decoder will live on every stage after the first one.
img_h (int) – Input image height.
img_w (int) – Input image width.
patch_dim (int) – The size of each image patch side.
language_rotary_base (int) – RoPE base.
language_rope_scaling (bool) – Toggle RoPE scaling.
language_rope_scaling_factor (float) – RoPE scaling factor. Defaults to 8.
image_token_index (int) – Token ID for image token such as
. pixel_shuffle (bool) – Enable pixel shuffle.
tile_tags (list) – Optional tile tags.
pg_collection (ProcessGroupCollection) – Model communication process groups.
vp_stage (int) – Virtual pipeline stage.
Initialization
This is a convenience method to surface the language model’s word embeddings, which is necessary for
finalize_model_grads._allreduce_word_embedding_grads.
- set_input_tensor(input_tensor) None#
Set model chunk input tensor.
- freeze(
- freeze_language_model: bool,
- freeze_vision_model: bool,
- freeze_vision_projection: bool,
Freeze model modules.
Make specific modules non-trainable by setting requires_grad to False.
- Parameters:
freeze_language_model (bool) – Freeze the language model module.
freeze_vision_model (bool) – Freeze the vision model module.
freeze_vision_projection (bool) – Freeze the vision projection module.
- _preprocess_data(
- image_embeddings,
- language_embeddings,
- input_ids,
- loss_mask,
- labels,
- use_inference_kv_cache,
- inference_context,
- image_token_index,
- num_image_tiles,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Preprocess input data before input to language model.
This function is adopted from https://github.com/huggingface/transformers/blob/85817d98fb60977c97e3014196a462b732d2ed1a/src/transformers/models/llava_next/modeling_llava_next.py#L409 for our input data conventions.
image_token_index = -200 indicates the image position in the input_ids = [0, 1, -200, 2, 3] and labels = [1, -200, 2, 3, 4], for example. We want to replace the image position (-200) with image_embeddings and return the following:
final_embeddings = [0, 1, image_embeddings, 2, 3],
final_labels = [1, -100, 2, 3, 4]
final_loss_mask = [1, 0, 0, 1, 1]
This function handles samples without images (text-only sample). It also handles samples with images that are split into multiples tiles.
If pipeline parallelism is not used, then self.pre_process and self.post_process are both True and we update both input embeddings, labels and loss masks (if available).
If pipeline parallelism is used, then we do the following
the first language model chunk has self.pre_process = True and self.post_process = False. We update input embeddings.
the middle language model chunk(s) has self.pre_process = False and self.post_process = False. We don’t need to update anything.
the last language model chunk has self.pre_process = False and self.post_process = True. We update labels and loss mask.
TODO: This function should adjust the attention mask too. Currently, we assume the language model uses a causal mask.
- Returns:
image and text embeddings [combined_seq_len, b, h]. final_labels (torch.Tensor): labels for image and text positions [b, combined_seq_len]. final_loss_mask (torch.Tensor): loss mask [b, combined_seq_len].
- Return type:
final_embedding (torch.Tensor)
- _process_embedding_token_parallel(
- combined_embeddings,
- new_labels,
- new_loss_mask,
- packed_seq_params,
Processes the input data for model parallelism support.
When using sequence parallelism (SP) or context parallelism (CP), the sequence is sharded across different GPUs. This function performs the sharding and distributes the sequence across GPUs for SP and CP
Context Parallelism is a feature that helps improve memory efficiency for long sequence training by distributing sequence across CP ranks. It requires token length to be divisible by (CP size *2) to ensure proper load balance.
Sequence Parallelism is a feature that helps improve memory efficiency for long sequence training by distributing sequence across TP ranks. It requires token length to be divisible by TP size.
- Returns:
image and text embeddings combined and distributed. new_labels (torch.Tensor): Distributed labels for image and text positions. new_loss_mask (torch.Tensor): Distributed loss mask. packed_seq_params (PackedSeqParams): Dict with padded token information.
- Return type:
combined_embeddings (torch.Tensor)
- _apply_tile_tagging(image_embeddings, num_image_tiles)#
Apply tile tagging.
The image embeddings of multiple tiles are prepended with tile tags such as <tile_1>. This implements the method used in NVLM https://arxiv.org/pdf/2409.11402.
- Parameters:
image_embeddings (torch.Tensor) – [img_seq_len, num_tiles, h_language].
num_image_tiles (torch.Tensor) – Number of tiles for each input image [num_images].
- Returns:
Tile tags prepended to image embeddings. [tile_seq_len (=5) + img_seq_len, num_tiles, h_language]
- Return type:
torch.Tensor
- forward(
- images: torch.Tensor,
- input_ids: torch.Tensor,
- position_ids: torch.Tensor,
- attention_mask: torch.Tensor,
- labels: Optional[torch.Tensor] = None,
- loss_mask: Optional[torch.Tensor] = None,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- num_image_tiles: Optional[List[int]] = None,
- image_token_index: Optional[int] = None,
- runtime_gather_output: Optional[bool] = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Forward function of the LLaVA model.
- Parameters:
images (torch.Tensor) – input images of shape [num_tiles, img_h, img_w]. num_tiles means the number of image tiles in this batch. num_tiles = 0 if the batch doesn’t contain images.
input_ids (torch.Tensor) – input text ids [batch, text_seq_len].
position_ids (torch.Tensor) – input text position ids [batch, text_seq_len].
attention_mask (torch.Tensor) – Language model attention mask [batch, 1, 1, combined_seq_len]. NOTE: attention_mask is typically None and attn_mask_type in layer specs determines the attention mask used.
labels (torch.Tensor) – Optional target text labels [batch, combined_seq_len].
loss_mask (torch.Tensor) – Text loss mask [batch, text_seq_len].
inference_context (BaseInferenceContext) – Inference-time parameters including KV cache.
num_image_tiles (list of int) – Number of tiles per image. Default 1 tile per image.
image_token_index (int) – ID for input images. Default None means
image_token_indexarg in the constructor will be used.runtime_gather_output (bool) – Gather output at runtime. Default None means
parallel_outputarg in the constructor will be used.packed_seq_params (PackedSeqParams) –
If using sequence packing, must contain subsample length information. 2) If using SP/CP with padding mask type, must contain padded token information.
- Returns:
Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size]. loss_mask (torch.Tensor): Loss mask expanded to combined sequence length. Shape [b, s].
- Return type:
output (torch.Tensor)
- core.models.multimodal.llava_model._load_state_dict_hook_ignore_param_names(
- param_names: List[str],
- module: torch.nn.Module,
- incompatible_keys: collections.namedtuple,
Hook to ignore missing keys during checkpoint loading.
By default, this should not be used to avoid accidentally missing weights in checkpoint loading.
Example use case: Use this if you want to load a checkpoint that contains vision and language model weights but not the vision projection weights.
- Parameters:
param_names (list str) – Parameter names allowed to be missing when calling load_state_dict.
module (torch.nn.Module) – The torch module this hook applies to. Required by the torch API.
incompatible_keys (namedtuple) – Namedtuple with fields missing_keys and unexpected_keys, which collect the missing and unexpected keys, respectively.
- core.models.multimodal.llava_model._load_state_dict_hook_ignore_extra_state(
- module: torch.nn.Module,
- incompatible_keys: collections.namedtuple,
Hook to ignore Transformer Engine _extra_state used for FP8.
This is for backwards-compatibility. Newer TE versions add _extra_state keys to the state dict, while older models might not have those keys. Those keys can be ignored when not using FP8.
- Parameters:
module (torch.nn.Module) – The torch module this hook applies to. Required by the torch API.
incompatible_keys (namedtuple) – Namedtuple with fields missing_keys and unexpected_keys, which collect the missing and unexpected keys, respectively.
- core.models.multimodal.llava_model.pixel_shuffle(x, scale_factor=0.5, version=2)#
Pixel shuffle based on InternVL but adapted for our use case.
- Parameters:
x (torch.Tensor) – Vision model outputs [num_tiles, img_seq_len, h_vision]
version (int) – Implementation version.
- Returns:
Shuffled vision model outputs [num_tiles, (sq ** 2) * (scale ** 2), h_vision / (scale ** 2)]