nemo_automodel.components.datasets.diffusion.multi_tier_bucketing

View as Markdown

Module Contents

Classes

NameDescription
MultiTierBucketCalculatorCalculate resolution buckets constrained by a maximum pixel budget.

Data

logger

API

class nemo_automodel.components.datasets.diffusion.multi_tier_bucketing.MultiTierBucketCalculator(
quantization: int = 64,
max_pixels: typing.Optional[int] = None,
debug_mode: bool = False
)

Calculate resolution buckets constrained by a maximum pixel budget. Supports various aspect ratios, each scaled to fit within the pixel budget.

ASPECT_RATIOS
RESOLUTION_PRESETS
buckets
= self._generate_all_buckets()
max_pixels
nemo_automodel.components.datasets.diffusion.multi_tier_bucketing.MultiTierBucketCalculator._build_lookup_structures()

Build efficient lookup structures.

nemo_automodel.components.datasets.diffusion.multi_tier_bucketing.MultiTierBucketCalculator._calculate_max_resolution(
aspect_ratio: float
) -> typing.Optional[typing.Tuple[int, int]]

Calculate the maximum resolution for an aspect ratio within the pixel budget.

For a given aspect ratio r = w/h, and pixel budget P: w * h <= P w = r * h r * h * h <= P h <= sqrt(P / r)

Then w = r * h

nemo_automodel.components.datasets.diffusion.multi_tier_bucketing.MultiTierBucketCalculator._generate_all_buckets() -> typing.List[typing.Dict]

Generate all unique resolution buckets within the pixel budget.

nemo_automodel.components.datasets.diffusion.multi_tier_bucketing.MultiTierBucketCalculator._print_bucket_summary()

Print summary of generated buckets.

nemo_automodel.components.datasets.diffusion.multi_tier_bucketing.MultiTierBucketCalculator._round_to_quantization(
value: int
) -> int

Round value to nearest quantization multiple.

classmethod

Create calculator from a named preset.

Parameters:

preset
str

One of ‘256p’, ‘512p’, ‘768p’, ‘1024p’, ‘1536p’

quantization
intDefaults to 64

Resolution quantization

Returns: MultiTierBucketCalculator

MultiTierBucketCalculator instance

nemo_automodel.components.datasets.diffusion.multi_tier_bucketing.MultiTierBucketCalculator.get_all_buckets() -> typing.List[typing.Dict]

Get all buckets.

nemo_automodel.components.datasets.diffusion.multi_tier_bucketing.MultiTierBucketCalculator.get_bucket_by_id(
bucket_id: int
) -> typing.Dict

Get bucket by ID.

nemo_automodel.components.datasets.diffusion.multi_tier_bucketing.MultiTierBucketCalculator.get_bucket_by_resolution(
width: int,
height: int
) -> typing.Optional[typing.Dict]

Get bucket by exact resolution.

nemo_automodel.components.datasets.diffusion.multi_tier_bucketing.MultiTierBucketCalculator.get_bucket_for_image(
image_width: int,
image_height: int
) -> typing.Dict

Get the best bucket for an image.

Parameters:

image_width
int

Original image width

image_height
int

Original image height

max_pixels

Override max pixels for this query (deprecated, use constructor)

Returns: Dict

Bucket dictionary with resolution and metadata

nemo_automodel.components.datasets.diffusion.multi_tier_bucketing.MultiTierBucketCalculator.get_dynamic_batch_size(
resolution: typing.Tuple[int, int],
base_batch_size: int = 32,
base_resolution: typing.Tuple[int, int] = (512, 512)
) -> int

Calculate dynamic batch size based on resolution. Larger images get smaller batches to maintain GPU memory usage.

Parameters:

resolution
Tuple[int, int]

(width, height)

base_batch_size
intDefaults to 32

Batch size for base resolution

base_resolution
Tuple[int, int]Defaults to (512, 512)

Reference resolution

Returns: int

Recommended batch size

nemo_automodel.components.datasets.diffusion.multi_tier_bucketing.MultiTierBucketCalculator.resize_and_crop(
image,
target_width: int,
target_height: int,
crop_mode: str = 'center'
) -> typing.Tuple

Resize and crop image to target resolution.

Parameters:

image

PIL Image or numpy array

target_width
int

Target width

target_height
int

Target height

crop_mode
strDefaults to 'center'

‘center’, ‘random’, or ‘smart’

Returns: Tuple

(resized_image, crop_offset_x, crop_offset_y)

nemo_automodel.components.datasets.diffusion.multi_tier_bucketing.logger = logging.getLogger(__name__)