Skip to content

Commit

Permalink
rework extras tab to use script system
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Jan 23, 2023
1 parent 68303c9 commit b523019
Show file tree
Hide file tree
Showing 13 changed files with 675 additions and 464 deletions.
5 changes: 0 additions & 5 deletions javascript/ui.js
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,6 @@ function create_tab_index_args(tabId, args){
return res
}

function get_extras_tab_index(){
const [,,...args] = [...arguments]
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
}

function get_img2img_tab_index() {
let res = args_to_array(arguments)
res.splice(-2)
Expand Down
13 changes: 5 additions & 8 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
from secrets import compare_digest

import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.extras import run_extras
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
Expand Down Expand Up @@ -45,10 +44,8 @@ def validate_sampler_name(name):

def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
reqDict.pop('upscaler_1')
reqDict.pop('upscaler_2')
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
return reqDict

def decode_base64_to_image(encoding):
Expand Down Expand Up @@ -244,7 +241,7 @@ def extras_single_image_api(self, req: ExtrasSingleImageRequest):
reqDict['image'] = decode_base64_to_image(reqDict['image'])

with self.queue_lock:
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)

return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])

Expand All @@ -260,7 +257,7 @@ def prepareFiles(file):
reqDict.pop('imageList')

with self.queue_lock:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)

return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])

Expand Down
236 changes: 60 additions & 176 deletions modules/postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,219 +1,103 @@
from __future__ import annotations
import os

import numpy as np
from PIL import Image

from typing import Callable, List, OrderedDict, Tuple
from functools import partial
from dataclasses import dataclass

from modules import shared, images, devices, ui_components
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
from modules.shared import opts
import modules.gfpgan_model
import modules.codeformer_model


class LruCache(OrderedDict):
@dataclass(frozen=True)
class Key:
image_hash: int
info_hash: int
args_hash: int

@dataclass
class Value:
image: Image.Image
info: str

def __init__(self, max_size: int = 5, *args, **kwargs):
super().__init__(*args, **kwargs)
self._max_size = max_size

def get(self, key: LruCache.Key) -> LruCache.Value:
ret = super().get(key)
if ret is not None:
self.move_to_end(key) # Move to end of eviction list
return ret

def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
self[key] = value
while len(self) > self._max_size:
self.popitem(last=False)


cached_images: LruCache = LruCache(max_size=5)


def run_postprocessing(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc()

shared.state.begin()
shared.state.job = 'extras'

imageArr = []
# Also keep track of original file names
imageNameArr = []
image_data = []
image_names = []
outputs = []

if extras_mode == 1:
#convert file to pillow image
for img in image_folder:
image = Image.open(img)
imageArr.append(image)
imageNameArr.append(os.path.splitext(img.orig_name)[0])
image_data.append(image)
image_names.append(os.path.splitext(img.orig_name)[0])
elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected'

if input_dir == '':
return outputs, "Please select an input directory.", ''
image_list = shared.listfiles(input_dir)
for img in image_list:
for filename in image_list:
try:
image = Image.open(img)
image = Image.open(filename)
except Exception:
continue
imageArr.append(image)
imageNameArr.append(img)
image_data.append(image)
image_names.append(filename)
else:
imageArr.append(image)
imageNameArr.append(None)
assert image, 'image not selected'

image_data.append(image)
image_names.append(None)

if extras_mode == 2 and output_dir != '':
outpath = output_dir
else:
outpath = opts.outdir_samples or opts.outdir_extras_samples

# Extra operation definitions

def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
shared.state.job = 'extras-gfpgan'
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img)

if gfpgan_visibility < 1.0:
res = Image.blend(image, res, gfpgan_visibility)

info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
return (res, info)

def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
shared.state.job = 'extras-codeformer'
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
res = Image.fromarray(restored_img)

if codeformer_visibility < 1.0:
res = Image.blend(image, res, codeformer_visibility)

info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
return (res, info)

def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
shared.state.job = 'extras-upscale'
upscaler = shared.sd_upscalers[scaler_index]
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
cropped = Image.new("RGB", (resize_w, resize_h))
cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
res = cropped
return res

def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
# Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
nonlocal upscaling_resize
if resize_mode == 1:
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
crop_info = " (crop)" if upscaling_crop else ""
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
return (image, info)

@dataclass
class UpscaleParams:
upscaler_idx: int
blend_alpha: float

def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
blended_result: Image.Image = None
image_hash: str = hash(np.array(image.getdata()).tobytes())
for upscaler in params:
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = LruCache.Key(image_hash=image_hash,
info_hash=hash(info),
args_hash=hash(upscale_args))
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
cached_images.put(cache_key, LruCache.Value(image=res, info=info))
else:
res, info = cached_entry.image, cached_entry.info

if blended_result is None:
blended_result = res
else:
blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
return (blended_result, info)

# Build a list of operations to run
facefix_ops: List[Callable] = []
facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []

upscale_ops: List[Callable] = []
upscale_ops += [run_prepare_crop] if resize_mode == 1 else []

if upscaling_resize != 0:
step_params: List[UpscaleParams] = []
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))

upscale_ops.append(partial(run_upscalers_blend, step_params))

extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)

for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''

shared.state.textinfo = f'Processing image {image_name}'

infotext = ''

for image, name in zip(image_data, image_names):
shared.state.textinfo = name

existing_pnginfo = image.info or {}

image = image.convert("RGB")
info = ""
# Run each operation on each image
for op in extras_ops:
image, info = op(image, info)
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))

if opts.use_original_name_batch and image_name is not None:
basename = os.path.splitext(os.path.basename(image_name))[0]
scripts.scripts_postproc.run(pp, args)

if opts.use_original_name_batch and name is not None:
basename = os.path.splitext(os.path.basename(name))[0]
else:
basename = ''

if opts.enable_pnginfo: # append info before save
image.info = existing_pnginfo
image.info["extras"] = info
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])

if save_output:
# Add upscaler name as a suffix.
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
# Add second upscaler if applicable.
if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
if opts.enable_pnginfo:
pp.image.info = existing_pnginfo
pp.image.info["postprocessing"] = infotext

images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
if save_output:
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)

if extras_mode != 2 or show_extras_results :
outputs.append(image)
if extras_mode != 2 or show_extras_results:
outputs.append(pp.image)

devices.torch_gc()

return outputs, ui_components.plaintext_to_html(info), ''


def clear_cache():
cached_images.clear()

return outputs, ui_common.plaintext_to_html(infotext), ''


def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
"""old handler for API"""

args = scripts.scripts_postproc.create_args_for_run({
"Upscale": {
"upscale_mode": resize_mode,
"upscale_by": upscaling_resize,
"upscale_to_width": upscaling_resize_w,
"upscale_to_height": upscaling_resize_h,
"upscale_crop": upscaling_crop,
"upscaler_1_name": extras_upscaler_1,
"upscaler_2_name": extras_upscaler_2,
"upscaler_2_visibility": extras_upscaler_2_visibility,
},
"GFPGAN": {
"gfpgan_visibility": gfpgan_visibility,
},
"CodeFormer": {
"codeformer_visibility": codeformer_visibility,
"codeformer_weight": codeformer_weight,
},
})

return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
Loading

0 comments on commit b523019

Please sign in to comment.