-
Notifications
You must be signed in to change notification settings - Fork 316
/
Copy pathsample_inpainting.py
62 lines (52 loc) · 2.04 KB
/
sample_inpainting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import os, sys
from PIL import Image
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
EulerDiscreteScheduler
)
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_dir)
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_inpainting import StableDiffusionXLInpaintPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
def infer(image_path, mask_path, prompt):
ckpt_dir = f'{root_dir}/weights/Kolors-Inpainting'
text_encoder = ChatGLMModel.from_pretrained(
f'{ckpt_dir}/text_encoder',
torch_dtype=torch.float16).half()
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half()
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half()
pipe = StableDiffusionXLInpaintPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler
)
pipe.to("cuda")
pipe.enable_attention_slicing()
generator = torch.Generator(device="cpu").manual_seed(603)
basename = image_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]
image = Image.open(image_path).convert('RGB')
mask_image = Image.open(mask_path).convert('RGB')
result = pipe(
prompt = prompt,
image = image,
mask_image = mask_image,
height=1024,
width=768,
guidance_scale = 6.0,
generator= generator,
num_inference_steps= 25,
negative_prompt = '残缺的手指,畸形的手指,畸形的手,残肢,模糊,低质量',
num_images_per_prompt = 1,
strength = 0.999
).images[0]
result.save(f'{root_dir}/scripts/outputs/sample_inpainting_{basename}.jpg')
if __name__ == '__main__':
import fire
fire.Fire(infer)