API ReferenceTasks

AudioBatch

View as Markdown

AudioBatch is the task type for audio processing in NeMo Curator.

Import

1from nemo_curator.tasks import AudioBatch

Class Definition

1from dataclasses import dataclass
2
3@dataclass
4class AudioBatch(Task[dict | list[dict]]):
5 """Task containing audio data for processing.
6
7 Attributes:
8 task_id: Unique identifier for this batch.
9 dataset_name: Name of the source dataset.
10 data: Audio manifest data (dict or list of dicts).
11 """
12
13 task_id: str
14 dataset_name: str
15 data: dict | list[dict]

Audio Manifest Format

Audio data follows the NeMo manifest format:

1{
2 "audio_filepath": "/path/to/audio.wav",
3 "duration": 5.2,
4 "text": "Transcription text...",
5 "speaker": "speaker_001",
6 "metadata": {
7 "sample_rate": 16000,
8 "channels": 1
9 }
10}

Properties

num_items

Get the number of audio samples in the batch.

1@property
2def num_items(self) -> int:
3 """Returns the number of audio samples."""

Creating AudioBatch

1from nemo_curator.tasks import AudioBatch
2
3# Single manifest entry
4manifest = {
5 "audio_filepath": "/data/audio/sample.wav",
6 "duration": 5.2,
7 "text": "Hello world",
8}
9
10batch = AudioBatch(
11 task_id="audio_001",
12 dataset_name="speech_dataset",
13 data=manifest,
14)
15
16# Multiple entries
17manifests = [
18 {"audio_filepath": "/data/audio/s1.wav", "duration": 3.1},
19 {"audio_filepath": "/data/audio/s2.wav", "duration": 4.5},
20]
21
22batch = AudioBatch(
23 task_id="audio_batch_001",
24 dataset_name="speech_dataset",
25 data=manifests,
26)

Usage in Stages

1from dataclasses import dataclass
2from nemo_curator.stages.base import ProcessingStage
3from nemo_curator.tasks import AudioBatch
4
5@dataclass
6class DurationFilterStage(ProcessingStage[AudioBatch, AudioBatch]):
7 """Filter audio by duration."""
8
9 name: str = "DurationFilter"
10 min_duration: float = 1.0
11 max_duration: float = 30.0
12
13 def inputs(self) -> tuple[list[str], list[str]]:
14 return ["data"], []
15
16 def outputs(self) -> tuple[list[str], list[str]]:
17 return ["data"], []
18
19 def process(self, task: AudioBatch) -> AudioBatch | None:
20 data = task.data
21
22 # Handle both single dict and list
23 if isinstance(data, dict):
24 data = [data]
25
26 filtered = [
27 item for item in data
28 if self.min_duration <= item.get("duration", 0) <= self.max_duration
29 ]
30
31 if not filtered:
32 return None
33
34 return AudioBatch(
35 task_id=f"{task.task_id}_filtered",
36 dataset_name=task.dataset_name,
37 data=filtered if len(filtered) > 1 else filtered[0],
38 _metadata=task._metadata,
39 _stage_perf=task._stage_perf,
40 )

Common Operations

ASR Transcription

1def process(self, task: AudioBatch) -> AudioBatch:
2 data = task.data if isinstance(task.data, list) else [task.data]
3
4 for item in data:
5 audio_path = item["audio_filepath"]
6 item["text"] = self.asr_model.transcribe(audio_path)
7
8 return AudioBatch(
9 task_id=f"{task.task_id}_{self.name}",
10 dataset_name=task.dataset_name,
11 data=data if len(data) > 1 else data[0],
12 _metadata=task._metadata,
13 _stage_perf=task._stage_perf,
14 )

Quality Scoring

1def process(self, task: AudioBatch) -> AudioBatch:
2 data = task.data if isinstance(task.data, list) else [task.data]
3
4 for item in data:
5 if "text" in item and "reference_text" in item:
6 item["wer"] = compute_wer(item["reference_text"], item["text"])
7
8 return AudioBatch(
9 task_id=f"{task.task_id}_{self.name}",
10 dataset_name=task.dataset_name,
11 data=data if len(data) > 1 else data[0],
12 _metadata=task._metadata,
13 _stage_perf=task._stage_perf,
14 )

Source Code

View source on GitHub