-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsample_ultra_res_demo.py
437 lines (342 loc) · 18.2 KB
/
sample_ultra_res_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
from uuid import uuid4
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
import argparse
import math
from skimage import color
import cv2
import numpy as np
from imagen_pytorch.trainer import restore_parts
from imagen_pytorch.version import __version__
from packaging import version
from torchvision.utils import save_image
import torchvision.transforms as transforms
from train_ultra_res import init_imagen
from ultra_res_patient_dataset import MAG_LEVEL_SIZES
import os
import gc
from fsspec.core import url_to_fs
import warnings
# used to ignore CUDA warnings that clog stdout
# REMOVE if there are CUDA errors other than those expected
warnings.filterwarnings("ignore", category=UserWarning)
PATCH_SIZE = 1024
PATCH_SIZES = {1: 64, 2: 256, 3: 1024}
BATCH_SIZES = [128, 64, 6]
FILL_COLOR = 0.95
def load_model(mag_level, unet_number, device, args):
imagen = init_imagen(mag_level, unet_number, device=device)
checkpoint_name = f"unet{unet_number}_mag{mag_level}"
path = vars(args)[checkpoint_name]
fs, _ = url_to_fs(path)
with fs.open(path) as f:
loaded_obj = torch.load(f, map_location='cpu')
if version.parse(__version__) != version.parse(loaded_obj['version']):
print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}')
try:
imagen.load_state_dict(loaded_obj['model'], strict=True)
except RuntimeError:
print("Failed loading state dict. Trying partial load")
imagen.load_state_dict(restore_parts(imagen.state_dict(), loaded_obj['model']))
return imagen
def print_memory_usage(rank):
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
print(f"cuda:{rank} total memory: {t}, reserverd memory: {r}, allocated memory: {a}, free memory: {r-a}")
def generate_image_distributed(rank, mag_level, unet_number, args, in_queue, out_queue, patches_generated, overlap, orientation, patch_pos, num_patches_width):
device = torch.device(f"cuda:{rank}")
print(f"started process on {device}")
imagen = load_model(mag_level, unet_number, device, args)
while True:
item = in_queue.get()
if item is None:
break
idx, batch_lowres_image, cond_image, pos = item
inpaint_patch = None
inpaint_mask = None
# need to check if patch above, next to, and above and next to this patch have been generated
# if they have, then we can generate this patch
if pos is not None:
i, j = pos
above_patch = None
next_to_patch = None
above_next_to_patch = None
above = (i - 1, j)
above_idx = -1 if above not in patch_pos else patch_pos.index(above)
next_to = (i, j + orientation)
next_to_idx = -1 if next_to not in patch_pos else patch_pos.index(next_to)
above_next_to = (i - 1, j + orientation)
above_next_to_idx = -1 if above_next_to not in patch_pos else patch_pos.index(above_next_to)
above_exists = above_idx in patches_generated or above not in patch_pos
next_to_exists = next_to_idx in patches_generated or next_to not in patch_pos
above_next_to_exists = above_next_to_idx in patches_generated or above_next_to not in patch_pos
# variables needed to get upscaled crops from the cond image so we can blend nicely with it
unet_patch_size = PATCH_SIZES[unet_number]
patch_width = int(MAG_LEVEL_SIZES[mag_level] * PATCH_SIZE / MAG_LEVEL_SIZES[mag_level - 1])
patch_dist = int(patch_width * (1 - overlap))
topleft_y = cond_image.shape[1] // 2 - patch_width // 2
topleft_x = cond_image.shape[2] // 2 - patch_width // 2
above_y = topleft_y - patch_dist
above_x = topleft_x
next_to_y = topleft_y
next_to_x = topleft_x + orientation * patch_dist
above_next_to_y = topleft_y - patch_dist
above_next_to_x = topleft_x + orientation * patch_dist
space_above = i != 0
space_next_to = (orientation == 1 and j < num_patches_width - 1) or (orientation == -1 and j > 0)
if i > 1 or j > 1:
continue
elif above_exists and next_to_exists and above_next_to_exists:
if above_idx in patches_generated:
above_patch = patches_generated[above_idx][0]
elif space_above:
above_patch = cond_image[:, above_y:above_y+patch_width, above_x:above_x+patch_width].unsqueeze(0)
above_patch = F.interpolate(above_patch, size=(unet_patch_size, unet_patch_size), mode='bilinear', align_corners=False)[0]
if next_to_idx in patches_generated:
next_to_patch = patches_generated[next_to_idx][0]
elif space_next_to:
next_to_patch = cond_image[:, next_to_y:next_to_y+patch_width, next_to_x:next_to_x+patch_width].unsqueeze(0)
next_to_patch = F.interpolate(next_to_patch, size=(unet_patch_size, unet_patch_size), mode='bilinear', align_corners=False)[0]
if above_next_to_idx in patches_generated:
above_next_to_patch = patches_generated[above_next_to_idx][0]
elif space_above and space_next_to:
above_next_to_patch = cond_image[:, above_next_to_y:above_next_to_y+patch_width, above_next_to_x:above_next_to_x+patch_width].unsqueeze(0)
above_next_to_patch = F.interpolate(above_next_to_patch, size=(unet_patch_size, unet_patch_size), mode='bilinear', align_corners=False)[0]
else:
in_queue.put((idx, batch_lowres_image, cond_image, pos))
continue
print(f"generating patch at {pos} which is index {idx}", flush=True)
# inpaint_patch is the patch that will be generated with above, next_to, and above_next_to patches
# already generated. They need to be added to the inpaint_patch in the correct positions
inpaint_patch = torch.zeros(3, unet_patch_size, unet_patch_size)
inpaint_mask = torch.zeros(unet_patch_size, unet_patch_size)
overlap_pos = int(overlap * unet_patch_size)
# if we are at the top of the image, then above_patch is None
# if we are at the left/right of the image, then next_to_patch is None
# if we are at the top left/right of the image, then above_next_to_patch is None#
if above_patch is not None:
inpaint_patch[:, :overlap_pos, :] = above_patch[:, -overlap_pos:, :]
inpaint_mask[:overlap_pos, :] = 1
if next_to_patch is not None:
if orientation == -1:
inpaint_patch[:, :, :overlap_pos] = next_to_patch[:, :, -overlap_pos:]
inpaint_mask[:, :overlap_pos] = 1
else:
inpaint_patch[:, :, -overlap_pos:] = next_to_patch[:, :, :overlap_pos]
inpaint_mask[:, -overlap_pos:] = 1
if above_next_to_patch is not None:
if orientation == -1:
inpaint_patch[:, :overlap_pos, :overlap_pos] = above_next_to_patch[:, -overlap_pos:, -overlap_pos:]
else:
inpaint_patch[:, :overlap_pos, -overlap_pos:] = above_next_to_patch[:, -overlap_pos:, :overlap_pos]
inpaint_patch = inpaint_patch.unsqueeze(0).to(device)
inpaint_mask = inpaint_mask.unsqueeze(0).to(device)
save_image(inpaint_patch[0].cpu(), f"{args.sample_dir}/{uuid4()}_inpaint_patch_{pos}.png")
save_image(inpaint_mask[0].cpu(), f"{args.sample_dir}/{uuid4()}_inpaint_mask_{pos}.png")
save_image(cond_image.cpu(), f"{args.sample_dir}/{uuid4()}_cond_image_{pos}.png")
if above_patch is not None:
save_image(above_patch.cpu(), f"{args.sample_dir}/{uuid4()}_above_patch_{pos}.png")
if next_to_patch is not None:
save_image(next_to_patch.cpu(), f"{args.sample_dir}/{uuid4()}_next_to_patch_{pos}.png")
if above_next_to_patch is not None:
save_image(above_next_to_patch.cpu(), f"{args.sample_dir}/{uuid4()}_above_next_to_patch_{pos}.png")
if cond_image != None:
cond_image = cond_image.unsqueeze(0).to(device)
if batch_lowres_image != None:
batch_lowres_image = batch_lowres_image.to(device)
batch_image = imagen.sample(
batch_size=1,
return_pil_images=False,
cond_images=cond_image,
start_image_or_video=batch_lowres_image,
start_at_unet_number=unet_number,
stop_at_unet_number=unet_number,
inpaint_images=inpaint_patch,
inpaint_masks=inpaint_mask,
inpaint_resample_times=args.inpaint_resample,
use_tqdm=False,
device=device,
)
save_image(batch_image[0].cpu(), f"{args.sample_dir}/{uuid4()}_patch_{pos}.png")
if pos is not None:
print(f"{len(patches_generated)}/{len(patch_pos)} patches generated", flush=True)
patches_generated[idx] = batch_image.cpu()
out_queue.put((idx,))
del imagen
del cond_image
del batch_lowres_image
gc.collect()
torch.cuda.empty_cache()
def generate_image_with_unet(mag_level, unet_number, args, lowres_image, cond_image, patch_pos, overlap, orientation, num_patches_width):
in_queue = mp.Queue()
out_queue = mp.Queue()
patches_generated = mp.Manager().dict()
processes = []
if cond_image is not None:
num_cond_images = cond_image.shape[0]
else:
num_cond_images = 1
print(f"Generating {num_cond_images} images for mag {mag_level} and unet {unet_number}")
images = []
for idx in range(num_cond_images):
if lowres_image is not None and lowres_image[idx] is None:
continue
# Extract the corresponding batch of cond_images and call trainer.sample()
idx_cond_image = None if cond_image is None else cond_image[idx]
idx_lowres_image = None if lowres_image is None else lowres_image[idx]
pos = None if patch_pos is None else patch_pos[idx]
in_queue.put((idx, idx_lowres_image, idx_cond_image, pos))
num_processes = min(args.num_gpus, num_cond_images)
for rank in range(num_processes):
p = mp.Process(target=generate_image_distributed, args=(rank, mag_level, unet_number, args, in_queue, out_queue, patches_generated, overlap, orientation, patch_pos, num_patches_width))
p.start()
processes.append(p)
for _ in range(4 if mag_level == 1 else 1):
out_queue.get()
for _ in range(num_processes):
in_queue.put(None)
for p in processes:
p.join()
images = [(patches_generated[idx] if idx in patches_generated else None) for idx in range(num_cond_images)]
if cond_image is not None:
del cond_image
if lowres_image is not None:
del lowres_image
gc.collect()
torch.cuda.empty_cache()
return images
def generate_image(mag_level, args, cond_image=None, patch_pos=None, overlap=0.25, orientation=-1, num_patches_width=1):
lowres_image = generate_image_with_unet(mag_level, 1, args, None, cond_image, patch_pos, overlap, orientation, num_patches_width)
medres_image = generate_image_with_unet(mag_level, 2, args, lowres_image, cond_image, patch_pos, overlap, orientation, num_patches_width)
highres_image = generate_image_with_unet(mag_level, 3, args, medres_image, cond_image, patch_pos, overlap, orientation, num_patches_width)
return highres_image
# mag0 images represent 40000x40000 patches, but are 1024x1024
# We need to get the positions of the centers of all patches
# that are 6500x6500 in this image, and use these as the conditioning
# images to generate the mag1 images.
#
# Each pixel in this image is 40000/1024 pixels in the original, and
# eaxh pixel in the original is 1024/40000 pixels in this image.
#
# So a 6500x6500 patch is 6500 * 1024/40000 pixels in this image.
#
# So split the image into these patches, and move each image around
# so that the patch is at the center.
#
# FOR MAG2
# Zoomed image is now much larger than patch_size. Each PATCH_SIZE
# patch in the image is the correct scale for a 6500x6500 patch that
# will be used to condition mag2 generation.
#
# So basically, we just need to find all the patches we need to
# generate for mag2, then get a PATCH_SIZE crop around that area in
# the mag1 full scale image
def get_cond_images(zoomed_image, mag_level, overlap=0.25):
# patch size of a mag1 image within the generated mag0 image
patch_width = int(MAG_LEVEL_SIZES[mag_level] * PATCH_SIZE / MAG_LEVEL_SIZES[mag_level - 1])
patch_dist = int(patch_width * (1 - overlap))
zoomed_image_width = zoomed_image.shape[3]
# This takes into account the overlap
num_patches_width = 1 + math.ceil((zoomed_image_width - patch_width) / patch_dist)
# we want to filter out white patches to save time
if mag_level == 2:
zoomed_image_np = zoomed_image[0].permute(1, 2, 0).cpu().numpy()
# Mask out the background
img_hs = color.rgb2hsv(zoomed_image_np)
img_hs = np.logical_and(img_hs[:, :, 0] > 0.5, img_hs[:, :, 1] > 0.02)
# remove small objects
img_hs = cv2.erode(img_hs.astype(np.uint8), np.ones((5, 5), np.uint8), iterations=1)
# grow the mask
kernel = np.ones((51, 51), np.uint8)
img_hs = cv2.dilate(img_hs.astype(np.uint8), kernel, iterations=1)
print("cond image details:")
print("patch_width", patch_width)
print("patch_dist", patch_dist)
print("zoomed_image_width", zoomed_image_width)
print("num_patches_width", num_patches_width)
print("", flush=True)
# find patches of 161x161 that have mask > 0.5
# iterate over patch positions and check if the mask is > 0.5
patch_pos = []
for i in range(num_patches_width):
for j in range(num_patches_width):
y = i * patch_dist
x = j * patch_dist
patch = img_hs[y:y + patch_width, x:x + patch_width]
# if any of the pixels in the patch are > 0.5, then add the patch to the list
if np.any(patch > 0.5):
patch_pos.append((i, j))
else:
patch_pos = [(i, j) for i in range(num_patches_width) for j in range(num_patches_width)]
cond_images = []
for i, j in patch_pos:
y = i * patch_dist
x = j * patch_dist
center_y = y + patch_width // 2
center_x = x + patch_width // 2
# need to move the mag0_image so that the center of it
# aligns with the center of this patch. So zoomed_image_width // 2
# in the generated image should be aligned with center_y and center_x
shift_y = zoomed_image_width // 2 - center_y
shift_x = zoomed_image_width // 2 - center_x
# Shift the image horizontally and vertically
shifted_img = torch.roll(zoomed_image[0], shifts=(shift_y, shift_x), dims=(1, 2))
# Fill any gaps with the fill_color
if shift_y > 0:
shifted_img[:, :shift_y, :] = FILL_COLOR
else:
shifted_img[:, shift_y:, :] = FILL_COLOR
if shift_x > 0:
shifted_img[:, :, :shift_x] = FILL_COLOR
else:
shifted_img[:, :, shift_x:] = FILL_COLOR
# This shouldn't do anything for mag1 since zoomed_image is 1024x1024
shifted_img = transforms.CenterCrop(PATCH_SIZE)(shifted_img)
cond_images.append(shifted_img)
return torch.stack(cond_images), patch_pos, num_patches_width
def get_next_patches(patches, orientation):
processed_patches = []
waiting_patches = []
for i, j in patches:
if (i - 1, j) not in patches and (i, j + orientation) not in patches and (i - 1, j + orientation) not in patches:
processed_patches.append((i, j))
else:
waiting_patches.append((i, j))
return processed_patches, waiting_patches
def generate_high_res_image(zoomed_image, mag_level, args):
cond_images, patch_pos, num_patches_width = get_cond_images(zoomed_image, mag_level, overlap=args.overlap)
num_top_left_patches = len(get_next_patches(patch_pos, -1)[0])
num_top_right_patches = len(get_next_patches(patch_pos, 1)[0])
orientation = -1 if num_top_left_patches >= num_top_right_patches else 1
generate_image(mag_level, args, cond_image=cond_images, patch_pos=patch_pos, overlap=args.overlap, orientation=orientation, num_patches_width=num_patches_width)
def main():
args = parse_args()
try:
os.makedirs(args.sample_dir)
except FileExistsError:
pass
for i in range(20):
mag0_images = generate_image(0, args)
save_image(mag0_images[0][0], f'{args.sample_dir}/{uuid4()}_MAG0.png')
generate_high_res_image(mag0_images[0], 1, args)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--unet1_mag0', type=str)
parser.add_argument('--unet1_mag1', type=str)
parser.add_argument('--unet1_mag2', type=str)
parser.add_argument('--unet2_mag0', type=str)
parser.add_argument('--unet2_mag1', type=str)
parser.add_argument('--unet2_mag2', type=str)
parser.add_argument('--unet3_mag0', type=str)
parser.add_argument('--unet3_mag1', type=str)
parser.add_argument('--unet3_mag2', type=str)
parser.add_argument('--num_gpus', type=int)
parser.add_argument('--inpaint_resample', type=int)
parser.add_argument('--overlap', type=float)
parser.add_argument('--sample_dir', default="samples", type=str)
return parser.parse_args()
if __name__ == '__main__':
torch.multiprocessing.set_start_method('spawn')
main()