TimmImageEmbedder#
The TimmImageEmbedder
is the primary image embedding option in NeMo Curator. It leverages the PyTorch Image Models (timm) library to provide access to a wide range of state-of-the-art computer vision models for generating image embeddings. These embeddings are used for downstream tasks such as classification, filtering, and duplicate removal.
How it Works#
The TimmImageEmbedder
generates image embeddings in the following steps:
Model Selection: You specify a model name from the
timm
library (for example, CLIP, ViT, ResNet). The embedder loads the model, optionally with pretrained weights.Data Loading and Preprocessing: The embedder uses NVIDIA DALI for efficient, GPU-accelerated loading and preprocessing of images from WebDataset tar files. DALI-generated
.idx
files can be used to further speed up loading.Batch Processing: Images are processed in batches, with the batch size and number of data loading threads configurable for optimal GPU utilization.
Embedding Generation: Each batch of images is passed through the selected model to generate embeddings. Optionally, embeddings are normalized (recommended for most downstream tasks).
Distributed Execution: The process can be distributed across multiple GPUs or nodes, enabling large-scale embedding generation.
Output: The resulting embeddings are added as a new column (default:
image_embedding
) in the dataset metadata, which can then be saved for downstream tasks such as classification, filtering, or duplicate removal.
Usage#
Python Example#
from nemo_curator import get_client
from nemo_curator.datasets import ImageTextPairDataset
from nemo_curator.image.embedders import TimmImageEmbedder
client = get_client(cluster_type="gpu")
dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key")
embedding_model = TimmImageEmbedder(
"vit_large_patch14_clip_quickgelu_224.openai", # Any timm model name
pretrained=True,
batch_size=1024,
num_threads_per_worker=16,
normalize_embeddings=True,
)
dataset_with_embeddings = embedding_model(dataset)
# Metadata will have a new column named "image_embedding"
dataset_with_embeddings.save_metadata()
Key Parameters#
model_name
: Name of the timm model to use (seetimm.list_models()
)pretrained
: Whether to use pretrained weights (recommended)batch_size
: Number of images processed per GPU per batchnum_threads_per_worker
: Number of threads for DALI data loadingnormalize_embeddings
: Normalize output embeddings (required for most classifiers)autocast
: Use mixed precision for faster inference (default: True)use_index_files
: Use DALI-generated.idx
files for faster loading
Performance Tips#
Use
.idx
files in your dataset directory to speed up DALI data loading (DALI docs)Choose a batch size that fits your GPU memory for optimal throughput
Use multiple GPUs for distributed embedding generation
Best Practices#
Always use pretrained weights unless you have a specific reason to train from scratch
Normalize embeddings if you plan to use them for classification or similarity search
Monitor GPU memory usage and adjust
batch_size
accordinglyFor large datasets, generate and use
.idx
files to accelerate data loadingReview the output embeddings for expected dimensionality and normalization