classifiers.base#

Module Contents#

Classes#

API#

class classifiers.base.DistributedDataClassifier(
model: str,
labels: list[str] | None,
filter_by: list[str] | None,
batch_size: int,
out_dim: int | None,
pred_column: str | list[str],
max_chars: int,
device_type: str,
autocast: bool,
)#

Bases: nemo_curator.modules.base.BaseModule

Initialization

call(
dataset: nemo_curator.datasets.DocumentDataset,
) nemo_curator.datasets.DocumentDataset#
get_labels() list[str]#
class classifiers.base.HFDeberta(config: dataclasses.dataclass)#

Bases: torch.nn.Module, huggingface_hub.PyTorchModelHubMixin

Initialization

forward(batch: dict[str, torch.Tensor]) torch.Tensor#
set_autocast(autocast: bool) None#