***

title: DocumentBatch
description: API reference for DocumentBatch - the task type for text document processing
-----------------------------------------------------------------------------------------

`DocumentBatch` is the primary task type for text document processing in NeMo Curator.

## Import

```python
from nemo_curator.tasks import DocumentBatch
```

## Class Definition

```python
from dataclasses import dataclass
import pandas as pd
import pyarrow as pa

@dataclass
class DocumentBatch(Task[pa.Table | pd.DataFrame]):
    """Task containing a batch of text documents.

    Attributes:
        task_id: Unique identifier for this batch.
        dataset_name: Name of the source dataset.
        data: DataFrame or PyArrow Table containing documents.
    """

    task_id: str
    dataset_name: str
    data: pa.Table | pd.DataFrame
```

## Expected Data Schema

The `data` attribute typically contains:

| Column             | Type    | Description                  |
| ------------------ | ------- | ---------------------------- |
| `text`             | `str`   | The document text content    |
| `id`               | `str`   | Optional document identifier |
| `url`              | `str`   | Optional source URL          |
| Additional columns | Various | Task-specific metadata       |

## Properties

### `num_items`

Get the number of documents in the batch.

```python
@property
def num_items(self) -> int:
    """Returns the number of documents in this batch."""
```

## Methods

### `to_pyarrow()`

Convert data to PyArrow table.

```python
def to_pyarrow(self) -> pa.Table:
    """Convert data to PyArrow table."""
```

### `to_pandas()`

Convert data to Pandas DataFrame.

```python
def to_pandas(self) -> pd.DataFrame:
    """Convert data to Pandas DataFrame."""
```

### `get_columns()`

Get column names from the data.

```python
def get_columns(self) -> list[str]:
    """Get column names from the data."""
```

### `validate()`

Validate the batch structure.

```python
def validate(self) -> bool:
    """Validate that the batch has required structure.

    Returns:
        True if valid, False if empty or has no columns (logs warning).
    """
```

## Creating DocumentBatch

```python
import pandas as pd
from nemo_curator.tasks import DocumentBatch

# Create from DataFrame
df = pd.DataFrame({
    "text": ["Document 1 content...", "Document 2 content..."],
    "id": ["doc_001", "doc_002"],
    "url": ["https://example.com/1", "https://example.com/2"],
})

batch = DocumentBatch(
    task_id="batch_001",
    dataset_name="my_dataset",
    data=df,
)

print(f"Batch contains {batch.num_items} documents")
```

## Usage in Stages

```python
from dataclasses import dataclass
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.tasks import DocumentBatch

@dataclass
class TextFilterStage(ProcessingStage[DocumentBatch, DocumentBatch]):
    """Filter documents based on text length."""

    name: str = "TextFilter"
    min_length: int = 100

    def inputs(self) -> tuple[list[str], list[str]]:
        return ["data"], ["text"]

    def outputs(self) -> tuple[list[str], list[str]]:
        return ["data"], ["text"]

    def process(self, task: DocumentBatch) -> DocumentBatch | None:
        df = task.data

        # Filter by text length
        mask = df["text"].str.len() >= self.min_length
        filtered_df = df[mask]

        if filtered_df.empty:
            return None

        return DocumentBatch(
            task_id=f"{task.task_id}_filtered",
            dataset_name=task.dataset_name,
            data=filtered_df,
            _metadata=task._metadata,
            _stage_perf=task._stage_perf,
        )
```

## Common Patterns

### Adding Columns

```python
def process(self, task: DocumentBatch) -> DocumentBatch:
    df = task.data.copy()
    df["word_count"] = df["text"].str.split().str.len()

    return DocumentBatch(
        task_id=f"{task.task_id}_{self.name}",
        dataset_name=task.dataset_name,
        data=df,
        _metadata=task._metadata,
        _stage_perf=task._stage_perf,
    )
```

### Filtering Rows

```python
def process(self, task: DocumentBatch) -> DocumentBatch | None:
    df = task.data
    filtered = df[df["score"] > self.threshold]

    if filtered.empty:
        return None  # Filter out entire batch

    return DocumentBatch(
        task_id=f"{task.task_id}_{self.name}",
        dataset_name=task.dataset_name,
        data=filtered,
        _metadata=task._metadata,
        _stage_perf=task._stage_perf,
    )
```

## Source Code

[View source on GitHub](https://github.com/NVIDIA-NeMo/Curator/blob/main/nemo_curator/tasks/document.py)
