Important
You are viewing the NeMo 2.0 documentation. This release introduces significant changes to the API and a new library, NeMo Run. We are currently porting all features from NeMo 1.0 to 2.0. For documentation on previous versions or features not yet available in 2.0, please refer to the NeMo 24.07 documentation.
Embedders#
Base Class#
- class nemo_curator.image.embedders.ImageEmbedder(
- model_name: str,
- image_embedding_column: str,
- classifiers: Iterable[ImageClassifier],
An abstract base class for generating image embeddings.
Subclasses only need to define how a model is loaded and a dataset is read in from a tar file shard. This class handles distributing the tasks across workers and saving the metadata to the dataset. The embedding model must be able to fit onto a single GPU.
- abstract load_dataset_shard(tar_path: str) Iterable #
Loads images and metadata from a tarfile in the dataset.
- Parameters:
tar_path (str) – The path to a tar file shard in the input WebDataset.
- Returns:
- An iterator over the dataset. Each iteration should produce
a tuple of (image, metadata) pairs. The batch of images will be passed directly to the model created by ImageEmbedder.load_embedding_model. The metadata must be a list of dictionaries. Each element of the list must correspond to the image in the batch at the same position. Each dictionary must contain a field that is the same as id_field in the dataset. This ID field in the metadata will be used to match the image to the its record in the metadata (Parquet) files.
- Return type:
Iterable
- abstract load_embedding_model(device: str) Callable #
Loads the model used to generate image embeddings.
- Parameters:
device (str) – A PyTorch device identifier that specifies what GPU to load the model on.
- Returns:
- A callable model, usually a torch.nn.Module.
The input to this model will be the batches of images output by the ImageEmbedder.load_dataset_shard.
- Return type:
Callable
Timm#
- class nemo_curator.image.embedders.TimmImageEmbedder(
- model_name: str,
- pretrained: bool = False,
- batch_size: int = 1,
- num_threads_per_worker: int = 4,
- image_embedding_column: str = 'image_embedding',
- normalize_embeddings: bool = True,
- classifiers: Iterable = [],
- autocast: bool = True,
- use_index_files: bool = False,
PyTorch Image Models (timm) is a library containing SOTA computer vision models. Many of these models are useful in generating image embeddings for modules in NeMo Curator. This module can also automatically convert the image transformations from PyTorch transformations to DALI transformations in the supported models.
- load_dataset_shard(tar_path: str)#
Loads a WebDataset tar shard using DALI.
- Parameters:
tar_path (str) – The path of the tar shard to load.
- Returns:
An iterator over the dataset. Each tar file must have 3 files per record: a .jpg file, a .txt file, and a .json file. The .jpg file must contain the image, the .txt file must contain the associated caption, and the .json must contain the metadata for the record (including its ID). Images will be loaded using DALI.
- Return type:
Iterable
- load_embedding_model(device='cuda')#
Loads the model used to generate image embeddings.
- Parameters:
device (str) – A PyTorch device identifier that specifies what GPU to load the model on.
- Returns:
- A timm model loaded on the specified device.
The model’s forward call may be augmented with torch.autocast() or embedding normalization if specified in the constructor.
- Return type:
Callable