bridge.diffusion.models.flux.layers#

FLUX embedding layers for diffusion models.

Module Contents#

Classes#

EmbedND

N-Dimensional Rotary Position Embedding generator.

MLPEmbedder

MLP embedder with two projection layers and SiLU activation.

TimeStepEmbedder

A neural network module that embeds timesteps for use in diffusion models.

Functions#

rope

Compute rotary position embeddings.

API#

bridge.diffusion.models.flux.layers.rope(pos: torch.Tensor, dim: int, theta: int) torch.Tensor#

Compute rotary position embeddings.

Different from the original ROPE used for flux. Megatron attention takes the outer product and calculates sin/cos inside, so we only need to get the freqs here in the shape of [seq, …, dim].

Parameters:
  • pos – Position tensor.

  • dim – Embedding dimension (must be even).

  • theta – Base frequency.

Returns:

Rotary position embeddings of shape […, dim//2].

class bridge.diffusion.models.flux.layers.EmbedND(dim: int, theta: int, axes_dim: List[int])#

Bases: torch.nn.Module

N-Dimensional Rotary Position Embedding generator.

Generate Rope matrix with preset axes dimensions.

Parameters:
  • dim – Total embedding dimension.

  • theta – Base frequency for rotary embeddings.

  • axes_dim – List of dimensions for each axis.

Initialization

forward(ids: torch.Tensor) torch.Tensor#

Compute N-dimensional rotary position embeddings.

Parameters:

ids – Position IDs tensor of shape [batch, seq, n_axes].

Returns:

Rotary embeddings tensor.

class bridge.diffusion.models.flux.layers.MLPEmbedder(in_dim: int, hidden_dim: int)#

Bases: torch.nn.Module

MLP embedder with two projection layers and SiLU activation.

Parameters:
  • in_dim – Input dimension.

  • hidden_dim – Hidden/output dimension.

Initialization

forward(x: torch.Tensor) torch.Tensor#

Forward pass through the MLP embedder.

class bridge.diffusion.models.flux.layers.TimeStepEmbedder(
embedding_dim: int,
hidden_dim: int,
flip_sin_to_cos: bool = True,
downscale_freq_shift: float = 0,
scale: float = 1,
max_period: int = 10000,
)#

Bases: torch.nn.Module

A neural network module that embeds timesteps for use in diffusion models.

It projects the input timesteps to a higher-dimensional space and then embeds them using an MLP (Multilayer Perceptron). The projection and embedding provide a learned representation of the timestep that can be used in further computations.

Parameters:
  • embedding_dim – The dimensionality of the timestep embedding space.

  • hidden_dim – The dimensionality of the hidden layer in the MLPEmbedder.

  • flip_sin_to_cos – Whether to flip the sine and cosine components.

  • downscale_freq_shift – A scaling factor for the frequency shift.

  • scale – A scaling factor applied to the timestep projections.

  • max_period – The maximum period for the sine and cosine functions.

Initialization

forward(timesteps: torch.Tensor) torch.Tensor#

Compute timestep embeddings.

Parameters:

timesteps – Input timestep tensor.

Returns:

Embedded timesteps tensor.