***

title: CompositeStage
description: >-
API reference for CompositeStage - high-level stages that decompose into
multiple execution stages
-------------------------

The `CompositeStage` class represents high-level, user-facing stages that decompose into multiple low-level execution stages.

## Import

```python
from nemo_curator.stages.base import CompositeStage
```

## When to Use CompositeStage

Use `CompositeStage` when you need to:

* Provide a simplified API while maintaining fine-grained execution control
* Bundle multiple related stages into a single logical operation
* Handle stages that require different resources (e.g., CPU-based followed by GPU-based)

## Class Definition

```python
from dataclasses import dataclass
from typing import Generic, TypeVar

@dataclass
class CompositeStage(ProcessingStage[InputT, OutputT]):
    """High-level stage that decomposes into multiple execution stages.

    Composite stages are decomposed during pipeline planning, allowing
    each sub-stage to run with its own resource requirements.

    Attributes:
        stages: List of constituent ProcessingStage instances.
    """

    stages: list[ProcessingStage] = field(default_factory=list)
```

## Abstract Methods

### `decompose()`

Return the list of stages this composite decomposes into.

```python
def decompose(self) -> list[ProcessingStage]:
    """Decompose into constituent execution stages.

    Returns:
        List of ProcessingStage instances to execute.
    """
    return self.stages
```

## Creating a CompositeStage

```python
from dataclasses import dataclass, field
from nemo_curator.stages.base import CompositeStage, ProcessingStage
from nemo_curator.tasks import Task

@dataclass
class MyCompositeStage(CompositeStage[Task, Task]):
    """A composite stage that bundles multiple operations."""

    name: str = "MyCompositeStage"
    param1: str = ""
    param2: int = 0

    def __post_init__(self) -> None:
        super().__init__()
        self.stages = [
            StageA(param1=self.param1),
            StageB(param2=self.param2),
            StageC(),
        ]

    def inputs(self) -> tuple[list[str], list[str]]:
        # Return first stage's inputs
        return self.stages[0].inputs()

    def outputs(self) -> tuple[list[str], list[str]]:
        # Return last stage's outputs
        return self.stages[-1].outputs()

    def decompose(self) -> list[ProcessingStage]:
        return self.stages
```

## Configuration with `with_()`

`CompositeStage` uses a dictionary-based `with_()` signature to configure individual sub-stages:

```python
from nemo_curator.stages.resources import Resources

composite_stage = MyCompositeStage(param1="value", param2=10)

# Configure individual stages within the composite
stage_config = {
    "StageA": {"resources": Resources(cpus=4.0)},
    "StageB": {"resources": Resources(cpus=2.0, gpus=1.0)},
}
configured_stage = composite_stage.with_(stage_config)
```

## Important Rules

1. **Decomposed stages cannot be CompositeStages** - Only leaf ProcessingStage instances
2. **`inputs()` returns first stage's inputs** - The composite's input requirements
3. **`outputs()` returns last stage's outputs** - The composite's output type
4. **Unique stage names** - All stages in `decompose()` must have unique names for `with_()` to work

## Usage in Pipelines

```python
from nemo_curator.pipeline import Pipeline

# Composite stages are automatically decomposed during build
pipeline = Pipeline(
    name="my_pipeline",
    stages=[
        MyCompositeStage(param1="test", param2=5),
        AnotherStage(),
    ],
)

# The pipeline.build() method decomposes composites
# decomposition_info tracks the mapping
results = pipeline.run()
```

## Source Code

[View source on GitHub](https://github.com/NVIDIA-NeMo/Curator/blob/main/nemo_curator/stages/base.py)
