*** title: ImageBatch description: API reference for ImageBatch - the task type for image processing ------------------------------------------------------------------------------ `ImageBatch` is the task type for image processing in NeMo Curator. ## Import ```python from nemo_curator.tasks import ImageBatch ``` ## Class Definition ```python from dataclasses import dataclass from nemo_curator.tasks.image import ImageObject @dataclass class ImageBatch(Task[list[ImageObject]]): """Task containing a batch of images. Attributes: task_id: Unique identifier for this batch. dataset_name: Name of the source dataset. data: List of ImageObject instances. """ task_id: str dataset_name: str data: list[ImageObject] ``` ## ImageObject Each image in the batch is represented by an `ImageObject`: ```python @dataclass class ImageObject: """Represents a single image with metadata. Attributes: path: Path to the image file. caption: Optional text caption for the image. metadata: Additional metadata dictionary. embeddings: Optional embedding vector. """ path: str caption: str | None = None metadata: dict[str, Any] = field(default_factory=dict) embeddings: np.ndarray | None = None ``` ## Properties ### `num_items` Get the number of images in the batch. ```python @property def num_items(self) -> int: """Returns the number of images in this batch.""" ``` ## Creating ImageBatch ```python from nemo_curator.tasks import ImageBatch from nemo_curator.tasks.image import ImageObject # Create image objects images = [ ImageObject( path="/data/images/image1.jpg", caption="A cat sitting on a couch", metadata={"source": "dataset_a"}, ), ImageObject( path="/data/images/image2.jpg", caption="A dog playing in the park", metadata={"source": "dataset_a"}, ), ] # Create batch batch = ImageBatch( task_id="img_batch_001", dataset_name="image_dataset", data=images, ) ``` ## Usage in Stages ```python from dataclasses import dataclass from nemo_curator.stages.base import ProcessingStage from nemo_curator.tasks import ImageBatch @dataclass class ImageFilterStage(ProcessingStage[ImageBatch, ImageBatch]): """Filter images based on metadata.""" name: str = "ImageFilter" min_resolution: int = 256 def inputs(self) -> tuple[list[str], list[str]]: return ["data"], [] def outputs(self) -> tuple[list[str], list[str]]: return ["data"], [] def process(self, task: ImageBatch) -> ImageBatch | None: filtered = [ img for img in task.data if img.metadata.get("width", 0) >= self.min_resolution and img.metadata.get("height", 0) >= self.min_resolution ] if not filtered: return None return ImageBatch( task_id=f"{task.task_id}_filtered", dataset_name=task.dataset_name, data=filtered, _metadata=task._metadata, _stage_perf=task._stage_perf, ) ``` ## Common Operations ### Adding Embeddings ```python def process(self, task: ImageBatch) -> ImageBatch: for img in task.data: img.embeddings = self.model.encode(img.path) return ImageBatch( task_id=f"{task.task_id}_{self.name}", dataset_name=task.dataset_name, data=task.data, _metadata=task._metadata, _stage_perf=task._stage_perf, ) ``` ### Filtering by Score ```python def process(self, task: ImageBatch) -> ImageBatch | None: filtered = [ img for img in task.data if img.metadata.get("aesthetic_score", 0) >= self.threshold ] if not filtered: return None return ImageBatch( task_id=f"{task.task_id}_{self.name}", dataset_name=task.dataset_name, data=filtered, _metadata=task._metadata, _stage_perf=task._stage_perf, ) ``` ## Source Code [View source on GitHub](https://github.com/NVIDIA-NeMo/Curator/blob/main/nemo_curator/tasks/image.py)