API ReferenceTasks

ImageBatch

View as Markdown

ImageBatch is the task type for image processing in NeMo Curator.

Import

1from nemo_curator.tasks import ImageBatch

Class Definition

1from dataclasses import dataclass
2from nemo_curator.tasks.image import ImageObject
3
4@dataclass
5class ImageBatch(Task[list[ImageObject]]):
6 """Task containing a batch of images.
7
8 Attributes:
9 task_id: Unique identifier for this batch.
10 dataset_name: Name of the source dataset.
11 data: List of ImageObject instances.
12 """
13
14 task_id: str
15 dataset_name: str
16 data: list[ImageObject]

ImageObject

Each image in the batch is represented by an ImageObject:

1@dataclass
2class ImageObject:
3 """Represents a single image with metadata.
4
5 Attributes:
6 path: Path to the image file.
7 caption: Optional text caption for the image.
8 metadata: Additional metadata dictionary.
9 embeddings: Optional embedding vector.
10 """
11
12 path: str
13 caption: str | None = None
14 metadata: dict[str, Any] = field(default_factory=dict)
15 embeddings: np.ndarray | None = None

Properties

num_items

Get the number of images in the batch.

1@property
2def num_items(self) -> int:
3 """Returns the number of images in this batch."""

Creating ImageBatch

1from nemo_curator.tasks import ImageBatch
2from nemo_curator.tasks.image import ImageObject
3
4# Create image objects
5images = [
6 ImageObject(
7 path="/data/images/image1.jpg",
8 caption="A cat sitting on a couch",
9 metadata={"source": "dataset_a"},
10 ),
11 ImageObject(
12 path="/data/images/image2.jpg",
13 caption="A dog playing in the park",
14 metadata={"source": "dataset_a"},
15 ),
16]
17
18# Create batch
19batch = ImageBatch(
20 task_id="img_batch_001",
21 dataset_name="image_dataset",
22 data=images,
23)

Usage in Stages

1from dataclasses import dataclass
2from nemo_curator.stages.base import ProcessingStage
3from nemo_curator.tasks import ImageBatch
4
5@dataclass
6class ImageFilterStage(ProcessingStage[ImageBatch, ImageBatch]):
7 """Filter images based on metadata."""
8
9 name: str = "ImageFilter"
10 min_resolution: int = 256
11
12 def inputs(self) -> tuple[list[str], list[str]]:
13 return ["data"], []
14
15 def outputs(self) -> tuple[list[str], list[str]]:
16 return ["data"], []
17
18 def process(self, task: ImageBatch) -> ImageBatch | None:
19 filtered = [
20 img for img in task.data
21 if img.metadata.get("width", 0) >= self.min_resolution
22 and img.metadata.get("height", 0) >= self.min_resolution
23 ]
24
25 if not filtered:
26 return None
27
28 return ImageBatch(
29 task_id=f"{task.task_id}_filtered",
30 dataset_name=task.dataset_name,
31 data=filtered,
32 _metadata=task._metadata,
33 _stage_perf=task._stage_perf,
34 )

Common Operations

Adding Embeddings

1def process(self, task: ImageBatch) -> ImageBatch:
2 for img in task.data:
3 img.embeddings = self.model.encode(img.path)
4
5 return ImageBatch(
6 task_id=f"{task.task_id}_{self.name}",
7 dataset_name=task.dataset_name,
8 data=task.data,
9 _metadata=task._metadata,
10 _stage_perf=task._stage_perf,
11 )

Filtering by Score

1def process(self, task: ImageBatch) -> ImageBatch | None:
2 filtered = [
3 img for img in task.data
4 if img.metadata.get("aesthetic_score", 0) >= self.threshold
5 ]
6
7 if not filtered:
8 return None
9
10 return ImageBatch(
11 task_id=f"{task.task_id}_{self.name}",
12 dataset_name=task.dataset_name,
13 data=filtered,
14 _metadata=task._metadata,
15 _stage_perf=task._stage_perf,
16 )

Source Code

View source on GitHub