-
Notifications
You must be signed in to change notification settings - Fork 58
/
Copy pathfractalgen.py
199 lines (176 loc) · 7.43 KB
/
fractalgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from functools import partial
import torch
import torch.nn as nn
from models.ar import AR
from models.mar import MAR
from models.pixelloss import PixelLoss
class FractalGen(nn.Module):
""" Fractal Generative Model"""
def __init__(self,
img_size_list,
embed_dim_list,
num_blocks_list,
num_heads_list,
generator_type_list,
label_drop_prob=0.1,
class_num=1000,
attn_dropout=0.1,
proj_dropout=0.1,
guiding_pixel=False,
num_conds=1,
r_weight=1.0,
grad_checkpointing=False,
fractal_level=0):
super().__init__()
# --------------------------------------------------------------------------
# fractal specifics
self.fractal_level = fractal_level
self.num_fractal_levels = len(img_size_list)
# --------------------------------------------------------------------------
# Class embedding for the first fractal level
if self.fractal_level == 0:
self.num_classes = class_num
self.class_emb = nn.Embedding(class_num, embed_dim_list[0])
self.label_drop_prob = label_drop_prob
self.fake_latent = nn.Parameter(torch.zeros(1, embed_dim_list[0]))
torch.nn.init.normal_(self.class_emb.weight, std=0.02)
torch.nn.init.normal_(self.fake_latent, std=0.02)
# --------------------------------------------------------------------------
# Generator for the current level
if generator_type_list[fractal_level] == "ar":
generator = AR
elif generator_type_list[fractal_level] == "mar":
generator = MAR
else:
raise NotImplementedError
self.generator = generator(
seq_len=(img_size_list[fractal_level] // img_size_list[fractal_level+1]) ** 2,
patch_size=img_size_list[fractal_level+1],
cond_embed_dim=embed_dim_list[fractal_level-1] if fractal_level > 0 else embed_dim_list[0],
embed_dim=embed_dim_list[fractal_level],
num_blocks=num_blocks_list[fractal_level],
num_heads=num_heads_list[fractal_level],
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
guiding_pixel=guiding_pixel if fractal_level > 0 else False,
num_conds=num_conds,
grad_checkpointing=grad_checkpointing,
)
# --------------------------------------------------------------------------
# Build the next fractal level recursively
if self.fractal_level < self.num_fractal_levels - 2:
self.next_fractal = FractalGen(
img_size_list=img_size_list,
embed_dim_list=embed_dim_list,
num_blocks_list=num_blocks_list,
num_heads_list=num_heads_list,
generator_type_list=generator_type_list,
label_drop_prob=label_drop_prob,
class_num=class_num,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
guiding_pixel=guiding_pixel,
num_conds=num_conds,
r_weight=r_weight,
grad_checkpointing=grad_checkpointing,
fractal_level=fractal_level+1
)
else:
# The final fractal level uses PixelLoss.
self.next_fractal = PixelLoss(
c_channels=embed_dim_list[fractal_level],
depth=num_blocks_list[fractal_level+1],
width=embed_dim_list[fractal_level+1],
num_heads=num_heads_list[fractal_level+1],
r_weight=r_weight,
)
def forward(self, imgs, cond_list):
"""
Forward pass to get loss recursively.
"""
if self.fractal_level == 0:
# Compute class embedding conditions.
class_embedding = self.class_emb(cond_list)
if self.training:
# Randomly drop labels according to label_drop_prob.
drop_latent_mask = (torch.rand(cond_list.size(0)) < self.label_drop_prob).unsqueeze(-1).cuda().to(class_embedding.dtype)
class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
else:
# For evaluation (unconditional NLL), use a constant mask.
drop_latent_mask = torch.ones(cond_list.size(0)).unsqueeze(-1).cuda().to(class_embedding.dtype)
class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
cond_list = [class_embedding for _ in range(5)]
# Get image patches and conditions for the next level
imgs, cond_list, guiding_pixel_loss = self.generator(imgs, cond_list)
# Compute loss recursively from the next fractal level.
loss = self.next_fractal(imgs, cond_list)
return loss + guiding_pixel_loss
def sample(self, cond_list, num_iter_list, cfg, cfg_schedule, temperature, filter_threshold, fractal_level,
visualize=False):
"""
Generate samples recursively.
"""
if fractal_level < self.num_fractal_levels - 2:
next_level_sample_function = partial(
self.next_fractal.sample,
num_iter_list=num_iter_list,
cfg_schedule="constant",
fractal_level=fractal_level + 1
)
else:
next_level_sample_function = self.next_fractal.sample
# Recursively sample using the current generator.
return self.generator.sample(
cond_list, num_iter_list[fractal_level], cfg, cfg_schedule,
temperature, filter_threshold, next_level_sample_function, visualize
)
def fractalar_in64(**kwargs):
model = FractalGen(
img_size_list=(64, 4, 1),
embed_dim_list=(1024, 512, 128),
num_blocks_list=(32, 8, 3),
num_heads_list=(16, 8, 4),
generator_type_list=("ar", "ar", "ar"),
fractal_level=0,
**kwargs)
return model
def fractalmar_in64(**kwargs):
model = FractalGen(
img_size_list=(64, 4, 1),
embed_dim_list=(1024, 512, 128),
num_blocks_list=(32, 8, 3),
num_heads_list=(16, 8, 4),
generator_type_list=("mar", "mar", "ar"),
fractal_level=0,
**kwargs)
return model
def fractalmar_base_in256(**kwargs):
model = FractalGen(
img_size_list=(256, 16, 4, 1),
embed_dim_list=(768, 384, 192, 64),
num_blocks_list=(24, 6, 3, 1),
num_heads_list=(12, 6, 3, 4),
generator_type_list=("mar", "mar", "mar", "ar"),
fractal_level=0,
**kwargs)
return model
def fractalmar_large_in256(**kwargs):
model = FractalGen(
img_size_list=(256, 16, 4, 1),
embed_dim_list=(1024, 512, 256, 64),
num_blocks_list=(32, 8, 4, 1),
num_heads_list=(16, 8, 4, 4),
generator_type_list=("mar", "mar", "mar", "ar"),
fractal_level=0,
**kwargs)
return model
def fractalmar_huge_in256(**kwargs):
model = FractalGen(
img_size_list=(256, 16, 4, 1),
embed_dim_list=(1280, 640, 320, 64),
num_blocks_list=(40, 10, 5, 1),
num_heads_list=(16, 8, 4, 4),
generator_type_list=("mar", "mar", "mar", "ar"),
fractal_level=0,
**kwargs)
return model