image.classifiers.base#

Module Contents#

Classes#

ImageClassifier

An abstract base class that represents a classifier on top of embeddings generated by a CLIP vision encoder.

API#

class image.classifiers.base.ImageClassifier(
model_name: str,
embedding_column: str,
pred_column: str,
pred_type: str | type,
batch_size: int,
embedding_size: int,
)#

Bases: abc.ABC

An abstract base class that represents a classifier on top of embeddings generated by a CLIP vision encoder.

Subclasses only need to define how a model is loaded. They may also override the postprocess method if they would like to modify output series of predictions before it gets combined into the dataset. The classifier must be able to fit on a single GPU.

Initialization

Constructs an image classifier.

Args: model_name (str): A unqiue name to identify the model on each worker and in the logs. embedding_column (str): The column name that stores the image embeddings. pred_column (str): The column name to be added where the classifier’s predictions will be stored. pred_type (Union[str, type]): The datatype of the pred_column. batch_size (int): If greater than 0, the image embeddings will be processed in batches of at most this size. If less than 0, all embeddings will be processed at once.

abstractmethod load_model(device: str) collections.abc.Callable#

Loads the classifier model.

Args: device (str): A PyTorch device identifier that specifies what GPU to load the model on.

Returns: Callable: A callable model, usually a torch.nn.Module. The input to this model will be the batches of images output by the ImageEmbedder.load_dataset_shard.

postprocess(series: cudf.Series) cudf.Series#

Postprocesses the predictions of the classifier before saving them to the metadata.

Args: series (cudf.Series): The cuDF series of raw model predictions.

Returns: cudf.Series: The same series unmodified. Override in your classifier if needed.