From 0cd49138c8e5c76afce37c2998931b9f9fb718e1 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 20 Aug 2022 18:45:43 +0200 Subject: [PATCH 1/2] Add Gradio interface for inpainting --- scripts/inpaint_gradio.py | 115 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 scripts/inpaint_gradio.py diff --git a/scripts/inpaint_gradio.py b/scripts/inpaint_gradio.py new file mode 100644 index 000000000..6263535ae --- /dev/null +++ b/scripts/inpaint_gradio.py @@ -0,0 +1,115 @@ +from omegaconf import OmegaConf +from PIL import Image +import numpy as np +import torch +from main import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +import gradio as gr + + +def make_batch(image, mask, device): + + if image.size != (512, 512): + print("Resampling image to 512x512") + image = image.resize((512, 512), Image.Resampling.LANCZOS) + + image = np.array(image.convert("RGB")) + image = image.astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + + if mask.size != (512, 512): + print("Resampling mask to 512x512") + mask = mask.resize((512, 512), Image.Resampling.LANCZOS) + + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = (1 - mask) * image + + batch = {"image": image, "mask": mask, "masked_image": masked_image} + for k in batch: + batch[k] = batch[k].to(device=device) + batch[k] = batch[k] * 2.0 - 1.0 + return batch + + +def run( + *, + image, + mask, + device, + model, + sampler, + steps, +): + batch = make_batch(image, mask, device=device) + + # encode masked image and concat downsampled mask + c = model.cond_stage_model.encode(batch["masked_image"]) + cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:]) + c = torch.cat((c, cc), dim=1) + + shape = (c.shape[1] - 1,) + c.shape[2:] + samples_ddim, _ = sampler.sample( + S=steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False + ) + x_samples_ddim = model.decode_first_stage(samples_ddim) + + image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0) + mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + inpainted = (1 - mask) * image + mask * predicted_image + inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 + image = Image.fromarray(inpainted.astype(np.uint8)) + + return image + + +if __name__ == "__main__": + + config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") + model = instantiate_from_config(config.model) + model.load_state_dict( + torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False + ) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + sampler = DDIMSampler(model) + + def gradio_run(sketch, nb_steps): + + image = sketch["image"] + mask = sketch["mask"] + + generated = run( + image=image, + mask=mask, + device=device, + model=model, + sampler=sampler, + steps=nb_steps, + ) + + return [generated] + + inpaint_interface = gr.Interface( + gradio_run, + inputs=[ + gr.Image(interactive=True, type="pil", tool="sketch"), + gr.Slider(minimum=1, maximum=200, value=50, label="Number of steps"), + ], + outputs=[ + gr.Gallery(), + ], + ) + + with torch.no_grad(): + with model.ema_scope(): + inpaint_interface.launch() From fc47d79339ccd46b0b8c809adf56d3e1899e04b7 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 21 Aug 2022 19:02:52 +0200 Subject: [PATCH 2/2] Improvements for images not 512x512 pixels - Now creating an image the same size of input by resizing prediction - Automatically rotate the image according to EXIF data --- scripts/inpaint_gradio.py | 79 +++++++++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 28 deletions(-) diff --git a/scripts/inpaint_gradio.py b/scripts/inpaint_gradio.py index 6263535ae..55cab55af 100644 --- a/scripts/inpaint_gradio.py +++ b/scripts/inpaint_gradio.py @@ -1,27 +1,37 @@ from omegaconf import OmegaConf -from PIL import Image +from PIL import Image, ImageOps import numpy as np import torch +import torchvision.transforms.functional as F from main import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler import gradio as gr -def make_batch(image, mask, device): +def run( + *, + image, + mask, + device, + model, + sampler, + steps, +): + + # Transpose image if needed according to EXIF data + image = ImageOps.exif_transpose(image) - if image.size != (512, 512): - print("Resampling image to 512x512") - image = image.resize((512, 512), Image.Resampling.LANCZOS) + # Save original image size + orig_size = image.size + print(f"Original image size: {orig_size}") + # Convert image from PIL Image to torch tensor image = np.array(image.convert("RGB")) image = image.astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) - if mask.size != (512, 512): - print("Resampling mask to 512x512") - mask = mask.resize((512, 512), Image.Resampling.LANCZOS) - + # Convert mask from PIL Image to torch tensor mask = np.array(mask.convert("L")) mask = mask.astype(np.float32) / 255.0 mask = mask[None, None] @@ -29,31 +39,30 @@ def make_batch(image, mask, device): mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) + # Rescale image and mask if needed, saving unscaled original image and mask + orig_image = image + orig_mask = mask + + if orig_size != (512, 512): + print("Resize image an mask to 512x512") + image = F.resize(image, (512, 512), interpolation=F.InterpolationMode.BICUBIC) + mask = F.resize(mask, (512, 512), interpolation=F.InterpolationMode.BICUBIC) + + # Compute the masked image masked_image = (1 - mask) * image + # Saving tensors in a batch dict and move them to the GPU batch = {"image": image, "mask": mask, "masked_image": masked_image} for k in batch: batch[k] = batch[k].to(device=device) batch[k] = batch[k] * 2.0 - 1.0 - return batch - - -def run( - *, - image, - mask, - device, - model, - sampler, - steps, -): - batch = make_batch(image, mask, device=device) # encode masked image and concat downsampled mask c = model.cond_stage_model.encode(batch["masked_image"]) cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:]) c = torch.cat((c, cc), dim=1) + # Predict image shape = (c.shape[1] - 1,) + c.shape[2:] samples_ddim, _ = sampler.sample( S=steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False @@ -64,11 +73,24 @@ def run( mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0) predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - inpainted = (1 - mask) * image + mask * predicted_image - inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 - image = Image.fromarray(inpainted.astype(np.uint8)) + # Get final image tensor by adding the original masked image with the + # prediction inside the mask - resizing prediction image if needed + if orig_size == (512, 512): + inpainted = (1 - mask) * image + mask * predicted_image + inpainted = inpainted.cpu() + else: + w, h = orig_size + print(f"Resize prediction to {w}x{h}") + predicted_image = F.resize( + predicted_image, (h, w), interpolation=F.InterpolationMode.BICUBIC + ) + inpainted = (1 - orig_mask) * orig_image + orig_mask * predicted_image.cpu() + + # Convert final image back to a PIL Image + inpainted = inpainted.numpy().transpose(0, 2, 3, 1)[0] * 255 + image_result = Image.fromarray(inpainted.astype(np.uint8)) - return image + return image_result if __name__ == "__main__": @@ -97,7 +119,7 @@ def gradio_run(sketch, nb_steps): steps=nb_steps, ) - return [generated] + return generated inpaint_interface = gr.Interface( gradio_run, @@ -106,8 +128,9 @@ def gradio_run(sketch, nb_steps): gr.Slider(minimum=1, maximum=200, value=50, label="Number of steps"), ], outputs=[ - gr.Gallery(), + gr.Image(), ], + article="To avoid rescaling, use an image of dimensions **512x512**.", ) with torch.no_grad():