nemo_automodel.components.datasets.diffusion.base_dataset

View as Markdown

Module Contents

Classes

NameDescription
BaseMultiresolutionDatasetAbstract base class for multiresolution datasets with bucket-based sampling.

Data

logger

API

class nemo_automodel.components.datasets.diffusion.base_dataset.BaseMultiresolutionDataset(
cache_dir: str,
quantization: int = 64
)
Abstract

Bases: Dataset

Abstract base class for multiresolution datasets with bucket-based sampling.

cache_dir
= Path(cache_dir)
calculator
metadata
= self._load_metadata()
nemo_automodel.components.datasets.diffusion.base_dataset.BaseMultiresolutionDataset.__getitem__(
idx: int
) -> typing.Dict
abstract

Load a single sample. Subclasses must implement.

nemo_automodel.components.datasets.diffusion.base_dataset.BaseMultiresolutionDataset.__len__() -> int
nemo_automodel.components.datasets.diffusion.base_dataset.BaseMultiresolutionDataset._aspect_ratio_to_name(
aspect_ratio: float
) -> str

Convert aspect ratio to a descriptive name.

nemo_automodel.components.datasets.diffusion.base_dataset.BaseMultiresolutionDataset._group_by_bucket()

Group samples by bucket (aspect_ratio + resolution).

nemo_automodel.components.datasets.diffusion.base_dataset.BaseMultiresolutionDataset._load_metadata() -> typing.List[typing.Dict]

Load metadata from cache directory.

Expects metadata.json with “shards” key referencing shard files.

nemo_automodel.components.datasets.diffusion.base_dataset.BaseMultiresolutionDataset.get_bucket_info() -> typing.Dict

Get bucket organization information.

nemo_automodel.components.datasets.diffusion.base_dataset.logger = logging.getLogger(__name__)