image.embedders.base#

Module Contents#

Classes#

ImageEmbedder

An abstract base class for generating image embeddings.

API#

class image.embedders.base.ImageEmbedder(
model_name: str,
image_embedding_column: str,
classifiers: collections.abc.Iterable[nemo_curator.image.classifiers.ImageClassifier],
)#

Bases: abc.ABC

An abstract base class for generating image embeddings.

Subclasses only need to define how a model is loaded and a dataset is read in from a tar file shard. This class handles distributing the tasks across workers and saving the metadata to the dataset. The embedding model must be able to fit onto a single GPU.

Initialization

Constructs an image embedder.

Args: model_name (str): A unqiue name to identify the model on each worker and in the logs. image_embedding_column (str): The column name to be added where the image embeddings will be saved. classifiers (Iterable[ImageClassifier]): A collection of classifiers. If the iterable has a nonzero length, all classifiers will be loaded on the GPU at the same time and be passed the image embeddings immediately after they are created.

abstractmethod load_dataset_shard(tar_path: str) collections.abc.Iterable#

Loads images and metadata from a tarfile in the dataset.

Args: tar_path (str): The path to a tar file shard in the input WebDataset.

Returns: Iterable: An iterator over the dataset. Each iteration should produce a tuple of (image, metadata) pairs. The batch of images will be passed directly to the model created by ImageEmbedder.load_embedding_model. The metadata must be a list of dictionaries. Each element of the list must correspond to the image in the batch at the same position. Each dictionary must contain a field that is the same as id_field in the dataset. This ID field in the metadata will be used to match the image to the its record in the metadata (Parquet) files.

abstractmethod load_embedding_model(device: str) collections.abc.Callable#

Loads the model used to generate image embeddings.

Args: device (str): A PyTorch device identifier that specifies what GPU to load the model on.

Returns: Callable: A callable model, usually a torch.nn.Module. The input to this model will be the batches of images output by the ImageEmbedder.load_dataset_shard.