classifiers.prompt_task_complexity#

Module Contents#

Classes#

CustomHFDeberta

MeanPooling

MulticlassHead

PromptTaskComplexityClassifier

PromptTaskComplexityClassifier is a multi-headed model which classifies English text prompts across task types and complexity dimensions. Tasks are classified across 11 common categories. Complexity is evaluated across 6 dimensions and ensembled to create an overall complexity score. Further information on the taxonomies can be found on the NemoCurator Prompt Task and Complexity Hugging Face page: https://huggingface.co/nvidia/prompt-task-and-complexity-classifier. This class is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets.

PromptTaskComplexityConfig

PromptTaskComplexityModel

Data#

API#

class classifiers.prompt_task_complexity.CustomHFDeberta(config: dataclasses.dataclass)#

Bases: torch.nn.Module, huggingface_hub.PyTorchModelHubMixin

Initialization

compute_results(
preds: torch.Tensor,
target: str,
decimal: int = 4,
) tuple[list[str], list[str], list[float]]#
forward(batch: dict[str, torch.Tensor]) dict[str, torch.Tensor]#
process_logits(logits: list[torch.Tensor]) dict[str, torch.Tensor]#
set_autocast(autocast: bool) None#
class classifiers.prompt_task_complexity.MeanPooling#

Bases: torch.nn.Module

Initialization

forward(
last_hidden_state: torch.Tensor,
attention_mask: torch.Tensor,
) torch.Tensor#
class classifiers.prompt_task_complexity.MulticlassHead(input_size: int, num_classes: int)#

Bases: torch.nn.Module

Initialization

forward(x: torch.Tensor) torch.Tensor#
classifiers.prompt_task_complexity.PROMPT_TASK_COMPLEXITY_IDENTIFIER#

‘nvidia/prompt-task-and-complexity-classifier’

class classifiers.prompt_task_complexity.PromptTaskComplexityClassifier(
batch_size: int = 256,
text_field: str = 'text',
max_chars: int = 2000,
device_type: str = 'cuda',
autocast: bool = True,
max_mem_gb: int | None = None,
)#

Bases: nemo_curator.classifiers.base.DistributedDataClassifier

PromptTaskComplexityClassifier is a multi-headed model which classifies English text prompts across task types and complexity dimensions. Tasks are classified across 11 common categories. Complexity is evaluated across 6 dimensions and ensembled to create an overall complexity score. Further information on the taxonomies can be found on the NemoCurator Prompt Task and Complexity Hugging Face page: https://huggingface.co/nvidia/prompt-task-and-complexity-classifier. This class is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets.

Attributes: batch_size (int): The number of samples per batch for inference. Defaults to 256. text_field (str): The field in the dataset that should be classified. max_chars (int): The maximum number of characters in each document to consider for classification. Defaults to 2000. device_type (str): The type of device to use for inference, either “cuda” or “cpu”. Defaults to “cuda”. autocast (bool): Whether to use mixed precision for faster inference. Defaults to True. max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None, it defaults to the available GPU memory minus 4 GB.

Initialization

Constructs a Module

Args: input_backend (Literal[“pandas”, “cudf”, “any”]): The backend the input dataframe must be on for the module to work name (str, Optional): The name of the module. If None, defaults to self.class.name

get_labels() list[str]#
class classifiers.prompt_task_complexity.PromptTaskComplexityConfig#
base_model: str#

‘microsoft/DeBERTa-v3-base’

max_len: int#

512

model_output_type: dict#

‘field(…)’

class classifiers.prompt_task_complexity.PromptTaskComplexityModel(
config: classifiers.prompt_task_complexity.PromptTaskComplexityConfig,
autocast: bool,
max_mem_gb: int | None,
)#

Bases: crossfit.backend.torch.hf.model.HFModel

Initialization

load_config() transformers.AutoConfig#
load_model(
device: str = 'cuda',
) classifiers.prompt_task_complexity.CustomHFDeberta#
load_tokenizer() transformers.AutoTokenizer#