API ReferenceTasks
ImageBatch
ImageBatch is the task type for image processing in NeMo Curator.
Import
1 from nemo_curator.tasks import ImageBatch
Class Definition
1 from dataclasses import dataclass 2 from nemo_curator.tasks.image import ImageObject 3 4 @dataclass 5 class ImageBatch(Task[list[ImageObject]]): 6 """Task containing a batch of images. 7 8 Attributes: 9 task_id: Unique identifier for this batch. 10 dataset_name: Name of the source dataset. 11 data: List of ImageObject instances. 12 """ 13 14 task_id: str 15 dataset_name: str 16 data: list[ImageObject]
ImageObject
Each image in the batch is represented by an ImageObject:
1 @dataclass 2 class ImageObject: 3 """Represents a single image with metadata. 4 5 Attributes: 6 path: Path to the image file. 7 caption: Optional text caption for the image. 8 metadata: Additional metadata dictionary. 9 embeddings: Optional embedding vector. 10 """ 11 12 path: str 13 caption: str | None = None 14 metadata: dict[str, Any] = field(default_factory=dict) 15 embeddings: np.ndarray | None = None
Properties
num_items
Get the number of images in the batch.
1 @property 2 def num_items(self) -> int: 3 """Returns the number of images in this batch."""
Creating ImageBatch
1 from nemo_curator.tasks import ImageBatch 2 from nemo_curator.tasks.image import ImageObject 3 4 # Create image objects 5 images = [ 6 ImageObject( 7 path="/data/images/image1.jpg", 8 caption="A cat sitting on a couch", 9 metadata={"source": "dataset_a"}, 10 ), 11 ImageObject( 12 path="/data/images/image2.jpg", 13 caption="A dog playing in the park", 14 metadata={"source": "dataset_a"}, 15 ), 16 ] 17 18 # Create batch 19 batch = ImageBatch( 20 task_id="img_batch_001", 21 dataset_name="image_dataset", 22 data=images, 23 )
Usage in Stages
1 from dataclasses import dataclass 2 from nemo_curator.stages.base import ProcessingStage 3 from nemo_curator.tasks import ImageBatch 4 5 @dataclass 6 class ImageFilterStage(ProcessingStage[ImageBatch, ImageBatch]): 7 """Filter images based on metadata.""" 8 9 name: str = "ImageFilter" 10 min_resolution: int = 256 11 12 def inputs(self) -> tuple[list[str], list[str]]: 13 return ["data"], [] 14 15 def outputs(self) -> tuple[list[str], list[str]]: 16 return ["data"], [] 17 18 def process(self, task: ImageBatch) -> ImageBatch | None: 19 filtered = [ 20 img for img in task.data 21 if img.metadata.get("width", 0) >= self.min_resolution 22 and img.metadata.get("height", 0) >= self.min_resolution 23 ] 24 25 if not filtered: 26 return None 27 28 return ImageBatch( 29 task_id=f"{task.task_id}_filtered", 30 dataset_name=task.dataset_name, 31 data=filtered, 32 _metadata=task._metadata, 33 _stage_perf=task._stage_perf, 34 )
Common Operations
Adding Embeddings
1 def process(self, task: ImageBatch) -> ImageBatch: 2 for img in task.data: 3 img.embeddings = self.model.encode(img.path) 4 5 return ImageBatch( 6 task_id=f"{task.task_id}_{self.name}", 7 dataset_name=task.dataset_name, 8 data=task.data, 9 _metadata=task._metadata, 10 _stage_perf=task._stage_perf, 11 )
Filtering by Score
1 def process(self, task: ImageBatch) -> ImageBatch | None: 2 filtered = [ 3 img for img in task.data 4 if img.metadata.get("aesthetic_score", 0) >= self.threshold 5 ] 6 7 if not filtered: 8 return None 9 10 return ImageBatch( 11 task_id=f"{task.task_id}_{self.name}", 12 dataset_name=task.dataset_name, 13 data=filtered, 14 _metadata=task._metadata, 15 _stage_perf=task._stage_perf, 16 )