image.embedders.timm#

Module Contents#

Classes#

TimmImageEmbedder

PyTorch Image Models (timm) is a library containing SOTA computer vision models. Many of these models are useful in generating image embeddings for modules in NeMo Curator. This module can also automatically convert the image transformations from PyTorch transformations to DALI transformations in the supported models.

API#

class image.embedders.timm.TimmImageEmbedder(
model_name: str,
pretrained: bool = False,
batch_size: int = 1,
num_threads_per_worker: int = 4,
image_embedding_column: str = 'image_embedding',
normalize_embeddings: bool = True,
classifiers: collections.abc.Iterable = [],
autocast: bool = True,
use_index_files: bool = False,
)#

Bases: nemo_curator.image.embedders.base.ImageEmbedder

PyTorch Image Models (timm) is a library containing SOTA computer vision models. Many of these models are useful in generating image embeddings for modules in NeMo Curator. This module can also automatically convert the image transformations from PyTorch transformations to DALI transformations in the supported models.

Initialization

Constructs the embedder.

Args: model_name (str): The timm model to use. A list of available models can be found by running timm.list_models() pretrained (bool): If True, loads the pretrained weights of the model. batch_size (int): The number of images to run inference on in a single batch. If the batch_size is larger than the number of elements in a shard, only the number of elements in a shard will be used. num_threads_per_worker (int): The number of threads per worker (GPU) to use for loading images with DALI. image_embedding_column (str): The output column where the embeddings will be stored in the dataset. normalize_embeddings (bool): Whether to normalize the embeddings output by the model. Defaults to True. classifiers (Iterable): A collection of classifiers to immediately apply on top of the image embeddings. autocast (bool): If True, runs the timm model using torch.autocast(). use_index_files (bool): If True, tries to find and use index files generated by DALI at the same path as the tar file shards. The index files must be generated by DALI’s wds2idx tool. See https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/dataloading_webdataset.html#Creating-an-index for more information. Each index file must be of the form “shard_id.idx” where shard_id is the same integer as the corresponding tar file for the data. The index files must be in the same folder as the tar files.

load_dataset_shard(tar_path: str) collections.abc.Iterable#

Loads a WebDataset tar shard using DALI.

Args: tar_path (str): The path of the tar shard to load.

Returns: Iterable: An iterator over the dataset. Each tar file must have 3 files per record: a .jpg file, a .txt file, and a .json file. The .jpg file must contain the image, the .txt file must contain the associated caption, and the .json must contain the metadata for the record (including its ID). Images will be loaded using DALI.

load_embedding_model(device: str = 'cuda') torch.nn.Module#

Loads the model used to generate image embeddings.

Args: device (str): A PyTorch device identifier that specifies what GPU to load the model on.

Returns: Callable: A timm model loaded on the specified device. The model’s forward call may be augmented with torch.autocast() or embedding normalization if specified in the constructor.