API ReferenceTasks

DocumentBatch

View as Markdown

DocumentBatch is the primary task type for text document processing in NeMo Curator.

Import

1from nemo_curator.tasks import DocumentBatch

Class Definition

1from dataclasses import dataclass
2import pandas as pd
3import pyarrow as pa
4
5@dataclass
6class DocumentBatch(Task[pa.Table | pd.DataFrame]):
7 """Task containing a batch of text documents.
8
9 Attributes:
10 task_id: Unique identifier for this batch.
11 dataset_name: Name of the source dataset.
12 data: DataFrame or PyArrow Table containing documents.
13 """
14
15 task_id: str
16 dataset_name: str
17 data: pa.Table | pd.DataFrame

Expected Data Schema

The data attribute typically contains:

ColumnTypeDescription
textstrThe document text content
idstrOptional document identifier
urlstrOptional source URL
Additional columnsVariousTask-specific metadata

Properties

num_items

Get the number of documents in the batch.

1@property
2def num_items(self) -> int:
3 """Returns the number of documents in this batch."""

Methods

to_pyarrow()

Convert data to PyArrow table.

1def to_pyarrow(self) -> pa.Table:
2 """Convert data to PyArrow table."""

to_pandas()

Convert data to Pandas DataFrame.

1def to_pandas(self) -> pd.DataFrame:
2 """Convert data to Pandas DataFrame."""

get_columns()

Get column names from the data.

1def get_columns(self) -> list[str]:
2 """Get column names from the data."""

validate()

Validate the batch structure.

1def validate(self) -> bool:
2 """Validate that the batch has required structure.
3
4 Returns:
5 True if valid, False if empty or has no columns (logs warning).
6 """

Creating DocumentBatch

1import pandas as pd
2from nemo_curator.tasks import DocumentBatch
3
4# Create from DataFrame
5df = pd.DataFrame({
6 "text": ["Document 1 content...", "Document 2 content..."],
7 "id": ["doc_001", "doc_002"],
8 "url": ["https://example.com/1", "https://example.com/2"],
9})
10
11batch = DocumentBatch(
12 task_id="batch_001",
13 dataset_name="my_dataset",
14 data=df,
15)
16
17print(f"Batch contains {batch.num_items} documents")

Usage in Stages

1from dataclasses import dataclass
2from nemo_curator.stages.base import ProcessingStage
3from nemo_curator.tasks import DocumentBatch
4
5@dataclass
6class TextFilterStage(ProcessingStage[DocumentBatch, DocumentBatch]):
7 """Filter documents based on text length."""
8
9 name: str = "TextFilter"
10 min_length: int = 100
11
12 def inputs(self) -> tuple[list[str], list[str]]:
13 return ["data"], ["text"]
14
15 def outputs(self) -> tuple[list[str], list[str]]:
16 return ["data"], ["text"]
17
18 def process(self, task: DocumentBatch) -> DocumentBatch | None:
19 df = task.data
20
21 # Filter by text length
22 mask = df["text"].str.len() >= self.min_length
23 filtered_df = df[mask]
24
25 if filtered_df.empty:
26 return None
27
28 return DocumentBatch(
29 task_id=f"{task.task_id}_filtered",
30 dataset_name=task.dataset_name,
31 data=filtered_df,
32 _metadata=task._metadata,
33 _stage_perf=task._stage_perf,
34 )

Common Patterns

Adding Columns

1def process(self, task: DocumentBatch) -> DocumentBatch:
2 df = task.data.copy()
3 df["word_count"] = df["text"].str.split().str.len()
4
5 return DocumentBatch(
6 task_id=f"{task.task_id}_{self.name}",
7 dataset_name=task.dataset_name,
8 data=df,
9 _metadata=task._metadata,
10 _stage_perf=task._stage_perf,
11 )

Filtering Rows

1def process(self, task: DocumentBatch) -> DocumentBatch | None:
2 df = task.data
3 filtered = df[df["score"] > self.threshold]
4
5 if filtered.empty:
6 return None # Filter out entire batch
7
8 return DocumentBatch(
9 task_id=f"{task.task_id}_{self.name}",
10 dataset_name=task.dataset_name,
11 data=filtered,
12 _metadata=task._metadata,
13 _stage_perf=task._stage_perf,
14 )

Source Code

View source on GitHub