image.embedders.base#
Module Contents#
Classes#
| An abstract base class for generating image embeddings. | 
API#
- class image.embedders.base.ImageEmbedder(
- model_name: str,
- image_embedding_column: str,
- classifiers: collections.abc.Iterable[nemo_curator.image.classifiers.ImageClassifier],
- Bases: - abc.ABC- An abstract base class for generating image embeddings. - Subclasses only need to define how a model is loaded and a dataset is read in from a tar file shard. This class handles distributing the tasks across workers and saving the metadata to the dataset. The embedding model must be able to fit onto a single GPU. - Initialization - Constructs an image embedder. - Args: model_name (str): A unqiue name to identify the model on each worker and in the logs. image_embedding_column (str): The column name to be added where the image embeddings will be saved. classifiers (Iterable[ImageClassifier]): A collection of classifiers. If the iterable has a nonzero length, all classifiers will be loaded on the GPU at the same time and be passed the image embeddings immediately after they are created. - abstract load_dataset_shard(tar_path: str) collections.abc.Iterable#
- Loads images and metadata from a tarfile in the dataset. - Args: tar_path (str): The path to a tar file shard in the input WebDataset. - Returns: Iterable: An iterator over the dataset. Each iteration should produce a tuple of (image, metadata) pairs. The batch of images will be passed directly to the model created by ImageEmbedder.load_embedding_model. The metadata must be a list of dictionaries. Each element of the list must correspond to the image in the batch at the same position. Each dictionary must contain a field that is the same as id_field in the dataset. This ID field in the metadata will be used to match the image to the its record in the metadata (Parquet) files. 
 - abstract load_embedding_model(device: str) collections.abc.Callable#
- Loads the model used to generate image embeddings. - Args: device (str): A PyTorch device identifier that specifies what GPU to load the model on. - Returns: Callable: A callable model, usually a torch.nn.Module. The input to this model will be the batches of images output by the ImageEmbedder.load_dataset_shard.