Building Generative Models for Continuous Data via Continuous Interpolants¶
import math
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.datasets import make_moons
Task Setup¶
To demonstrate how Conditional Flow Matching works we use sklearn to sample from and create custom 2D distriubtions.
To start we define our "dataloader" so to speak. This is the '''sample_moons''' function.
Next we define a custom PriorDistribution to enable the conversion of 8 equidistance gaussians to the moon distribution above.
def sample_moons(n, normalize = False):
x1, _ = make_moons(n_samples=n, noise=0.08)
x1 = torch.Tensor(x1)
x1 = x1 * 3 - 1
if normalize:
x1 = (x1 - x1.mean(0))/x1.std(0) * 2
return x1
x1 = sample_moons(1000)
plt.scatter(x1[:, 0], x1[:, 1])
<matplotlib.collections.PathCollection at 0x71fad37d11e0>
Model Creation¶
Here we define a simple 4 layer MLP and define our optimizer
dim = 2
hidden_size = 64
batch_size = 256
model = torch.nn.Sequential(
torch.nn.Linear(dim + 1, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, dim),
)
optimizer = torch.optim.Adam(model.parameters())
Continuous Flow Matching Interpolant¶
Here we import our desired interpolant objects.
The continuous flow matcher and the desired time distribution.
from bionemo.moco.interpolants import ContinuousFlowMatcher
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.distributions.prior import GaussianPrior
uniform_time = UniformTimeDistribution()
simple_prior = GaussianPrior()
sigma = 0.1
interpolant = ContinuousFlowMatcher(time_distribution=uniform_time,
prior_distribution=simple_prior,
sigma=sigma,
prediction_type="velocity")
# Place both the model and the interpolant on the same device
DEVICE = "cuda"
model = model.to(DEVICE)
interpolant = interpolant.to_device(DEVICE)
Training Loop¶
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = interpolant.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = interpolant.sample_time(batch_size)
xt = interpolant.interpolate(x1, t, x0)
ut = interpolant.calculate_target(x1, x0)
vt = model(torch.cat([xt, t[:, None]], dim=-1))
loss = interpolant.loss(vt, ut, target_type="velocity").mean()
loss.backward()
optimizer.step()
if (k + 1) % 5000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
5000: loss 2.766 10000: loss 2.730 15000: loss 3.084 20000: loss 2.839
Setting Up Generation¶
Now we need to import the desired inference time schedule. This is what gives us the time values to iterate through to iteratively generate from our model.
Here we show the output time schedule as well as the discretization between time points. We note that different inference time schedules may have different shapes resulting in non uniform dt
from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
inference_sched = LinearInferenceSchedule(nsteps = 100)
schedule = inference_sched.generate_schedule().to(DEVICE)
dts = inference_sched.discretize().to(DEVICE)
schedule, dts
(tensor([0.0000, 0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700, 0.0800, 0.0900, 0.1000, 0.1100, 0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700, 0.1800, 0.1900, 0.2000, 0.2100, 0.2200, 0.2300, 0.2400, 0.2500, 0.2600, 0.2700, 0.2800, 0.2900, 0.3000, 0.3100, 0.3200, 0.3300, 0.3400, 0.3500, 0.3600, 0.3700, 0.3800, 0.3900, 0.4000, 0.4100, 0.4200, 0.4300, 0.4400, 0.4500, 0.4600, 0.4700, 0.4800, 0.4900, 0.5000, 0.5100, 0.5200, 0.5300, 0.5400, 0.5500, 0.5600, 0.5700, 0.5800, 0.5900, 0.6000, 0.6100, 0.6200, 0.6300, 0.6400, 0.6500, 0.6600, 0.6700, 0.6800, 0.6900, 0.7000, 0.7100, 0.7200, 0.7300, 0.7400, 0.7500, 0.7600, 0.7700, 0.7800, 0.7900, 0.8000, 0.8100, 0.8200, 0.8300, 0.8400, 0.8500, 0.8600, 0.8700, 0.8800, 0.8900, 0.9000, 0.9100, 0.9200, 0.9300, 0.9400, 0.9500, 0.9600, 0.9700, 0.9800, 0.9900], device='cuda:0'), tensor([0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100], device='cuda:0'))
Sample from the trained model¶
inf_size = 1024
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for dt, t in zip(dts, schedule):
full_t = inference_sched.pad_time(inf_size, t, DEVICE)
vt = model(torch.cat([sample, full_t[:, None]], dim=-1)) # calculate the vector field based on the definition of the model
sample = interpolant.step(vt, sample, dt, full_t)
trajectory.append(sample) # save the trajectory for plotting purposes
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
plot_limit = 1024
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
Sample from underlying score model¶
low temperature sampling is a heuristic, unclear what effects it has on the final distribution. Intuitively, it cuts tails and focuses more on the mode, in practice who knows exactly what's the final effect.¶
gt_mode is a hyperparameter that must be experimentally chosen¶
inf_size = 1024
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE)
trajectory = [sample.detach().cpu()]
for dt, t in zip(dts, schedule):
time = inference_sched.pad_time(inf_size, t, DEVICE) #torch.full((inf_size,), t).to(DEVICE)
with torch.no_grad():
vt = model(torch.cat([sample, time[:, None]], dim=-1))
sample = interpolant.step_score_stochastic(vt, sample, dt, time, noise_temperature=1.0, gt_mode = "tan")
trajectory.append(sample.detach().cpu())
traj = torch.stack(trajectory).cpu().detach().numpy()
plot_limit = 1024
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(0)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(1)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.title("Stochastic score sampling Temperature = 1.0")
plt.show()
What happens if you just sample from a random model?¶
model = torch.nn.Sequential(
torch.nn.Linear(dim + 1, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, dim),
).to(DEVICE)
inf_size = 1024
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE)
trajectory = [sample.detach().cpu()]
for dt, t in zip(dts, schedule):
time = inference_sched.pad_time(inf_size, t, DEVICE) #torch.full((inf_size,), t).to(DEVICE)
with torch.no_grad():
vt = model(torch.cat([sample, time[:, None]], dim=-1))
sample = interpolant.step(vt, sample, dt, time)
trajectory.append(sample.detach().cpu())
plot_limit = 1024
traj = torch.stack(trajectory).cpu().detach().numpy()
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(0)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(1)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
import math
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module):
def __init__(
self, dim_in: int, dim_out: int, dim_hids: List[int],
):
super().__init__()
self.layers = nn.ModuleList([
TimeLinear(dim_in, dim_hids[0]),
*[TimeLinear(dim_hids[i-1], dim_hids[i]) for i in range(1, len(dim_hids))],
TimeLinear(dim_hids[-1], dim_out)
])
def forward(self, x: torch.Tensor, t: torch.Tensor):
for i, layer in enumerate(self.layers):
x = layer(x, t)
if i < len(self.layers) - 1:
x = F.relu(x)
return x
class TimeLinear(nn.Module):
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.time_embedding = TimeEmbedding(dim_out)
self.fc = nn.Linear(dim_in, dim_out)
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.fc(x)
alpha = self.time_embedding(t).view(-1, self.dim_out)
return alpha * x
class TimeEmbedding(nn.Module):
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t: torch.Tensor):
if t.ndim == 0:
t = t.unsqueeze(-1)
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.interpolants import DDPM
from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule, DiscreteLinearNoiseSchedule
from bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule
from bionemo.moco.distributions.prior import GaussianPrior
DEVICE = "cuda:0"
uniform_time = UniformTimeDistribution(discrete_time=True, nsteps = 1000)
simple_prior = GaussianPrior()
interpolant = DDPM(time_distribution=uniform_time,
prior_distribution=simple_prior,
prediction_type = "noise",
noise_schedule = DiscreteLinearNoiseSchedule(nsteps = 1000),
device=DEVICE)
Train the Model¶
# Place both the model and the interpolant on the same device
dim = 2
hidden_size = 128
num_hiddens = 3
batch_size = 256
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens)
optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3)
DEVICE = "cuda"
model = model.to(DEVICE)
interpolant = interpolant.to_device(DEVICE)
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = interpolant.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = interpolant.sample_time(batch_size)
xt = interpolant.interpolate(x1, t, x0)
eps = model(xt, t)
loss = interpolant.loss(eps, x0, t).mean()
loss.backward()
optimizer.step()
if (k + 1) % 1000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
1000: loss 0.337 2000: loss 0.320 3000: loss 0.260 4000: loss 0.328 5000: loss 0.324 6000: loss 0.427 7000: loss 0.254 8000: loss 0.352 9000: loss 0.365 10000: loss 0.390 11000: loss 0.332 12000: loss 0.265 13000: loss 0.362 14000: loss 0.394 15000: loss 0.405 16000: loss 0.340 17000: loss 0.326 18000: loss 0.357 19000: loss 0.330 20000: loss 0.409
Let's vizualize what the interpolation looks like during training for different times¶
x0 = interpolant.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
for t in range(0, 900, 100):
tt = interpolant.sample_time(batch_size)*0 + t
out = interpolant.interpolate(x1, tt, x0)
plt.scatter(out[:, 0].cpu().detach(), out[:, 1].cpu().detach())
plt.title(f"Time = {t}")
plt.show()
Create the inference time schedule and sample from the model¶
inf_size = 1024
schedule = DiscreteLinearInferenceSchedule(nsteps = 1000, direction = "diffusion").generate_schedule(device= DEVICE)
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
vt = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = interpolant.step_noise(vt, full_t, sample)
trajectory.append(sample) # save the trajectory for plotting purposes
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
plot_limit = 1024
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
/home/dreidenbach/mambaforge/envs/moco_bionemo/lib/python3.10/site-packages/IPython/core/pylabtools.py:170: UserWarning: Creating legend with loc="best" can be slow with large amounts of data. fig.canvas.print_figure(bytes_io, **kw)
inf_size = 1024
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample.detach().cpu()]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
with torch.no_grad():
eps_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = interpolant.step(eps_hat, full_t, sample)
trajectory.append(sample.detach().cpu()) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
plot_limit = 1024
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
Notice that his yields very similar results to using the underlying score function in the stochastic score based CFM example¶
Notice that there is no difference whether or not we convert the predicted noise to data inside thte .step() function¶
Let's try other cool sampling functions¶
inf_size = 1024
schedule = DiscreteLinearInferenceSchedule(nsteps = 1000, direction = "diffusion").generate_schedule(device= DEVICE)
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample.detach().cpu()]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
with torch.no_grad():
eps_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = interpolant.step_ddim(eps_hat, full_t, sample)
trajectory.append(sample.detach().cpu()) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
plot_limit = 1024
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
What happens when you sample from an untrained model with DDPM¶
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens).to(DEVICE)
inf_size = 1024
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE)
trajectory = [sample.detach().cpu()]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
with torch.no_grad():
vt = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = interpolant.step_noise(vt, full_t, sample)
trajectory.append(sample.detach().cpu()) #
plot_limit = 1024
traj = torch.stack(trajectory).cpu().detach().numpy()
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(0)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(1)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
Now let's switch the parameterization of DDPM from noise to data¶
Here instead of training the model to learn the noise we want to learn the raw data. Both options are valid and the choice of which depends on the underlying modeling task.
from bionemo.moco.distributions.time.uniform import UniformTimeDistribution
from bionemo.moco.interpolants.discrete_time.continuous.ddpm import DDPM
from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule, DiscreteLinearNoiseSchedule
from bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule
from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
DEVICE = "cuda:0"
uniform_time = UniformTimeDistribution(discrete_time=True, nsteps = 1000)
simple_prior = GaussianPrior()
interpolant = DDPM(time_distribution=uniform_time,
prior_distribution=simple_prior,
prediction_type = "data",
noise_schedule = DiscreteLinearNoiseSchedule(nsteps = 1000),
device=DEVICE)
Let us first train the model with a weight such that it is theoretically equivalent to the simple noise matching loss. See Equation 9 from https://arxiv.org/pdf/2202.00512¶
# Place both the model and the interpolant on the same device
dim = 2
hidden_size = 128
num_hiddens = 3
batch_size = 256
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens)
optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3)
DEVICE = "cuda"
model = model.to(DEVICE)
interpolant = interpolant.to_device(DEVICE)
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = interpolant.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = interpolant.sample_time(batch_size)
xt = interpolant.interpolate(x1, t, x0)
x_hat = model(xt, t)
loss = interpolant.loss(x_hat, x1, t, weight_type="data_to_noise").mean()
loss.backward()
optimizer.step()
if (k + 1) % 1000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
1000: loss 0.908 2000: loss 32.665 3000: loss 0.371 4000: loss 0.970 5000: loss 0.434 6000: loss 0.814 7000: loss 0.599 8000: loss 0.545 9000: loss 0.594 10000: loss 5.172 11000: loss 0.415 12000: loss 0.699 13000: loss 0.400 14000: loss 0.416 15000: loss 0.904 16000: loss 0.785 17000: loss 0.428 18000: loss 0.541 19000: loss 0.346 20000: loss 1.336
inf_size = 1024
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample.detach().cpu()]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
with torch.no_grad():
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = interpolant.step(x_hat, full_t, sample)
trajectory.append(sample.detach().cpu()) # save the trajectory for plotting purposes
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
plot_limit = 1024
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
Now let us train with no loss weighting to optimize a true data matching loss for comparison¶
# Place both the model and the interpolant on the same device
dim = 2
hidden_size = 128
num_hiddens = 3
batch_size = 256
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens)
optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3)
DEVICE = "cuda"
model = model.to(DEVICE)
interpolant = interpolant.to_device(DEVICE)
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = interpolant.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = interpolant.sample_time(batch_size)
xt = interpolant.interpolate(x1, t, x0)
x_hat = model(xt, t)
loss = interpolant.loss(x_hat, x1, t, weight_type="ones").mean()
loss.backward()
optimizer.step()
if (k + 1) % 1000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
1000: loss 2.489 2000: loss 2.569 3000: loss 2.848 4000: loss 2.444 5000: loss 2.644 6000: loss 2.609 7000: loss 2.766 8000: loss 2.713 9000: loss 2.555 10000: loss 2.639 11000: loss 2.778 12000: loss 2.800 13000: loss 2.509 14000: loss 2.518 15000: loss 2.562 16000: loss 2.739 17000: loss 2.990 18000: loss 2.300 19000: loss 2.318 20000: loss 2.638
inf_size = 1024
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample.detach().cpu()]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
with torch.no_grad():
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = interpolant.step(x_hat, full_t, sample)
trajectory.append(sample.detach().cpu()) # save the trajectory for plotting purposes
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
plot_limit = 1024
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
Now let's try a continuous time analog interpolant to DDPM called VDM¶
This interpolant was used in Chroma and is described in great detail here https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf¶
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.interpolants import VDM
from bionemo.moco.schedules.noise.continuous_snr_transforms import CosineSNRTransform, LinearSNRTransform, LinearLogInterpolatedSNRTransform
from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
DEVICE = "cuda:0"
uniform_time = UniformTimeDistribution(discrete_time=False)
simple_prior = GaussianPrior()
interpolant = VDM(time_distribution=uniform_time,
prior_distribution=simple_prior,
prediction_type = "data",
noise_schedule = LinearLogInterpolatedSNRTransform(),
device=DEVICE)
schedule = LinearInferenceSchedule(nsteps = 1000, direction="diffusion")
# Place both the model and the interpolant on the same device
dim = 2
hidden_size = 128
num_hiddens = 3
batch_size = 256
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens)
DEVICE = "cuda"
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3)
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = interpolant.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = interpolant.sample_time(batch_size)
xt = interpolant.interpolate(x1, t, x0)
x_hat = model(xt, t)
loss = interpolant.loss(x_hat, x1, t, weight_type="ones").mean()
loss.backward()
optimizer.step()
if (k + 1) % 1000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
1000: loss 1.331 2000: loss 1.067 3000: loss 1.343 4000: loss 1.291 5000: loss 1.249 6000: loss 0.954 7000: loss 1.063 8000: loss 1.179 9000: loss 1.246 10000: loss 1.641 11000: loss 1.088 12000: loss 1.208 13000: loss 1.274 14000: loss 0.927 15000: loss 1.078 16000: loss 1.109 17000: loss 1.046 18000: loss 1.235 19000: loss 1.325 20000: loss 1.159
inf_size = 1024
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample.detach().cpu()]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
with torch.no_grad():
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = interpolant.step(x_hat, full_t, sample, dt)
trajectory.append(sample.detach().cpu()) # save the trajectory for plotting purposes
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
plot_limit = 1024
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
inf_size = 1024
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample.detach().cpu()]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
with torch.no_grad():
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = interpolant.step_ddim(x_hat, full_t, sample, dt)
trajectory.append(sample.detach().cpu()) # save the trajectory for plotting purposes
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
plot_limit = 1024
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
What is interesting here is that the deterministic sampling of DDIM best recovers the Flow Matching ODE samples¶
inf_size = 1024
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample.detach().cpu()]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
with torch.no_grad():
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
# sample = interpolant.step_hybrid_sde(x_hat, full_t, sample, dt)
sample = interpolant.step_ode(x_hat, full_t, sample, dt)
trajectory.append(sample.detach().cpu()) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
plot_limit = 1024
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
inf_size = 1024
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample.detach().cpu()]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
with torch.no_grad():
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
# sample = interpolant.step_hybrid_sde(x_hat, full_t, sample, dt)
sample = interpolant.step_ode(x_hat, full_t, sample, dt, temperature = 1.5)
trajectory.append(sample.detach().cpu()) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
plot_limit = 1024
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
inf_size = 1024
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample.detach().cpu()]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
with torch.no_grad():
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
# sample = interpolant.step_hybrid_sde(x_hat, full_t, sample, dt)
sample = interpolant.step_ode(x_hat, full_t, sample, dt, temperature = 0.5)
trajectory.append(sample.detach().cpu()) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
plot_limit = 1024
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
inf_size = 1024
sample = interpolant.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample.detach().cpu()]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
with torch.no_grad():
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = interpolant.step_hybrid_sde(x_hat, full_t, sample, dt)
# sample = interpolant.step_ode(x_hat, full_t, sample, dt)
trajectory.append(sample.detach().cpu()) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
plot_limit = 1024
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :plot_limit, 0], traj[0, :plot_limit, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :plot_limit, 0], traj[i, :plot_limit, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :plot_limit, 0], traj[-1, :plot_limit, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()