image.embedders.base
#
Module Contents#
Classes#
An abstract base class for generating image embeddings. |
API#
- class image.embedders.base.ImageEmbedder(
- model_name: str,
- image_embedding_column: str,
- classifiers: collections.abc.Iterable[nemo_curator.image.classifiers.ImageClassifier],
Bases:
abc.ABC
An abstract base class for generating image embeddings.
Subclasses only need to define how a model is loaded and a dataset is read in from a tar file shard. This class handles distributing the tasks across workers and saving the metadata to the dataset. The embedding model must be able to fit onto a single GPU.
Initialization
Constructs an image embedder.
Args: model_name (str): A unqiue name to identify the model on each worker and in the logs. image_embedding_column (str): The column name to be added where the image embeddings will be saved. classifiers (Iterable[ImageClassifier]): A collection of classifiers. If the iterable has a nonzero length, all classifiers will be loaded on the GPU at the same time and be passed the image embeddings immediately after they are created.
- abstractmethod load_dataset_shard(tar_path: str) collections.abc.Iterable #
Loads images and metadata from a tarfile in the dataset.
Args: tar_path (str): The path to a tar file shard in the input WebDataset.
Returns: Iterable: An iterator over the dataset. Each iteration should produce a tuple of (image, metadata) pairs. The batch of images will be passed directly to the model created by ImageEmbedder.load_embedding_model. The metadata must be a list of dictionaries. Each element of the list must correspond to the image in the batch at the same position. Each dictionary must contain a field that is the same as id_field in the dataset. This ID field in the metadata will be used to match the image to the its record in the metadata (Parquet) files.
- abstractmethod load_embedding_model(device: str) collections.abc.Callable #
Loads the model used to generate image embeddings.
Args: device (str): A PyTorch device identifier that specifies what GPU to load the model on.
Returns: Callable: A callable model, usually a torch.nn.Module. The input to this model will be the batches of images output by the ImageEmbedder.load_dataset_shard.