API ReferenceTasks
AudioBatch
AudioBatch is the task type for audio processing in NeMo Curator.
Import
1 from nemo_curator.tasks import AudioBatch
Class Definition
1 from dataclasses import dataclass 2 3 @dataclass 4 class AudioBatch(Task[dict | list[dict]]): 5 """Task containing audio data for processing. 6 7 Attributes: 8 task_id: Unique identifier for this batch. 9 dataset_name: Name of the source dataset. 10 data: Audio manifest data (dict or list of dicts). 11 """ 12 13 task_id: str 14 dataset_name: str 15 data: dict | list[dict]
Audio Manifest Format
Audio data follows the NeMo manifest format:
1 { 2 "audio_filepath": "/path/to/audio.wav", 3 "duration": 5.2, 4 "text": "Transcription text...", 5 "speaker": "speaker_001", 6 "metadata": { 7 "sample_rate": 16000, 8 "channels": 1 9 } 10 }
Properties
num_items
Get the number of audio samples in the batch.
1 @property 2 def num_items(self) -> int: 3 """Returns the number of audio samples."""
Creating AudioBatch
1 from nemo_curator.tasks import AudioBatch 2 3 # Single manifest entry 4 manifest = { 5 "audio_filepath": "/data/audio/sample.wav", 6 "duration": 5.2, 7 "text": "Hello world", 8 } 9 10 batch = AudioBatch( 11 task_id="audio_001", 12 dataset_name="speech_dataset", 13 data=manifest, 14 ) 15 16 # Multiple entries 17 manifests = [ 18 {"audio_filepath": "/data/audio/s1.wav", "duration": 3.1}, 19 {"audio_filepath": "/data/audio/s2.wav", "duration": 4.5}, 20 ] 21 22 batch = AudioBatch( 23 task_id="audio_batch_001", 24 dataset_name="speech_dataset", 25 data=manifests, 26 )
Usage in Stages
1 from dataclasses import dataclass 2 from nemo_curator.stages.base import ProcessingStage 3 from nemo_curator.tasks import AudioBatch 4 5 @dataclass 6 class DurationFilterStage(ProcessingStage[AudioBatch, AudioBatch]): 7 """Filter audio by duration.""" 8 9 name: str = "DurationFilter" 10 min_duration: float = 1.0 11 max_duration: float = 30.0 12 13 def inputs(self) -> tuple[list[str], list[str]]: 14 return ["data"], [] 15 16 def outputs(self) -> tuple[list[str], list[str]]: 17 return ["data"], [] 18 19 def process(self, task: AudioBatch) -> AudioBatch | None: 20 data = task.data 21 22 # Handle both single dict and list 23 if isinstance(data, dict): 24 data = [data] 25 26 filtered = [ 27 item for item in data 28 if self.min_duration <= item.get("duration", 0) <= self.max_duration 29 ] 30 31 if not filtered: 32 return None 33 34 return AudioBatch( 35 task_id=f"{task.task_id}_filtered", 36 dataset_name=task.dataset_name, 37 data=filtered if len(filtered) > 1 else filtered[0], 38 _metadata=task._metadata, 39 _stage_perf=task._stage_perf, 40 )
Common Operations
ASR Transcription
1 def process(self, task: AudioBatch) -> AudioBatch: 2 data = task.data if isinstance(task.data, list) else [task.data] 3 4 for item in data: 5 audio_path = item["audio_filepath"] 6 item["text"] = self.asr_model.transcribe(audio_path) 7 8 return AudioBatch( 9 task_id=f"{task.task_id}_{self.name}", 10 dataset_name=task.dataset_name, 11 data=data if len(data) > 1 else data[0], 12 _metadata=task._metadata, 13 _stage_perf=task._stage_perf, 14 )
Quality Scoring
1 def process(self, task: AudioBatch) -> AudioBatch: 2 data = task.data if isinstance(task.data, list) else [task.data] 3 4 for item in data: 5 if "text" in item and "reference_text" in item: 6 item["wer"] = compute_wer(item["reference_text"], item["text"]) 7 8 return AudioBatch( 9 task_id=f"{task.task_id}_{self.name}", 10 dataset_name=task.dataset_name, 11 data=data if len(data) > 1 else data[0], 12 _metadata=task._metadata, 13 _stage_perf=task._stage_perf, 14 )