Image Embedders#
Overview#
Many image curation features in NeMo Curator operate on image embeddings instead of images directly. Image embedders provide a scalable way of generating embeddings for each image in the dataset.
Use Cases#
Aesthetic and NSFW classification both use image embeddings generated from OpenAI’s CLIP ViT-L variant.
Semantic deduplication computes the similarity of datapoints.
Prerequisites#
Make sure you check out the image curation getting started page to install everything you will need.
Timm Image Embedder#
PyTorch Image Models (timm) is a library containing SOTA computer vision models. Many of these models are useful in generating image embeddings for modules in NeMo Curator.
from nemo_curator import get_client
from nemo_curator.datasets import ImageTextPairDataset
from nemo_curator.image.embedders import TimmImageEmbedder
client = get_client(cluster_type="gpu")
dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key")
embedding_model = TimmImageEmbedder(
"vit_large_patch14_clip_quickgelu_224.openai",
pretrained=True,
batch_size=1024,
num_threads_per_worker=16,
normalize_embeddings=True,
)
dataset_with_embeddings = embedding_model(dataset)
# Metadata will have a new column named "image_embedding"
dataset_with_embeddings.save_metadata()
Here, we load a dataset in and compute the image embeddings using vit_large_patch14_clip_quickgelu_224.openai
.
At the end of the process, our metadata files have a new column named “image_embedding” that contains the image embedddings for each datapoint.
Key Parameters#
pretrained=True
ensures you download the pretrained weights of the model.batch_size=1024
determines the number of images processed on each individual GPU at once.num_threads_per_worker=16
determines the number of threads used by DALI for dataloading.normalize_embeddings=True
will normalize each embedding. NeMo Curator’s classifiers expect normalized embeddings as input.
Performance Considerations#
Under the hood, the image embedding model performs the following operations:
Download the weights of the model.
Download the PyTorch image transformations (resize and center-crop for example).
Convert the PyTorch image transformations to DALI transformations.
Load a shard of metadata (a
.parquet
file) onto each GPU you have available using Dask-cuDF.Load a copy of the model onto each GPU.
Repeatedly load images into batches of size
batch_size
onto each GPU with a given threads per worker (num_threads_per_worker
) using DALI.The model is run on the batch (without
torch.autocast()
sinceautocast=False
).The output embeddings of the model are normalized since
normalize_embeddings=True
.
There are a couple of key performance considerations from this flow.
You must have an NVIDIA GPU that mets the requirements.
You can create
.idx
files in the same directory of the tar files to speed up dataloading times. See the DALI documentation for more information.
Custom Image Embedder#
To write your own custom embedder, you inherit from nemo_curator.image.embedders.ImageEmbedder
and override two methods as shown below:
from nemo_curator.image.embedders import ImageEmbedder
class MyCustomEmbedder(ImageEmbedder):
def load_dataset_shard(self, tar_path: str) -> Iterable:
# Implement me!
pass
def load_embedding_model(self, device: str) -> Callable:
# Implement me!
pass
load_dataset_shard()
will take in a path to a tar file and return an iterable over the shard. The iterable should return a tuple of(a batch of data, metadata)
. The batch of data can be of any form. It will be directly passed to the model returned byload_embedding_model()
. The metadata should be a dictionary of metadata, with a field corresponding to theid_col
of the dataset. In our example, the metadata should include a value for"key"
.load_embedding_model()
will take a device and return a callable object. This callable will take as input a batch of data produced byload_dataset_shard()
.