TimmImageEmbedder#
The TimmImageEmbedder is the primary image embedding option in NeMo Curator. It leverages the PyTorch Image Models (timm) library to provide access to a wide range of state-of-the-art computer vision models for generating image embeddings. These embeddings are used for downstream tasks such as classification, filtering, and duplicate removal.
How it Works#
The TimmImageEmbedder generates image embeddings in the following steps:
- Model Selection: You specify a model name from the - timmlibrary (for example, CLIP, ViT, ResNet). The embedder loads the model, optionally with pretrained weights.
- Data Loading and Preprocessing: The embedder uses NVIDIA DALI for efficient, GPU-accelerated loading and preprocessing of images from WebDataset tar files. DALI-generated - .idxfiles can be used to further speed up loading.
- Batch Processing: Images are processed in batches, with the batch size and number of data loading threads configurable for optimal GPU utilization. 
- Embedding Generation: Each batch of images is passed through the selected model to generate embeddings. Optionally, embeddings are normalized (recommended for most downstream tasks). 
- Distributed Execution: The process can be distributed across multiple GPUs or nodes, enabling large-scale embedding generation. 
- Output: The resulting embeddings are added as a new column (default: - image_embedding) in the dataset metadata, which can then be saved for downstream tasks such as classification, filtering, or duplicate removal.
Usage#
Python Example#
from nemo_curator import get_client
from nemo_curator.datasets import ImageTextPairDataset
from nemo_curator.image.embedders import TimmImageEmbedder
client = get_client(cluster_type="gpu")
dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key")
embedding_model = TimmImageEmbedder(
    "vit_large_patch14_clip_quickgelu_224.openai",  # Any timm model name
    pretrained=True,
    batch_size=1024,
    num_threads_per_worker=16,
    normalize_embeddings=True,
)
dataset_with_embeddings = embedding_model(dataset)
# Metadata will have a new column named "image_embedding"
dataset_with_embeddings.save_metadata()
Key Parameters#
- model_name: Name of the timm model to use (see- timm.list_models())
- pretrained: Whether to use pretrained weights (recommended)
- batch_size: Number of images processed per GPU per batch
- num_threads_per_worker: Number of threads for DALI data loading
- normalize_embeddings: Normalize output embeddings (required for most classifiers)
- autocast: Use mixed precision for faster inference (default: True)
- use_index_files: Use DALI-generated- .idxfiles for faster loading
Performance Tips#
- Use - .idxfiles in your dataset directory to speed up DALI data loading (DALI docs)
- Choose a batch size that fits your GPU memory for optimal throughput 
- Use multiple GPUs for distributed embedding generation 
Best Practices#
- Always use pretrained weights unless you have a specific reason to train from scratch 
- Normalize embeddings if you plan to use them for classification or similarity search 
- Monitor GPU memory usage and adjust - batch_sizeaccordingly
- For large datasets, generate and use - .idxfiles to accelerate data loading
- Review the output embeddings for expected dimensionality and normalization