core.models.vision.clip_vit_model#

Module Contents#

Classes#

CLIPViTModel

CLIP ViT vision model.

Functions#

get_num_image_embeddings

Get the number of image embeddings per image tile.

API#

class core.models.vision.clip_vit_model.CLIPViTModel(
transformer_config: megatron.core.transformer.transformer_config.TransformerConfig,
transformer_layer_spec: megatron.core.transformer.spec_utils.ModuleSpec,
ln_pre_impl: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = NORM_IMPL,
ln_post_impl: Union[megatron.core.transformer.spec_utils.ModuleSpec, type] = NORM_IMPL,
add_class_token: bool = True,
class_token_len: int = 1,
patch_dim: int = 14,
img_h: int = 336,
img_w: int = 336,
model_subtype: str = 'clip',
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
vp_stage: Optional[int] = None,
)#

Bases: megatron.core.models.common.vision_module.vision_module.VisionModule

CLIP ViT vision model.

Parameters:
  • transformer_config (TransformerConfig) – Transformer config.

  • transformer_layer_spec (ModuleSpec) – Specifies module to use for transformer layers.

  • ln_pre_impl (ModuleSpec or type) – Specifies the layer norm type to use for ln_pre.

  • add_class_token (bool, optional) – Include a class token. Defaults to True.

  • class_token_len (int) – Class token length. Defaults to 1 but 8 may be faster.

  • patch_dim (int) – Image patch size.

  • img_h (int) – Input image height.

  • img_w (int) – Input image width.

  • pg_collection (ProcessGroupCollection) – Model communication process groups

  • vp_stage (int) – Virtual pipeline stage

Initialization

set_input_tensor(input_tensor: torch.Tensor) None#

Sets input tensor to the model.

Parameters:

input_tensor (Tensor) – Sets the input tensor for the model.

forward(
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) torch.Tensor#

Forward function of the CLIP ViT Model. This function passes the input tensors through the embedding layer and then the transformer.

Parameters:
  • x (torch.Tensor) – input data of shape [batch, img_h, img_w]

  • attention_mask (torch.Tensor with dtype=bool) – Attention mask to use.

Returns:

output after final transformer block of shape [b, s, h].

Return type:

x (torch.Tensor)

core.models.vision.clip_vit_model.get_num_image_embeddings(
img_h,
img_w,
patch_dim,
vision_model_type,
disable_vision_class_token,
class_token_len,
pixel_shuffle,
use_tile_tags=False,
max_num_tiles=0,
tokenizer_type=None,
)#

Get the number of image embeddings per image tile.