nemo_automodel.components.models.bagel.autoencoder#

Autoencoder used by BAGEL Stage 2 image generation training.

Module Contents#

Classes#

AutoEncoderParams

Architecture parameters for the BAGEL/FLUX autoencoder.

AttnBlock

Single-head spatial attention block used in the VAE bottleneck.

ResnetBlock

Residual convolution block used by the autoencoder.

Downsample

Stride-2 downsample with explicit asymmetric padding.

Upsample

Nearest-neighbor upsample followed by a 3x3 convolution.

Encoder

BAGEL/FLUX autoencoder encoder.

Decoder

BAGEL/FLUX autoencoder decoder.

DiagonalGaussian

Convert latent moments to a Gaussian sample or mean.

AutoEncoder

BAGEL Stage 2 autoencoder wrapper.

Functions#

swish

Swish activation.

_log_load_warning

default_autoencoder_params

Return the BAGEL-7B-MoT autoencoder architecture parameters.

load_bagel_autoencoder

Load the BAGEL autoencoder from ae.safetensors.

Data#

API#

nemo_automodel.components.models.bagel.autoencoder.logger#

‘getLogger(…)’

class nemo_automodel.components.models.bagel.autoencoder.AutoEncoderParams#

Architecture parameters for the BAGEL/FLUX autoencoder.

resolution: int#

None

in_channels: int#

None

downsample: int#

None

ch: int#

None

out_ch: int#

None

ch_mult: list[int]#

None

num_res_blocks: int#

None

z_channels: int#

None

scale_factor: float#

None

shift_factor: float#

None

nemo_automodel.components.models.bagel.autoencoder.swish(x: torch.Tensor) torch.Tensor#

Swish activation.

class nemo_automodel.components.models.bagel.autoencoder.AttnBlock(in_channels: int)#

Bases: torch.nn.Module

Single-head spatial attention block used in the VAE bottleneck.

Initialization

attention(h_: torch.Tensor) torch.Tensor#

Apply scaled dot-product attention over flattened image positions.

forward(x: torch.Tensor) torch.Tensor#

Apply residual attention.

class nemo_automodel.components.models.bagel.autoencoder.ResnetBlock(in_channels: int, out_channels: int)#

Bases: torch.nn.Module

Residual convolution block used by the autoencoder.

Initialization

forward(x: torch.Tensor) torch.Tensor#

Run the residual block.

class nemo_automodel.components.models.bagel.autoencoder.Downsample(in_channels: int)#

Bases: torch.nn.Module

Stride-2 downsample with explicit asymmetric padding.

Initialization

forward(x: torch.Tensor) torch.Tensor#

Downsample spatial dimensions by 2.

class nemo_automodel.components.models.bagel.autoencoder.Upsample(in_channels: int)#

Bases: torch.nn.Module

Nearest-neighbor upsample followed by a 3x3 convolution.

Initialization

forward(x: torch.Tensor) torch.Tensor#

Upsample spatial dimensions by 2.

class nemo_automodel.components.models.bagel.autoencoder.Encoder(
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
)#

Bases: torch.nn.Module

BAGEL/FLUX autoencoder encoder.

Initialization

forward(x: torch.Tensor) torch.Tensor#

Encode an image tensor to Gaussian latent moments.

class nemo_automodel.components.models.bagel.autoencoder.Decoder(
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
)#

Bases: torch.nn.Module

BAGEL/FLUX autoencoder decoder.

Initialization

forward(z: torch.Tensor) torch.Tensor#

Decode latents to image tensors.

class nemo_automodel.components.models.bagel.autoencoder.DiagonalGaussian(sample: bool = True, chunk_dim: int = 1)#

Bases: torch.nn.Module

Convert latent moments to a Gaussian sample or mean.

Initialization

forward(z: torch.Tensor) torch.Tensor#

Sample from or return the mean of a diagonal Gaussian.

class nemo_automodel.components.models.bagel.autoencoder.AutoEncoder(
params: nemo_automodel.components.models.bagel.autoencoder.AutoEncoderParams,
)#

Bases: torch.nn.Module

BAGEL Stage 2 autoencoder wrapper.

Initialization

encode(x: torch.Tensor) torch.Tensor#

Encode image tensors to scaled latents.

decode(z: torch.Tensor) torch.Tensor#

Decode scaled latents to image tensors.

forward(x: torch.Tensor) torch.Tensor#

Encode and decode image tensors.

nemo_automodel.components.models.bagel.autoencoder._log_load_warning(missing: list[str], unexpected: list[str]) None#
nemo_automodel.components.models.bagel.autoencoder.default_autoencoder_params() nemo_automodel.components.models.bagel.autoencoder.AutoEncoderParams#

Return the BAGEL-7B-MoT autoencoder architecture parameters.

nemo_automodel.components.models.bagel.autoencoder.load_bagel_autoencoder(
local_path: str | None,
) tuple[nemo_automodel.components.models.bagel.autoencoder.AutoEncoder, nemo_automodel.components.models.bagel.autoencoder.AutoEncoderParams]#

Load the BAGEL autoencoder from ae.safetensors.

Parameters:

local_path – Local path to ae.safetensors. If None, the module is returned with randomly initialized weights.

Returns:

The autoencoder module and its architecture parameters.