-
Notifications
You must be signed in to change notification settings - Fork 384
/
diffusion_utils.py
299 lines (261 loc) · 12.4 KB
/
diffusion_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
import numpy as np
import tensorflow.compat.v1 as tf
from . import nn
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + tf.exp(logvar1 - logvar2)
+ tf.squared_difference(mean1, mean2) * tf.exp(-logvar2))
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
warmup_time = int(num_diffusion_timesteps * warmup_frac)
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
return betas
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
if beta_schedule == 'quad':
betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2
elif beta_schedule == 'linear':
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == 'warmup10':
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
elif beta_schedule == 'warmup50':
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
elif beta_schedule == 'const':
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1. / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
def noise_like(shape, noise_fn=tf.random_normal, repeat=False, dtype=tf.float32):
repeat_noise = lambda: tf.repeat(noise_fn(shape=(1, *shape[1:]), dtype=dtype), repeats=shape[0], axis=0)
noise = lambda: noise_fn(shape=shape, dtype=dtype)
return repeat_noise() if repeat else noise()
class GaussianDiffusion:
"""
Contains utilities for the diffusion model.
"""
def __init__(self, *, betas, loss_type, tf_dtype=tf.float32):
self.loss_type = loss_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
assert alphas_cumprod_prev.shape == (timesteps,)
self.betas = tf.constant(betas, dtype=tf_dtype)
self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf_dtype)
self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf_dtype)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = tf.constant(np.sqrt(alphas_cumprod), dtype=tf_dtype)
self.sqrt_one_minus_alphas_cumprod = tf.constant(np.sqrt(1. - alphas_cumprod), dtype=tf_dtype)
self.log_one_minus_alphas_cumprod = tf.constant(np.log(1. - alphas_cumprod), dtype=tf_dtype)
self.sqrt_recip_alphas_cumprod = tf.constant(np.sqrt(1. / alphas_cumprod), dtype=tf_dtype)
self.sqrt_recipm1_alphas_cumprod = tf.constant(np.sqrt(1. / alphas_cumprod - 1), dtype=tf_dtype)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = tf.constant(posterior_variance, dtype=tf_dtype)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = tf.constant(np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf_dtype)
self.posterior_mean_coef1 = tf.constant(
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod), dtype=tf_dtype)
self.posterior_mean_coef2 = tf.constant(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod), dtype=tf_dtype)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = tf.gather(a, t)
assert out.shape == [bs]
return tf.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = tf.random_normal(shape=x_start.shape)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def predict_start_from_noise(self, x_t, t, noise):
assert x_t.shape == noise.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_losses(self, denoise_fn, x_start, t, noise=None):
"""
Training loss calculation
"""
B, H, W, C = x_start.shape.as_list()
assert t.shape == [B]
if noise is None:
noise = tf.random_normal(shape=x_start.shape, dtype=x_start.dtype)
assert noise.shape == x_start.shape and noise.dtype == x_start.dtype
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
x_recon = denoise_fn(x_noisy, t)
assert x_noisy.shape == x_start.shape
assert x_recon.shape[:3] == [B, H, W] and len(x_recon.shape) == 4
if self.loss_type == 'noisepred':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
assert x_recon.shape == x_start.shape
losses = nn.meanflat(tf.squared_difference(noise, x_recon))
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == [B]
return losses
def p_mean_variance(self, denoise_fn, *, x, t, clip_denoised: bool):
if self.loss_type == 'noisepred':
x_recon = self.predict_start_from_noise(x, t=t, noise=denoise_fn(x, t))
else:
raise NotImplementedError(self.loss_type)
if clip_denoised:
x_recon = tf.clip_by_value(x_recon, -1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
assert model_mean.shape == x_recon.shape == x.shape
assert posterior_variance.shape == posterior_log_variance.shape == [x.shape[0], 1, 1, 1]
return model_mean, posterior_variance, posterior_log_variance
def p_sample(self, denoise_fn, *, x, t, noise_fn, clip_denoised=True, repeat_noise=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance = self.p_mean_variance(denoise_fn, x=x, t=t, clip_denoised=clip_denoised)
noise = noise_like(x.shape, noise_fn, repeat_noise)
assert noise.shape == x.shape
# no noise when t == 0
nonzero_mask = tf.reshape(1 - tf.cast(tf.equal(t, 0), tf.float32), [x.shape[0]] + [1] * (len(x.shape) - 1))
return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise
def p_sample_loop(self, denoise_fn, *, shape, noise_fn=tf.random_normal):
"""
Generate samples
"""
i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32)
assert isinstance(shape, (tuple, list))
img_0 = noise_fn(shape=shape, dtype=tf.float32)
_, img_final = tf.while_loop(
cond=lambda i_, _: tf.greater_equal(i_, 0),
body=lambda i_, img_: [
i_ - 1,
self.p_sample(denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn)
],
loop_vars=[i_0, img_0],
shape_invariants=[i_0.shape, img_0.shape],
back_prop=False
)
assert img_final.shape == shape
return img_final
def p_sample_loop_trajectory(self, denoise_fn, *, shape, noise_fn=tf.random_normal, repeat_noise_steps=-1):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32)
assert isinstance(shape, (tuple, list))
img_0 = noise_like(shape, noise_fn, repeat_noise_steps >= 0)
times = tf.Variable([i_0])
imgs = tf.Variable([img_0])
# Steps with repeated noise
times, imgs = tf.while_loop(
cond=lambda times_, _: tf.less_equal(self.num_timesteps - times_[-1], repeat_noise_steps),
body=lambda times_, imgs_: [
tf.concat([times_, [times_[-1] - 1]], 0),
tf.concat([imgs_, [self.p_sample(denoise_fn=denoise_fn,
x=imgs_[-1],
t=tf.fill([shape[0]], times_[-1]),
noise_fn=noise_fn,
repeat_noise=True)]], 0)
],
loop_vars=[times, imgs],
shape_invariants=[tf.TensorShape([None, *i_0.shape]),
tf.TensorShape([None, *img_0.shape])],
back_prop=False
)
# Steps with different noise for each batch element
times, imgs = tf.while_loop(
cond=lambda times_, _: tf.greater_equal(times_[-1], 0),
body=lambda times_, imgs_: [
tf.concat([times_, [times_[-1] - 1]], 0),
tf.concat([imgs_, [self.p_sample(denoise_fn=denoise_fn,
x=imgs_[-1],
t=tf.fill([shape[0]], times_[-1]),
noise_fn=noise_fn,
repeat_noise=False)]], 0)
],
loop_vars=[times, imgs],
shape_invariants=[tf.TensorShape([None, *i_0.shape]),
tf.TensorShape([None, *img_0.shape])],
back_prop=False
)
assert imgs[-1].shape == shape
return times, imgs
def interpolate(self, denoise_fn, *, shape, noise_fn=tf.random_normal):
"""
Interpolate between images.
t == 0 means diffuse images for 1 timestep before mixing.
"""
assert isinstance(shape, (tuple, list))
# Placeholders for real samples to interpolate
x1 = tf.placeholder(tf.float32, shape)
x2 = tf.placeholder(tf.float32, shape)
# lam == 0.5 averages diffused images.
lam = tf.placeholder(tf.float32, shape=())
t = tf.placeholder(tf.int32, shape=())
# Add noise via forward diffusion
# TODO: use the same noise for both endpoints?
# t_batched = tf.constant([t] * x1.shape[0], dtype=tf.int32)
t_batched = tf.stack([t] * x1.shape[0])
xt1 = self.q_sample(x1, t=t_batched)
xt2 = self.q_sample(x2, t=t_batched)
# Mix latents
# Linear interpolation
xt_interp = (1 - lam) * xt1 + lam * xt2
# Constant variance interpolation
# xt_interp = tf.sqrt(1 - lam * lam) * xt1 + lam * xt2
# Reverse diffusion (similar to self.p_sample_loop)
# t = tf.constant(t, dtype=tf.int32)
_, x_interp = tf.while_loop(
cond=lambda i_, _: tf.greater_equal(i_, 0),
body=lambda i_, img_: [
i_ - 1,
self.p_sample(denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn)
],
loop_vars=[t, xt_interp],
shape_invariants=[t.shape, xt_interp.shape],
back_prop=False
)
assert x_interp.shape == shape
return x1, x2, lam, x_interp, t