nemo_automodel.components.models.bagel.autoencoder

View as Markdown

Autoencoder used by BAGEL Stage 2 image generation training.

Module Contents

Classes

NameDescription
AttnBlockSingle-head spatial attention block used in the VAE bottleneck.
AutoEncoderBAGEL Stage 2 autoencoder wrapper.
AutoEncoderParamsArchitecture parameters for the BAGEL/FLUX autoencoder.
DecoderBAGEL/FLUX autoencoder decoder.
DiagonalGaussianConvert latent moments to a Gaussian sample or mean.
DownsampleStride-2 downsample with explicit asymmetric padding.
EncoderBAGEL/FLUX autoencoder encoder.
ResnetBlockResidual convolution block used by the autoencoder.
UpsampleNearest-neighbor upsample followed by a 3x3 convolution.

Functions

NameDescription
_log_load_warning-
default_autoencoder_paramsReturn the BAGEL-7B-MoT autoencoder architecture parameters.
load_bagel_autoencoderLoad the BAGEL autoencoder from ae.safetensors.
swishSwish activation.

Data

logger

API

class nemo_automodel.components.models.bagel.autoencoder.AttnBlock(
in_channels: int
)

Bases: Module

Single-head spatial attention block used in the VAE bottleneck.

k
= nn.Conv2d(in_channels, in_channels, kernel_size=1)
norm
proj_out
= nn.Conv2d(in_channels, in_channels, kernel_size=1)
q
= nn.Conv2d(in_channels, in_channels, kernel_size=1)
v
= nn.Conv2d(in_channels, in_channels, kernel_size=1)
nemo_automodel.components.models.bagel.autoencoder.AttnBlock.attention(
h_: torch.Tensor
) -> torch.Tensor

Apply scaled dot-product attention over flattened image positions.

nemo_automodel.components.models.bagel.autoencoder.AttnBlock.forward(
x: torch.Tensor
) -> torch.Tensor

Apply residual attention.

class nemo_automodel.components.models.bagel.autoencoder.AutoEncoder(
params: nemo_automodel.components.models.bagel.autoencoder.AutoEncoderParams
)

Bases: Module

BAGEL Stage 2 autoencoder wrapper.

decoder
encoder
reg
= DiagonalGaussian()
scale_factor
= params.scale_factor
shift_factor
= params.shift_factor
nemo_automodel.components.models.bagel.autoencoder.AutoEncoder.decode(
z: torch.Tensor
) -> torch.Tensor

Decode scaled latents to image tensors.

nemo_automodel.components.models.bagel.autoencoder.AutoEncoder.encode(
x: torch.Tensor
) -> torch.Tensor

Encode image tensors to scaled latents.

nemo_automodel.components.models.bagel.autoencoder.AutoEncoder.forward(
x: torch.Tensor
) -> torch.Tensor

Encode and decode image tensors.

class nemo_automodel.components.models.bagel.autoencoder.AutoEncoderParams(
resolution: int,
in_channels: int,
downsample: int,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
scale_factor: float,
shift_factor: float
)
Dataclass

Architecture parameters for the BAGEL/FLUX autoencoder.

ch
int
ch_mult
list[int]
downsample
int
in_channels
int
num_res_blocks
int
out_ch
int
resolution
int
scale_factor
float
shift_factor
float
z_channels
int
class nemo_automodel.components.models.bagel.autoencoder.Decoder(
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int
)

Bases: Module

BAGEL/FLUX autoencoder decoder.

conv_in
conv_out
ffactor
= 2 ** (self.num_resolutions - 1)
mid
= nn.Module()
norm_out
num_resolutions
= len(ch_mult)
up
= nn.ModuleList()
z_shape
= (1, z_channels, curr_res, curr_res)
nemo_automodel.components.models.bagel.autoencoder.Decoder.forward(
z: torch.Tensor
) -> torch.Tensor

Decode latents to image tensors.

class nemo_automodel.components.models.bagel.autoencoder.DiagonalGaussian(
sample: bool = True,
chunk_dim: int = 1
)

Bases: Module

Convert latent moments to a Gaussian sample or mean.

nemo_automodel.components.models.bagel.autoencoder.DiagonalGaussian.forward(
z: torch.Tensor
) -> torch.Tensor

Sample from or return the mean of a diagonal Gaussian.

class nemo_automodel.components.models.bagel.autoencoder.Downsample(
in_channels: int
)

Bases: Module

Stride-2 downsample with explicit asymmetric padding.

conv
nemo_automodel.components.models.bagel.autoencoder.Downsample.forward(
x: torch.Tensor
) -> torch.Tensor

Downsample spatial dimensions by 2.

class nemo_automodel.components.models.bagel.autoencoder.Encoder(
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int
)

Bases: Module

BAGEL/FLUX autoencoder encoder.

conv_in
conv_out
down
= nn.ModuleList()
mid
= nn.Module()
norm_out
num_resolutions
= len(ch_mult)
nemo_automodel.components.models.bagel.autoencoder.Encoder.forward(
x: torch.Tensor
) -> torch.Tensor

Encode an image tensor to Gaussian latent moments.

class nemo_automodel.components.models.bagel.autoencoder.ResnetBlock(
in_channels: int,
out_channels: int
)

Bases: Module

Residual convolution block used by the autoencoder.

conv1
conv2
nin_shortcut
norm1
norm2
nemo_automodel.components.models.bagel.autoencoder.ResnetBlock.forward(
x: torch.Tensor
) -> torch.Tensor

Run the residual block.

class nemo_automodel.components.models.bagel.autoencoder.Upsample(
in_channels: int
)

Bases: Module

Nearest-neighbor upsample followed by a 3x3 convolution.

conv
nemo_automodel.components.models.bagel.autoencoder.Upsample.forward(
x: torch.Tensor
) -> torch.Tensor

Upsample spatial dimensions by 2.

nemo_automodel.components.models.bagel.autoencoder._log_load_warning(
missing: list[str],
unexpected: list[str]
) -> None
nemo_automodel.components.models.bagel.autoencoder.default_autoencoder_params() -> nemo_automodel.components.models.bagel.autoencoder.AutoEncoderParams

Return the BAGEL-7B-MoT autoencoder architecture parameters.

nemo_automodel.components.models.bagel.autoencoder.load_bagel_autoencoder(
local_path: str | None
) -> tuple[nemo_automodel.components.models.bagel.autoencoder.AutoEncoder, nemo_automodel.components.models.bagel.autoencoder.AutoEncoderParams]

Load the BAGEL autoencoder from ae.safetensors.

Parameters:

local_path
str | None

Local path to ae.safetensors. If None, the module is returned with randomly initialized weights.

Returns: tuple[AutoEncoder, AutoEncoderParams]

The autoencoder module and its architecture parameters.

nemo_automodel.components.models.bagel.autoencoder.swish(
x: torch.Tensor
) -> torch.Tensor

Swish activation.

nemo_automodel.components.models.bagel.autoencoder.logger = logging.getLogger(__name__)