*** title: DocumentBatch description: API reference for DocumentBatch - the task type for text document processing ----------------------------------------------------------------------------------------- `DocumentBatch` is the primary task type for text document processing in NeMo Curator. ## Import ```python from nemo_curator.tasks import DocumentBatch ``` ## Class Definition ```python from dataclasses import dataclass import pandas as pd import pyarrow as pa @dataclass class DocumentBatch(Task[pa.Table | pd.DataFrame]): """Task containing a batch of text documents. Attributes: task_id: Unique identifier for this batch. dataset_name: Name of the source dataset. data: DataFrame or PyArrow Table containing documents. """ task_id: str dataset_name: str data: pa.Table | pd.DataFrame ``` ## Expected Data Schema The `data` attribute typically contains: | Column | Type | Description | | ------------------ | ------- | ---------------------------- | | `text` | `str` | The document text content | | `id` | `str` | Optional document identifier | | `url` | `str` | Optional source URL | | Additional columns | Various | Task-specific metadata | ## Properties ### `num_items` Get the number of documents in the batch. ```python @property def num_items(self) -> int: """Returns the number of documents in this batch.""" ``` ## Methods ### `to_pyarrow()` Convert data to PyArrow table. ```python def to_pyarrow(self) -> pa.Table: """Convert data to PyArrow table.""" ``` ### `to_pandas()` Convert data to Pandas DataFrame. ```python def to_pandas(self) -> pd.DataFrame: """Convert data to Pandas DataFrame.""" ``` ### `get_columns()` Get column names from the data. ```python def get_columns(self) -> list[str]: """Get column names from the data.""" ``` ### `validate()` Validate the batch structure. ```python def validate(self) -> bool: """Validate that the batch has required structure. Returns: True if valid, False if empty or has no columns (logs warning). """ ``` ## Creating DocumentBatch ```python import pandas as pd from nemo_curator.tasks import DocumentBatch # Create from DataFrame df = pd.DataFrame({ "text": ["Document 1 content...", "Document 2 content..."], "id": ["doc_001", "doc_002"], "url": ["https://example.com/1", "https://example.com/2"], }) batch = DocumentBatch( task_id="batch_001", dataset_name="my_dataset", data=df, ) print(f"Batch contains {batch.num_items} documents") ``` ## Usage in Stages ```python from dataclasses import dataclass from nemo_curator.stages.base import ProcessingStage from nemo_curator.tasks import DocumentBatch @dataclass class TextFilterStage(ProcessingStage[DocumentBatch, DocumentBatch]): """Filter documents based on text length.""" name: str = "TextFilter" min_length: int = 100 def inputs(self) -> tuple[list[str], list[str]]: return ["data"], ["text"] def outputs(self) -> tuple[list[str], list[str]]: return ["data"], ["text"] def process(self, task: DocumentBatch) -> DocumentBatch | None: df = task.data # Filter by text length mask = df["text"].str.len() >= self.min_length filtered_df = df[mask] if filtered_df.empty: return None return DocumentBatch( task_id=f"{task.task_id}_filtered", dataset_name=task.dataset_name, data=filtered_df, _metadata=task._metadata, _stage_perf=task._stage_perf, ) ``` ## Common Patterns ### Adding Columns ```python def process(self, task: DocumentBatch) -> DocumentBatch: df = task.data.copy() df["word_count"] = df["text"].str.split().str.len() return DocumentBatch( task_id=f"{task.task_id}_{self.name}", dataset_name=task.dataset_name, data=df, _metadata=task._metadata, _stage_perf=task._stage_perf, ) ``` ### Filtering Rows ```python def process(self, task: DocumentBatch) -> DocumentBatch | None: df = task.data filtered = df[df["score"] > self.threshold] if filtered.empty: return None # Filter out entire batch return DocumentBatch( task_id=f"{task.task_id}_{self.name}", dataset_name=task.dataset_name, data=filtered, _metadata=task._metadata, _stage_perf=task._stage_perf, ) ``` ## Source Code [View source on GitHub](https://github.com/NVIDIA-NeMo/Curator/blob/main/nemo_curator/tasks/document.py)