API Reference

ProcessingStage

View as Markdown

The ProcessingStage class is the base class for all data processing stages in NeMo Curator. Each stage defines a single step in a data curation pipeline.

Import

1from nemo_curator.stages.base import ProcessingStage

Class Definition

1from dataclasses import dataclass
2from typing import Generic, TypeVar
3
4InputT = TypeVar("InputT", bound=Task)
5OutputT = TypeVar("OutputT", bound=Task)
6
7@dataclass
8class ProcessingStage(Generic[InputT, OutputT]):
9 """Base class for all processing stages.
10
11 Type Parameters:
12 InputT: The input task type this stage accepts.
13 OutputT: The output task type this stage produces.
14
15 Class Attributes:
16 name: String identifier for the stage.
17 resources: Resources configuration (CPUs, GPUs).
18 batch_size: Number of tasks to process per batch.
19 """
20
21 name: str = "ProcessingStage"
22 resources: Resources = field(default_factory=lambda: Resources(cpus=1.0))
23 batch_size: int = 1

Abstract Methods

inputs()

Define stage input requirements.

1def inputs(self) -> tuple[list[str], list[str]]:
2 """Define required task and data attributes.
3
4 Returns:
5 Tuple of (required_task_attributes, required_data_attributes).
6 """

outputs()

Define stage output requirements.

1def outputs(self) -> tuple[list[str], list[str]]:
2 """Define output task and data attributes.
3
4 Returns:
5 Tuple of (output_task_attributes, output_data_attributes).
6 """

process()

Process a single task.

1def process(self, task: InputT) -> OutputT | list[OutputT] | None:
2 """Process a single task.
3
4 Args:
5 task: The input task to process.
6
7 Returns:
8 - Single task: For 1-to-1 transformations
9 - List of tasks: For splitting/reading operations
10 - None: To filter out the task
11 """

Optional Lifecycle Methods

setup_on_node()

Node-level initialization (e.g., download models).

1def setup_on_node(
2 self,
3 node_info: NodeInfo,
4 worker_metadata: dict[str, Any],
5) -> None:
6 """Initialize resources on a compute node.
7
8 Called once per node before any workers start.
9 """

setup()

Worker-level initialization (e.g., load models).

1def setup(self, worker_metadata: dict[str, Any]) -> None:
2 """Initialize resources for a worker.
3
4 Called once per worker before processing begins.
5 """

teardown()

Cleanup after processing.

1def teardown(self) -> None:
2 """Clean up resources after processing completes."""

process_batch()

Vectorized batch processing for better performance.

1def process_batch(self, tasks: list[InputT]) -> list[OutputT | None]:
2 """Process a batch of tasks.
3
4 Override for vectorized operations.
5
6 Args:
7 tasks: List of input tasks.
8
9 Returns:
10 List of output tasks (None entries are filtered out).
11 """

Creating Custom Stages

1from dataclasses import dataclass
2from nemo_curator.stages.base import ProcessingStage
3from nemo_curator.stages.resources import Resources
4from nemo_curator.tasks import DocumentBatch
5
6@dataclass
7class MyCustomStage(ProcessingStage[DocumentBatch, DocumentBatch]):
8 """Custom stage that processes documents."""
9
10 name: str = "MyCustomStage"
11 resources: Resources = field(default_factory=lambda: Resources(cpus=2.0))
12
13 # Custom parameters
14 threshold: float = 0.5
15
16 def inputs(self) -> tuple[list[str], list[str]]:
17 return ["data"], ["text"]
18
19 def outputs(self) -> tuple[list[str], list[str]]:
20 return ["data"], ["text", "score"]
21
22 def process(self, task: DocumentBatch) -> DocumentBatch | None:
23 # Process the task
24 df = task.data
25 df["score"] = df["text"].apply(self._compute_score)
26
27 # Filter based on threshold
28 if df["score"].mean() < self.threshold:
29 return None
30
31 return DocumentBatch(
32 task_id=f"{task.task_id}_{self.name}",
33 dataset_name=task.dataset_name,
34 data=df,
35 _metadata=task._metadata,
36 _stage_perf=task._stage_perf,
37 )
38
39 def _compute_score(self, text: str) -> float:
40 # Custom scoring logic
41 return len(text) / 1000.0

Configuration with with_()

Stages can be configured using the with_() method:

1from nemo_curator.stages.resources import Resources
2
3stage = MyCustomStage(threshold=0.7)
4configured_stage = stage.with_(resources=Resources(cpus=4.0, gpus=1.0))

Source Code

View source on GitHub