forked from crowsonkb/v-diffusion-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcfg_modify_image.py
executable file
·154 lines (131 loc) · 6.31 KB
/
cfg_modify_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python3
"""Applies a text prompt to an existing image by finding a latent that would produce it
with the unconditioned DDIM ODE, then integrating the text-conditional DDIM ODE starting
from that latent."""
import argparse
from functools import partial
from pathlib import Path
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm import trange
from CLIP import clip
from diffusion import get_model, get_models, sampling, utils
MODULE_DIR = Path(__file__).resolve().parent
def parse_prompt(prompt, default_weight=3.):
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 2)
vals = [vals[0] + ':' + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(':', 1)
vals = vals + ['', default_weight][len(vals):]
return vals[0], float(vals[1])
def resize_and_center_crop(image, size):
fac = max(size[0] / image.size[0], size[1] / image.size[1])
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
return TF.center_crop(image, size[::-1])
def main():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('init', type=str,
help='the init image')
p.add_argument('prompts', type=str, default=[], nargs='*',
help='the text prompts to use')
p.add_argument('--images', type=str, default=[], nargs='*', metavar='IMAGE',
help='the image prompts')
p.add_argument('--checkpoint', type=str,
help='the checkpoint to use')
p.add_argument('--device', type=str,
help='the device to use')
p.add_argument('--max-timestep', '-mt', type=float, default=1.,
help='the maximum timestep')
p.add_argument('--method', type=str, default='plms',
choices=['ddim', 'prk', 'plms', 'pie', 'plms2'],
help='the sampling method to use')
p.add_argument('--model', type=str, default='cc12m_1_cfg', choices=['cc12m_1_cfg'],
help='the model to use')
p.add_argument('--output', '-o', type=str, default='out.png',
help='the output filename')
p.add_argument('--size', type=int, nargs=2,
help='the output image size')
p.add_argument('--steps', type=int, default=50,
help='the number of timesteps')
args = p.parse_args()
if args.device:
device = torch.device(args.device)
else:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
model = get_model(args.model)()
_, side_y, side_x = model.shape
if args.size:
side_x, side_y = args.size
checkpoint = args.checkpoint
if not checkpoint:
checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pth'
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
if device.type == 'cuda':
model = model.half()
model = model.to(device).eval().requires_grad_(False)
clip_model_name = model.clip_model if hasattr(model, 'clip_model') else 'ViT-B/16'
clip_model = clip.load(clip_model_name, jit=False, device=device)[0]
clip_model.eval().requires_grad_(False)
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
init = Image.open(utils.fetch(args.init)).convert('RGB')
init = resize_and_center_crop(init, (side_x, side_y))
init = utils.from_pil_image(init).to(device)[None]
zero_embed = torch.zeros([1, clip_model.visual.output_dim], device=device)
target_embeds, weights = [zero_embed], []
for prompt in args.prompts:
txt, weight = parse_prompt(prompt)
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
weights.append(weight)
for prompt in args.images:
path, weight = parse_prompt(prompt)
img = Image.open(utils.fetch(path)).convert('RGB')
clip_size = clip_model.visual.input_resolution
img = resize_and_center_crop(img, (clip_size, clip_size))
batch = TF.to_tensor(img)[None].to(device)
embed = F.normalize(clip_model.encode_image(normalize(batch)).float(), dim=-1)
target_embeds.append(embed)
weights.append(weight)
weights = torch.tensor([1 - sum(weights), *weights], device=device)
def cfg_model_fn(x, t):
n = x.shape[0]
n_conds = len(target_embeds)
x_in = x.repeat([n_conds, 1, 1, 1])
t_in = t.repeat([n_conds])
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
v = vs.mul(weights[:, None, None, None, None]).sum(0)
return v
def run():
t = torch.linspace(0, 1, args.steps + 1, device=device)
steps = utils.get_spliced_ddpm_cosine_schedule(t)
steps = steps[steps <= args.max_timestep]
if args.method == 'ddim':
x = sampling.reverse_sample(model, init, steps, {'clip_embed': zero_embed})
out = sampling.sample(cfg_model_fn, x, steps.flip(0)[:-1], 0, {})
if args.method == 'prk':
x = sampling.prk_sample(model, init, steps, {'clip_embed': zero_embed}, is_reverse=True)
out = sampling.prk_sample(cfg_model_fn, x, steps.flip(0)[:-1], {})
if args.method == 'plms':
x = sampling.plms_sample(model, init, steps, {'clip_embed': zero_embed}, is_reverse=True)
out = sampling.plms_sample(cfg_model_fn, x, steps.flip(0)[:-1], {})
if args.method == 'pie':
x = sampling.pie_sample(model, init, steps, {'clip_embed': zero_embed}, is_reverse=True)
out = sampling.pie_sample(cfg_model_fn, x, steps.flip(0)[:-1], {})
if args.method == 'plms2':
x = sampling.plms2_sample(model, init, steps, {'clip_embed': zero_embed}, is_reverse=True)
out = sampling.plms2_sample(cfg_model_fn, x, steps.flip(0)[:-1], {})
utils.to_pil_image(out[0]).save(args.output)
try:
run()
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()