image.classifiers.nsfw#

Module Contents#

Classes#

NSFWModel

Normalization

NsfwClassifier

NSFW Classifier is a small MLP trained on top of OpenAI’s ViT-L CLIP image embeddings. It is used to assess the likelihood of images containing sexually explicit material. More information on the model can be found here: https://github.com/LAION-AI/CLIP-based-NSFW-Detector.

API#

class image.classifiers.nsfw.NSFWModel#

Bases: torch.nn.Module

Initialization

forward(x: torch.Tensor) torch.Tensor#
class image.classifiers.nsfw.Normalization(shape: list[int])#

Bases: torch.nn.Module

Initialization

forward(x: torch.Tensor) torch.Tensor#
class image.classifiers.nsfw.NsfwClassifier(
embedding_column: str = 'image_embedding',
pred_column: str = 'nsfw_score',
batch_size: int = -1,
model_path: str | None = None,
)#

Bases: nemo_curator.image.classifiers.base.ImageClassifier

NSFW Classifier is a small MLP trained on top of OpenAI’s ViT-L CLIP image embeddings. It is used to assess the likelihood of images containing sexually explicit material. More information on the model can be found here: https://github.com/LAION-AI/CLIP-based-NSFW-Detector.

Initialization

Constructs the classifier.

Args: embedding_column (str): The column name that stores the image embeddings. pred_column (str): The column name to be added where the nsfw scores will be stored. pred_type (Union[str, type]): The datatype of the pred_column. batch_size (int): If greater than 0, the image embeddings will be processed in batches of at most this size. If less than 0, all embeddings will be processed at once. model_path (Optional[str]): If specified, will load the model from the given path. If not specified, will default to being stored in NEMO_CURATOR_HOME.

load_model(device: str) torch.nn.Module#

Loads the classifier model.

Args: device (str): A PyTorch device identifier that specifies what GPU to load the model on.

Returns: Callable: A callable model, usually a torch.nn.Module. The input to this model will be the batches of images output by the ImageEmbedder.load_dataset_shard.

postprocess(series: cudf.Series) cudf.Series#

Postprocesses the predictions of the classifier before saving them to the metadata.

Args: series (cudf.Series): The cuDF series of raw model predictions.

Returns: cudf.Series: The same series unmodified. Override in your classifier if needed.