forked from eps696/aphantasia
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathclip_fft.py
executable file
·264 lines (233 loc) · 12.6 KB
/
clip_fft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import os
import warnings
warnings.filterwarnings("ignore")
import argparse
import numpy as np
from imageio import imread, imsave
import shutil
try:
from googletrans import Translator, constants
except ImportError as e:
# print("--> Not running with googletrans support", e)
# googletrans optional (not needed if Translation not used)
pass
import torch
import torchvision
import torch.nn.functional as F
import clip
os.environ['KMP_DUPLICATE_LIB_OK']='True'
from sentence_transformers import SentenceTransformer
import lpips
from aphantasia.image import to_valid_rgb, fft_image, dwt_image
from aphantasia.utils import slice_imgs, derivat, sim_func, basename, img_list, img_read, plot_text, txt_clean, checkout, old_torch
from aphantasia import transforms
try: # progress bar for notebooks
get_ipython().__class__.__name__
from aphantasia.progress_bar import ProgressIPy as ProgressBar
except: # normal console
from aphantasia.progress_bar import ProgressBar
clip_models = ['ViT-B/16', 'ViT-B/32', 'RN101', 'RN50x16', 'RN50x4', 'RN50']
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--in_img', default=None, help='input image')
parser.add_argument('-t', '--in_txt', default=None, help='input text')
parser.add_argument('-t2', '--in_txt2', default=None, help='input text - style')
parser.add_argument('-w2', '--weight2', default=0.5, type=float, help='weight for style')
parser.add_argument('-t0', '--in_txt0', default=None, help='input text to subtract')
parser.add_argument( '--out_dir', default='_out')
parser.add_argument('-s', '--size', default='1280-720', help='Output resolution')
parser.add_argument('-r', '--resume', default=None, help='Path to saved FFT snapshots, to resume from')
parser.add_argument('-opt', '--opt_step', default=1, type=int, help='How many optimizing steps per save step')
parser.add_argument('-tr', '--translate', action='store_true', help='Translate text with Google Translate')
parser.add_argument('-ml', '--multilang', action='store_true', help='Use SBERT multilanguage model for text')
parser.add_argument( '--save_pt', action='store_true', help='Save FFT snapshots for further use')
parser.add_argument('-v', '--verbose', dest='verbose', action='store_true')
parser.add_argument( '--no-verbose', dest='verbose', action='store_false')
parser.set_defaults(verbose=True)
# training
parser.add_argument('-m', '--model', default='ViT-B/32', choices=clip_models, help='Select CLIP model to use')
parser.add_argument( '--steps', default=200, type=int, help='Total iterations')
parser.add_argument( '--samples', default=200, type=int, help='Samples to evaluate')
parser.add_argument( '--lrate', default=0.05, type=float, help='Learning rate')
parser.add_argument('-p', '--prog', action='store_true', help='Enable progressive lrate growth (up to double a.lrate)')
# wavelet
parser.add_argument( '--dwt', action='store_true', help='Use DWT instead of FFT')
parser.add_argument('-w', '--wave', default='coif2', help='wavelets: db[1..], coif[1..], haar, dmey')
# tweaks
parser.add_argument('-a', '--align', default='uniform', choices=['central', 'uniform', 'overscan', 'overmax'], help='Sampling distribution')
parser.add_argument('-tf', '--transform', default='custom', choices=['none', 'custom', 'elastic'], help='use augmenting transforms?')
parser.add_argument( '--contrast', default=0.9, type=float)
parser.add_argument( '--colors', default=1.5, type=float)
parser.add_argument( '--decay', default=1.5, type=float)
parser.add_argument('-sh', '--sharp', default=0.3, type=float)
parser.add_argument('-mm', '--macro', default=0.4, type=float, help='Endorse macro forms 0..1 ')
parser.add_argument('-e', '--enforce', default=0, type=float, help='Enforce details (by boosting similarity between two parallel samples)')
parser.add_argument('-x', '--expand', default=0, type=float, help='Boosts diversity (by enforcing difference between prev/next samples)')
parser.add_argument('-n', '--noise', default=0, type=float, help='Add noise to suppress accumulation') # < 0.05 ?
parser.add_argument('-nt', '--notext', default=0, type=float, help='Subtract typed text as image (avoiding graffiti?), [0..1]')
parser.add_argument('-c', '--sync', default=0, type=float, help='Sync output to input image')
parser.add_argument( '--invert', action='store_true', help='Invert criteria')
parser.add_argument( '--sim', default='mix', help='Similarity function (dot/angular/spherical/mixed; None = cossim)')
a = parser.parse_args()
if a.size is not None: a.size = [int(s) for s in a.size.split('-')][::-1]
if len(a.size)==1: a.size = a.size * 2
if a.in_img is not None and a.sync != 0: a.align = 'overscan'
if a.multilang is True: a.model = 'ViT-B/32' # sbert model is trained with ViT
return a
def main():
a = get_args()
shape = [1, 3, *a.size]
if a.dwt is True:
params, image_f, sz = dwt_image(shape, a.wave, a.sharp, a.colors, a.resume)
else:
params, image_f, sz = fft_image(shape, 0.01, a.decay, a.resume)
if sz is not None: a.size = sz
image_f = to_valid_rgb(image_f, colors = a.colors)
if a.prog is True:
lr1 = a.lrate * 2
lr0 = lr1 * 0.01
else:
lr0 = a.lrate
optimizer = torch.optim.AdamW(params, lr0, weight_decay=0.01, amsgrad=True)
sign = 1. if a.invert is True else -1.
# Load CLIP models
model_clip, _ = clip.load(a.model, jit=old_torch())
try:
a.modsize = model_clip.visual.input_resolution
except:
a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 224
if a.verbose is True: print(' using model', a.model)
xmem = {'ViT-B/16':0.25, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN101':0.33}
if a.model in xmem.keys():
a.samples = int(a.samples * xmem[a.model])
if a.multilang is True:
model_lang = SentenceTransformer('clip-ViT-B-32-multilingual-v1').cuda()
def enc_text(txt):
if a.multilang is True:
emb = model_lang.encode([txt], convert_to_tensor=True, show_progress_bar=False)
else:
emb = model_clip.encode_text(clip.tokenize(txt).cuda())
return emb.detach().clone()
if a.enforce != 0:
a.samples = int(a.samples * 0.5)
if a.sync > 0:
a.samples = int(a.samples * 0.5)
if 'elastic' in a.transform:
trform_f = transforms.transforms_elastic
a.samples = int(a.samples * 0.95)
elif 'custom' in a.transform:
trform_f = transforms.transforms_custom
a.samples = int(a.samples * 0.95)
else:
trform_f = transforms.normalize()
out_name = []
if a.in_txt is not None:
if a.verbose is True: print(' topic text: ', a.in_txt)
if a.translate:
translator = Translator()
a.in_txt = translator.translate(a.in_txt, dest='en').text
if a.verbose is True: print(' translated to:', a.in_txt)
txt_enc = enc_text(a.in_txt)
out_name.append(txt_clean(a.in_txt).lower()[:40])
if a.notext > 0:
txt_plot = torch.from_numpy(plot_text(a.in_txt, a.modsize)/255.).unsqueeze(0).permute(0,3,1,2).cuda()
txt_plot_enc = model_clip.encode_image(txt_plot).detach().clone()
if a.in_txt2 is not None:
if a.verbose is True: print(' style text:', a.in_txt2)
a.samples = int(a.samples * 0.75)
if a.translate:
translator = Translator()
a.in_txt2 = translator.translate(a.in_txt2, dest='en').text
if a.verbose is True: print(' translated to:', a.in_txt2)
txt_enc2 = enc_text(a.in_txt2)
out_name.append(txt_clean(a.in_txt2).lower()[:40])
if a.in_txt0 is not None:
if a.verbose is True: print(' subtract text:', a.in_txt0)
a.samples = int(a.samples * 0.75)
if a.translate:
translator = Translator()
a.in_txt0 = translator.translate(a.in_txt0, dest='en').text
if a.verbose is True: print(' translated to:', a.in_txt0)
txt_enc0 = enc_text(a.in_txt0)
out_name.append('off-' + txt_clean(a.in_txt0).lower()[:40])
if a.multilang is True: del model_lang
if a.in_img is not None and os.path.isfile(a.in_img):
if a.verbose is True: print(' ref image:', basename(a.in_img))
img_in = torch.from_numpy(img_read(a.in_img)/255.).unsqueeze(0).permute(0,3,1,2).cuda()
img_in = img_in[:,:3,:,:] # fix rgb channels
in_sliced = slice_imgs([img_in], a.samples, a.modsize, transforms.normalize(), a.align)[0]
img_enc = model_clip.encode_image(in_sliced).detach().clone()
if a.sync > 0:
sim_loss = lpips.LPIPS(net='vgg', verbose=False).cuda()
sim_size = [s//2 for s in a.size]
img_in = F.interpolate(img_in, sim_size, mode='bicubic', align_corners=True).float()
else:
del img_in
del in_sliced; torch.cuda.empty_cache()
out_name.append(basename(a.in_img).replace(' ', '_'))
if a.verbose is True: print(' samples:', a.samples)
out_name = '-'.join(out_name)
out_name += '-%s' % a.model if 'RN' in a.model.upper() else ''
tempdir = os.path.join(a.out_dir, out_name)
os.makedirs(tempdir, exist_ok=True)
prev_enc = 0
def train(i):
loss = 0
noise = a.noise * torch.rand(1, 1, *params[0].shape[2:4], 1).cuda() if a.noise > 0 else None
img_out = image_f(noise)
img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0]
out_enc = model_clip.encode_image(img_sliced)
if a.in_txt is not None: # input text
loss += sign * sim_func(txt_enc, out_enc, a.sim)
if a.notext > 0:
loss -= sign * a.notext * sim_func(txt_plot_enc, out_enc, a.sim)
if a.in_txt2 is not None: # input text - style
loss += sign * a.weight2 * sim_func(txt_enc2, out_enc, a.sim)
if a.in_txt0 is not None: # subtract text
loss += -sign * 0.3 * sim_func(txt_enc0, out_enc, a.sim)
if a.in_img is not None and os.path.isfile(a.in_img): # input image
loss += sign * 0.5 * sim_func(img_enc, out_enc, a.sim)
if a.sync > 0 and a.in_img is not None and os.path.isfile(a.in_img): # image composition
prog_sync = (a.steps // a.opt_step - i) / (a.steps // a.opt_step)
loss += prog_sync * a.sync * sim_loss(F.interpolate(img_out, sim_size, mode='bicubic', align_corners=True).float(), img_in, normalize=True).squeeze()
if a.sharp != 0 and a.dwt is not True: # scharr|sobel|default
loss -= a.sharp * derivat(img_out, mode='naiv')
# loss -= a.sharp * derivat(img_sliced, mode='scharr')
if a.enforce != 0:
img_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize, trform_f, a.align, a.macro)[0]
out_enc2 = model_clip.encode_image(img_sliced)
loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim)
del out_enc2; torch.cuda.empty_cache()
if a.expand > 0:
global prev_enc
if i > 0:
loss += a.expand * sim_func(out_enc, prev_enc, a.sim)
prev_enc = out_enc.detach() # .clone()
del img_out, img_sliced, out_enc; torch.cuda.empty_cache()
assert not isinstance(loss, int), ' Loss not defined, check the inputs'
if a.prog is True:
lr_cur = lr0 + (i / a.steps) * (lr1 - lr0)
for g in optimizer.param_groups:
g['lr'] = lr_cur
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % a.opt_step == 0:
with torch.no_grad():
img = image_f(contrast=a.contrast).cpu().numpy()[0]
# empirical tone mapping
if (a.sync > 0 and a.in_img is not None):
img = img **1.3
elif a.sharp != 0:
img = img ** (1 + a.sharp/2.)
checkout(img, os.path.join(tempdir, '%04d.jpg' % (i // a.opt_step)), verbose=a.verbose)
pbar.upd()
pbar = ProgressBar(a.steps // a.opt_step)
for i in range(a.steps):
train(i)
os.system('ffmpeg -v warning -y -i %s/\%%04d.jpg "%s.mp4"' % (tempdir, os.path.join(a.out_dir, out_name)))
shutil.copy(img_list(tempdir)[-1], os.path.join(a.out_dir, '%s-%d.jpg' % (out_name, a.steps)))
if a.save_pt is True:
torch.save(params, '%s.pt' % os.path.join(a.out_dir, out_name))
if __name__ == '__main__':
main()