API ReferenceTasks
DocumentBatch
DocumentBatch is the primary task type for text document processing in NeMo Curator.
Import
1 from nemo_curator.tasks import DocumentBatch
Class Definition
1 from dataclasses import dataclass 2 import pandas as pd 3 import pyarrow as pa 4 5 @dataclass 6 class DocumentBatch(Task[pa.Table | pd.DataFrame]): 7 """Task containing a batch of text documents. 8 9 Attributes: 10 task_id: Unique identifier for this batch. 11 dataset_name: Name of the source dataset. 12 data: DataFrame or PyArrow Table containing documents. 13 """ 14 15 task_id: str 16 dataset_name: str 17 data: pa.Table | pd.DataFrame
Expected Data Schema
The data attribute typically contains:
| Column | Type | Description |
|---|---|---|
text | str | The document text content |
id | str | Optional document identifier |
url | str | Optional source URL |
| Additional columns | Various | Task-specific metadata |
Properties
num_items
Get the number of documents in the batch.
1 @property 2 def num_items(self) -> int: 3 """Returns the number of documents in this batch."""
Methods
to_pyarrow()
Convert data to PyArrow table.
1 def to_pyarrow(self) -> pa.Table: 2 """Convert data to PyArrow table."""
to_pandas()
Convert data to Pandas DataFrame.
1 def to_pandas(self) -> pd.DataFrame: 2 """Convert data to Pandas DataFrame."""
get_columns()
Get column names from the data.
1 def get_columns(self) -> list[str]: 2 """Get column names from the data."""
validate()
Validate the batch structure.
1 def validate(self) -> bool: 2 """Validate that the batch has required structure. 3 4 Returns: 5 True if valid, False if empty or has no columns (logs warning). 6 """
Creating DocumentBatch
1 import pandas as pd 2 from nemo_curator.tasks import DocumentBatch 3 4 # Create from DataFrame 5 df = pd.DataFrame({ 6 "text": ["Document 1 content...", "Document 2 content..."], 7 "id": ["doc_001", "doc_002"], 8 "url": ["https://example.com/1", "https://example.com/2"], 9 }) 10 11 batch = DocumentBatch( 12 task_id="batch_001", 13 dataset_name="my_dataset", 14 data=df, 15 ) 16 17 print(f"Batch contains {batch.num_items} documents")
Usage in Stages
1 from dataclasses import dataclass 2 from nemo_curator.stages.base import ProcessingStage 3 from nemo_curator.tasks import DocumentBatch 4 5 @dataclass 6 class TextFilterStage(ProcessingStage[DocumentBatch, DocumentBatch]): 7 """Filter documents based on text length.""" 8 9 name: str = "TextFilter" 10 min_length: int = 100 11 12 def inputs(self) -> tuple[list[str], list[str]]: 13 return ["data"], ["text"] 14 15 def outputs(self) -> tuple[list[str], list[str]]: 16 return ["data"], ["text"] 17 18 def process(self, task: DocumentBatch) -> DocumentBatch | None: 19 df = task.data 20 21 # Filter by text length 22 mask = df["text"].str.len() >= self.min_length 23 filtered_df = df[mask] 24 25 if filtered_df.empty: 26 return None 27 28 return DocumentBatch( 29 task_id=f"{task.task_id}_filtered", 30 dataset_name=task.dataset_name, 31 data=filtered_df, 32 _metadata=task._metadata, 33 _stage_perf=task._stage_perf, 34 )
Common Patterns
Adding Columns
1 def process(self, task: DocumentBatch) -> DocumentBatch: 2 df = task.data.copy() 3 df["word_count"] = df["text"].str.split().str.len() 4 5 return DocumentBatch( 6 task_id=f"{task.task_id}_{self.name}", 7 dataset_name=task.dataset_name, 8 data=df, 9 _metadata=task._metadata, 10 _stage_perf=task._stage_perf, 11 )
Filtering Rows
1 def process(self, task: DocumentBatch) -> DocumentBatch | None: 2 df = task.data 3 filtered = df[df["score"] > self.threshold] 4 5 if filtered.empty: 6 return None # Filter out entire batch 7 8 return DocumentBatch( 9 task_id=f"{task.task_id}_{self.name}", 10 dataset_name=task.dataset_name, 11 data=filtered, 12 _metadata=task._metadata, 13 _stage_perf=task._stage_perf, 14 )