-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathft.html
369 lines (327 loc) · 17.4 KB
/
ft.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="description"
content="Relative trajectory balance, an objective for unbiased posterior sampling with diffusion priors. ">
<meta name="keywords" content="diffusion models, gflownets, fine-tuning">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Off Policy Diffusion Samplers</title>
<link href="https://fonts.googleapis.com/css?family=Google+Sans|Noto+Sans|Castoro"
rel="stylesheet">
<link rel="stylesheet" href="./static/css/bulma.min.css">
<link rel="stylesheet" href="./static/css/bulma-carousel.min.css">
<link rel="stylesheet" href="./static/css/bulma-slider.min.css">
<link rel="stylesheet" href="./static/css/fontawesome.all.min.css">
<link rel="stylesheet"
href="https://cdn.jsdelivr.net/gh/jpswalsh/academicons@1/css/academicons.min.css">
<link rel="stylesheet" href="./static/css/index.css">
<link rel="icon" href="./static/images/favicon.svg">
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
<script defer src="./static/js/fontawesome.all.min.js"></script>
<script src="./static/js/bulma-carousel.min.js"></script>
<script src="./static/js/bulma-slider.min.js"></script>
<script src="./static/js/index.js"></script>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({tex2jax: {inlineMath: [['$','$'], ['\\(','\\)']]}});
</script>
<script type="text/javascript"
src="https://cdn.jsdelivr.net/npm/mathjax@2/MathJax.js?config=TeX-AMS_CHTML">
</script>
</head>
<body>
<nav class="navbar" role="navigation" aria-label="main navigation" style="background-color: #333; color: #fff;">
<div class="navbar-brand">
<a role="button" class="navbar-burger" aria-label="menu" aria-expanded="false" style="color: #fff;">
<span aria-hidden="true"></span>
<span aria-hidden="true"></span>
<span aria-hidden="true"></span>
</a>
</div>
<div class="navbar-menu">
<div class="navbar-start" style="flex-grow: 1; justify-content: center;">
<a class="navbar-item" href="#abstract" style="color: #fff; border-bottom: 0px solid #fff;">
Abstract
</a>
<a class="navbar-item" href="#setup" style="color: #fff; border-bottom: 0px solid #fff;">
Setup
</a>
<a class="navbar-item" href="#method" style="color: #fff; border-bottom: 0px solid #fff;">
Method
</a>
<a class="navbar-item" href="#results" style="color: #fff; border-bottom: 0px solid #fff;">
Empirical Results
</a>
<a class="navbar-item" href="#BibTeX" style="color: #fff; border-bottom: 0px solid #fff;">
BibTeX
</a>
</div>
</nav>
<section class="hero">
<div class="hero-body">
<div class="container is-max-desktop">
<div class="columns is-centered">
<div class="column has-text-centered">
<h1 class="title is-1 publication-title">Amortizing intractable inference in diffusion models for vision, language, and control</h1>
<div class="is-size-5 publication-authors">
<span class="author-block">
<a href="https://hyperpotatoneo.github.io">Siddarth Venkatraman</a><sup>*1</sup></span>
<span class="author-block">
<a href="https://mj10.github.io">Moksh Jain</a><sup>*1</sup></span>
<span class="author-block">
<a href="https://lucascimeca.com">Luca Scimeca</a><sup>*1</sup>
</span>
<span class="author-block">
<a href="https://minsuukim.github.io">Minsu Kim</a><sup>*1,2</sup>
</span>
<span class="author-block">
<a href="https://x.com/MarcinSendera">Marcin Sendera</a><sup>*1,3</sup>
</span>
<br/>
<span class="author-block">
<a href="https://hasanmohsin.github.io">Mohsin Hasan</a><sup>1</sup>
</span>
<span class="author-block">
<a href="https://mila.quebec/en/directory/luke-rowe">Luke Rowe</a><sup>1</sup>
</span>
<span class="author-block">
<a href="https://sarthmit.github.io">Sarthak Mittal</a><sup>1</sup>
</span>
<span class="author-block">
<a href="https://pablo-lemos.github.io">Pablo Lemos</a><sup>1,4,5</sup>
</span>
<span class="author-block">
<a href="https://folinoid.com">Emmanuel Bengio</a><sup>6</sup>
</span>
<span class="author-block">
<a href="https://mila.quebec/en/directory/alexandre-adam">Alexandre Adam</a><sup>1,4</sup>
</span>
<span class="author-block">
<a href="https://jarridrb.github.io">Jarrid Rector-Brooks</a><sup>1,5</sup>
</span>
<span class="author-block">
<a href="https://yoshuabengio.org">Yoshua Bengio</a><sup>1,7</sup>
</span>
<span class="author-block">
<a href="https://neo-x.github.io">Glen Berseth</a><sup>1,7</sup>
</span>
<span class="author-block">
<a href="https://malkin1729.github.io">Nikolay Malkin</a><sup>1</sup>
</span>
</div>
<div class="is-size-5 publication-authors">
<span class="author-block">Mila - Quebec AI Institute</span>
</div>
<div class="is-size-5 publication-authors">
<span class="author-block"><sup>*</sup>Equal Contribution</span>
<span class="author-block"><sup>1</sup>Université de Montréal</span>
<span class="author-block"><sup>2</sup>KAIST</span>
<br/>
<span class="author-block"><sup>3</sup>Jagiellonian University</span>
<span class="author-block"><sup>4</sup>Ciela Institute</span>
<span class="author-block"><sup>5</sup>Dreamfold</span>
<br/>
<span class="author-block"><sup>6</sup>Recursion</span>
<span class="author-block"><sup>7</sup>CIFAR AI Chair</span>
</div>
<div class="is-size-5 publication-authors">
<span class="author-block">
<b>NeurIPS 2024</b>
</span>
<div class="column has-text-centered">
<div class="publication-links">
<!-- PDF Link. -->
<span class="link-block">
<a href="https://arxiv.org/pdf/2405.20971.pdf"
class="external-link button is-normal is-rounded">
<span class="icon">
<i class="fas fa-file-pdf"></i>
</span>
<span>Paper</span>
</a>
</span>
<span class="link-block">
<a href="https://arxiv.org/abs/2405.20971"
class="external-link button is-normal is-rounded">
<span class="icon">
<i class="ai ai-arxiv"></i>
</span>
<span>arXiv</span>
</a>
</span>
<!-- Code Link. -->
<span class="link-block">
<a href="https://github.com/GFNOrg/diffusion-finetuning"
class="external-link button is-normal is-rounded">
<span class="icon">
<i class="fab fa-github"></i>
</span>
<span>Code</span>
</a>
</span>
</div>
</div>
</div>
</div>
</div>
</section>
<section class="hero teaser">
<div class="container is-max-desktop">
<div class="hero-body">
<!-- <img src="./static/images/final_picture_teaser.jpg" -->
<!-- alt="Teaser image." -->
<!-- class="teaser-image"/> -->
<center>
<img src="./static/images/example.png" alt="GMM Example" style="width: 100%;"/>
</center>
</div>
</div>
</section>
<section class="section">
<div class="container is-max-desktop">
<!-- Abstract. -->
<div class="columns is-centered has-text-centered">
<div class="column is-four-fifths">
<h2 class="title is-3" id="abstract">Abstract</h2>
<div class="content has-text-justified">
<p>
Diffusion models have emerged as effective distribution estimators in vision, language, and reinforcement learning, but their use as priors
in downstream tasks poses an intractable posterior inference problem. This paper studies amortized sampling of the posterior over data, $\mathbf{x}\sim p^\text{post}(\mathbf{x})\propto p(\mathbf{x})r(\mathbf{x})$,
in a model that consists of a diffusion generative model prior p($\mathbf{x}$) and a black-box constraint or likelihood function $r(\mathbf{x})$. We state and prove the asymptotic
correctness of a data-free learning objective, relative trajectory balance, for training a diffusion model that samples from this posterior, a problem that
existing methods solve only approximately or in restricted cases. Relative trajectory balance arises from the generative flow network perspective on diffusion
models, which allows the use of deep reinforcement learning techniques to improve mode coverage. Experiments illustrate the broad potential of unbiased inference
of arbitrary posteriors under diffusion priors: in vision (classifier guidance), language (infilling under a discrete diffusion LLM), and multimodal data
(text-to-image generation). Beyond generative modeling, we apply relative trajectory balance to the problem of continuous control with a score-based behavior
prior, achieving state-of-the-art results on benchmarks in offline reinforcement learning.
</p>
</div>
</div>
</div>
<!--/ Abstract. -->
<hr>
<section class="section">
<div class="container is-max-desktop">
<!-- Animation. -->
<div class="columns is-centered">
<div class="column is-full-width">
<h2 class="title is-3" id="intractable">Intractable inference in diffusion models</h2>
<h3 class="title is-4" id="setup">Setup</h3>
<div class="content has-text-justified">
<p> Given a diffusion model prior \( p(\mathbf{x}) \) and a black-box likelihood function \( r(\mathbf{x}) \), our goal is to sample from the posterior \( p^{\text{post}}(\mathbf{x}) \propto p(\mathbf{x}) r(\mathbf{x}) \). Conventional approaches often rely on heuristic guidance, leading to bias or restricted applicability. In contrast, we derive a principled, unbiased objective for posterior sampling, rooted in the Generative Flow Network (GFlowNet) perspective, which ensures improved mode coverage and asymptotic correctness without requiring data or approximations.
</p>
</div>
<!--<h3 class="title is-4">Coverage Conditions and Geometric Relationships</h3>
<div class="content has-text-justified">
<p>
Some content here
</p>
</div>-->
<hr>
<h2 class="title is-3" id="method">Relative trajectory balance</h2>
<p>
The <b>Relative Trajectory Balance (RTB)</b> objective ensures that the ratio of the forward trajectory probabilities under the posterior model \( p_\phi^{\text{post}} \) to the prior model \( p_\theta \) is proportional to the constraint function \( r(\mathbf{x}) \). This is achieved by minimizing the loss:
</p>
<div class="box" style="background-color:aliceblue">
<div class="content has-text-justified">
$$L(x,\tau; \phi) = \left(\log\frac{Z_\phi\prod_{i=1}^np_{F,\phi}^{{\rm post}}(s_k\mid s_{k-1})}{r(x)\prod_{i=1}^{n}p_{F,\theta}(s_{k}\mid s_{k-1})}\right)^2$$
<!-- <b></b> -->
</div>
</div>
<div class="content has-text-justified">
<p>
Here, \( Z_{\phi} \) is a learnable normalization constant. Satisfying the RTB constraint (minimizing loss to 0) for all diffusion trajectories facilitates unbiased sampling from the desired posterior distribution \( p^{\text{post}}(\mathbf{x}) \propto p_\theta(\mathbf{x}) r(\mathbf{x}) \).
</p>
</div>
<div class="content has-text-justified">
<h3 class="title is-4" id="setup">Off-Policy Exploration</h3>
<p>
We are free to choose off-policy diffusion trajectories to optimize the RTB objective, which facilitates improved exploration and mode coverage. In particular, useful strategies include the use of replay buffers and local search.
</p>
</div>
<hr>
<h2 class="title is-3" id="results">Empirical Results</h2>
<h3 class="title is-4" id="results">Unconditional Image</h3>
<p>
We fine-tune unconditional diffusion priors for class-conditional generation on MNIST and CIFAR-10 datasets. Starting with pretrained unconditional models \( p_\theta(x) \), we apply the RTB objective to adapt the priors to sample from posteriors conditioned on class labels \( c \). This is achieved by incorporating class-specific constraints \( r(x) = p(c | x) \) during fine-tuning. In the figure, we observe some of our results. RTB effectively balances reward maximization and sample diversity, finetuning both for single class conditions, or multimodal distributions (e.g. even numbers).
</p>
<div class="content has-text-justified">
<center>
<img src="./static/images/cls_finetuning_samples.png" alt="Image" style="width: 100%;"/>
</center>
</div>
<h3 class="title is-4" id="results">Text-to-image</h3>
<p>
We apply RTB to the problem of KL regularized finetuning of text-to-image diffusion priors (stable-diffusion-1.5) with a reward function trained on human preferences (ImageReward). RTB optimized policies achieve high reward while maintaining diversity in the generated images. In the images shown below for different text prompts - the first row is the diffusion prior, second row consists of biased posteriors finetuned with KL regularized RL (DPOK), and third row consists of posteriors finetuned with RTB.
</p>
<div class="content has-text-justified">
<p>
</p>
<center>
<img src="./static/images/text2image_samples.png" alt="Text-to-image samples" style="width: 100%;"/>
</center>
<center>
<img src="./static/images/text2image_metrics.png" alt="Text-to-image metrics" style="width: 100%;"/>
</center>
</div>
<h3 class="title is-4" id="results">Diffusion language models</h3>
<p>
RTB is generally applicable to hierarchical generative models, including discrete diffusion. We apply RTB to infilling stories with a discrete diffusion model prior, outperforming finetuned autoregressive models for this task.
</p>
<div class="content has-text-justified">
<center>
<img src="./static/images/dlm_metrics.png" alt="LM results" style="width: 100%;"/>
</center>
</div>
<h3 class="title is-4" id="results">Offline RL</h3>
<p>
An important problem in offline RL is KL regularized policy extraction using the behavior policy as prior, and the trained Q function obtained using an off-the-shelf Q-learning algorithm. Diffusion policies are expressive and can model highly multimodal behavior policies. Given this diffusion prior \(\mu(a|s)\) and a Q function trained with IQL \(Q(s,a)\), we use RTB to obtain the KL regularized optimal policy of the form \(\pi^*(a|s) \propto \mu(a|s)e^{Q(s,a)}\). We match state of the art results in the D4RL benchmark.
</p>
<div class="content has-text-justified"></div>
<center>
<img src="./static/images/offline_rl_metrics.png" alt="Offline RL results" style="width: 100%;"/>
</center>
</div>
</div>
</div>
<!--/ Animation. -->
</div>
</section>
<hr>
<section class="section" id="BibTeX">
<div class="container is-max-desktop content">
<h2 class="title">BibTeX</h2>
<pre><code>
@inproceedings{
venkatraman2024amortizing,
title={Amortizing intractable inference in diffusion models for vision, language, and control},
author={Siddarth Venkatraman and Moksh Jain and Luca Scimeca and Minsu Kim and Marcin Sendera and Mohsin Hasan and Luke Rowe and Sarthak Mittal and Pablo Lemos and Emmanuel Bengio and Alexandre Adam and Jarrid Rector-Brooks and Yoshua Bengio and Glen Berseth and Nikolay Malkin},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=gVTkMsaaGI}
}
</code></pre>
<!-- journal = {ICML}, -->
</div>
</section>
<footer class="footer">
<div class="container">
<div class="content has-text-centered">
<a class="icon-link"
href="https://arxiv.org/pdf/2405.20971.pdf">
<i class="fas fa-file-pdf"></i>
</a>
<a class="icon-link" href="https://github.com/GFNOrg/diffusion-finetuning" class="external-link" disabled>
<i class="fab fa-github"></i>
</a>
</div>
<div class="columns is-centered">
<p>
<!-- Corresponding Authors: <a href="mailto:anikait@stanford.edu">Anikait Singh</a>, <a href="mailto:ftajwar@cs.cmu.edu">Fahim Tajwar</a>.<br> -->
We thank the Nerfies Team for their <a href="https://github.com/nerfies/nerfies.github.io">website template</a>.
</p>
</div>
</div>
</footer>
</body>
</html>