From 78080f5060441d224f3b4f49b2df3fb18931e54c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B3=A8=E6=B0=B4=E7=9A=84=E7=9F=BF=E5=B7=A5?= <43202966+waterminer@users.noreply.github.com> Date: Sun, 13 Aug 2023 00:44:30 +0800 Subject: [PATCH] =?UTF-8?q?0.3.0=E6=9B=B4=E6=96=B0=20(#10)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 更新README * 新增AI放大,支持多种算法 * 更改启动方式,现在可以带参数启动了 * 新增install.py,现在可以作为库来引入你的项目 * 修复若干BUG --- README.md | 25 +- {module => dataset_processor}/__init__.py | 3 +- dataset_processor/data.py | 55 ++++ dataset_processor/filter.py | 48 +++ {module => dataset_processor}/processor.py | 107 ++++--- dataset_processor/tools/__init__.py | 2 + {module => dataset_processor}/tools/tagger.py | 15 +- dataset_processor/tools/upscale.py | 191 ++++++++++++ dataset_processor/uitl.py | 289 ++++++++++++++++++ doc/doc_cn.md | 87 +++++- main.py | 61 ++-- module/data.py | 57 ---- module/filter.py | 31 -- module/uitl.py | 137 --------- requirements.txt | 6 +- setup.py | 32 ++ 16 files changed, 824 insertions(+), 322 deletions(-) rename {module => dataset_processor}/__init__.py (50%) create mode 100644 dataset_processor/data.py create mode 100644 dataset_processor/filter.py rename {module => dataset_processor}/processor.py (57%) create mode 100644 dataset_processor/tools/__init__.py rename {module => dataset_processor}/tools/tagger.py (97%) create mode 100644 dataset_processor/tools/upscale.py create mode 100644 dataset_processor/uitl.py delete mode 100644 module/data.py delete mode 100644 module/filter.py delete mode 100644 module/uitl.py create mode 100644 setup.py diff --git a/README.md b/README.md index 1a7fb96..54e8d29 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,14 @@ ✅批量处理图片 包括且不限于: + * 批量翻转 * 批量随机裁切 * 图片对比度增强 ✅批量处理标签 包括且不限于: + * 批量删除标签 * 批量插入标签 * 批量修改标签 @@ -20,14 +22,31 @@ ✅自动打标(试验性) +✅AI图片放大(试验性) + ✅子处理(试验性) ## TODO -🚧重构自动打标代码 +🚧智能裁切 -🚧图片放大功能 +🚧重构自动打标代码 🚧图形化界面 -如果这个项目给您提供了帮助,不妨点一个⭐star,谢谢 +## CREADIT + +### Upscale + +[Real ESRGAN](https://github.com/xinntao/Real-ESRGAN/):一种流行的AI放大方案 + +[Real CUGAN](https://github.com/bilibili/ailab/tree/main/Real-CUGAN):更加适合二次元的AI放大方案 + +[Real CUGAN-ncnn](https://github.com/Tohrusky/realcugan-ncnn-py):感谢这位作者提供的RealCUGAN工具包 + +### Tagger + +[WD-1.4-Tagger From SmilingWolf](https://huggingface.co/SmilingWolf):自动打标模型 + +--- +如果这个项目为您提供了帮助,不妨点一个⭐star,万分感谢! diff --git a/module/__init__.py b/dataset_processor/__init__.py similarity index 50% rename from module/__init__.py rename to dataset_processor/__init__.py index c07e6dc..68340f8 100644 --- a/module/__init__.py +++ b/dataset_processor/__init__.py @@ -1,4 +1,5 @@ from .data import Data from .filter import Filter -from .processor import Processor,ProcessorError +from .processor import Processor, ProcessorError from .uitl import * +from .tools import * diff --git a/dataset_processor/data.py b/dataset_processor/data.py new file mode 100644 index 0000000..cdd73c7 --- /dev/null +++ b/dataset_processor/data.py @@ -0,0 +1,55 @@ +from PIL import Image +import os + + +class Data: + + # 图片读取并初始化 + def __init__(self, path: str, name: str, ext: str): + self.token: list[str] = [] + self.conduct = "" + self.repeat = 0 + self.id = 0 + self.name = name + self.ext = ext + self.path = path + # 读取图片 + self.img = Image.open(os.path.join(path, name + ext)) + self.size = self.img.size + + # 载入标签 + def input_token(self, file_name: str, option=None): + clean_tag = False + NO_CHECK = [ # 清洗排除标签 + ':)', ';)', ':(', '>:)', '>:(', '\\(^o^)/', # 括号相关 + '^_^', '@_@', '>_@', '+_+', '+_-', 'o_o', '0_0', '|_|', '._.', '>_<', '=_=', '_', '<|>_<|>' # 下划线相关 + ] + if option.clean_tag: + clean_tag = True + with open(os.path.join(self.path, file_name), "r") as f: + self.token = f.read(-1).split(",") + for tag in self.token: + tag = tag.strip() + if clean_tag: + if tag not in NO_CHECK: + tag = tag.replace("_", " ") + tag = tag.replace("(", "\\(") + tag = tag.replace(")", "\\)") + + # 保存的方法 + def save(self, output_dir, option): + # 默认命名方式:id_conduct_repeat.ext 比如"000001_r_0.jpg" + save_name = str(self.id).zfill(6) + self.conduct + if option: + if option.save_source_name or option.save_conduct_id: + save_name = str(self.id).zfill(6) + if option.save_source_name: + save_name = save_name.join('_' + self.name) + if option.save_conduct_id: + save_name = save_name.join(self.conduct) + self.img.save(os.path.join(output_dir, save_name + self.ext)) + # print(save_name) + with open(os.path.join(output_dir, save_name + ".txt"), mode="w") as f: + text = ",".join(self.token) + f.write(text) + self.img.close() diff --git a/dataset_processor/filter.py b/dataset_processor/filter.py new file mode 100644 index 0000000..c6068a4 --- /dev/null +++ b/dataset_processor/filter.py @@ -0,0 +1,48 @@ +from .data import Data + + +class Filter: + """ + 这是一个过滤器类,包含所有有关数据过滤的函数 + 编写规范如下: + def 过滤器名(data:Data,arg)->bool: + #代码块 + return bool + 其中,True表示该数据会被过滤,False则会被保留 + """ + + def img_size(data: Data, size: list) -> bool: + min, max = tuple(size) + x, y = data.size + if min != -1: + if data.size[0] <= min or data.size[1] <= min: + return True + if max != -1: + if data.size[0] > max and data.size[1] > max: + return True + else: + return False + + def tag_filter(data: Data, tag) -> bool: + if tag in data.token: + return True + else: + return False + + def tag_selector(data: Data, tag) -> bool: + if tag in data.token: + return False + else: + return True + + def tag_is_not_none(data: Data) -> bool: + if data.token: + return False + else: + return True + + def tag_is_none(data: Data) -> bool: + if data.token: + return True + else: + return False diff --git a/module/processor.py b/dataset_processor/processor.py similarity index 57% rename from module/processor.py rename to dataset_processor/processor.py index 7b883bb..2942aff 100644 --- a/module/processor.py +++ b/dataset_processor/processor.py @@ -1,87 +1,86 @@ from random import randint as random from .data import Data +from .tools.tagger import Tagger +from .tools.upscale import UpscaleModel from PIL import Image from PIL import ImageEnhance import numpy as np + class Processor: - # 在这里定义处理方法 """ + 这是一个处理器类,包含所有有关数据处理的函数 编写规范如下: - def 处理名(data:Data,args): + def 处理名(data:Data,args)->Data: #代码块 return data """ - def random_crop(data, size): + def random_crop(data: Data, size) -> Data: if not (data.size[0] <= size or data.size[1] <= size): x = random(1, data.size[0] - size) y = random(1, data.size[1] - size) box = (x, y, x + size, y + size) data.img = data.img.crop(box) - data.conduct += "_rc" + data.conduct += f"_rc{data.repeat}" data.size = data.img.size else: raise ImageTooSmallError(data.name + data.ext) return data - def flip(data): + def flip(data: Data) -> Data: data.img = data.img.transpose(Image.FLIP_LEFT_RIGHT) - data.conduct += "_f" + data.conduct += f"_f{data.repeat}" return data - def resize(data: Data, proportion: float): - data.size = (int(data.size[0] * proportion), int(data.size[1] * proportion)) - data.img = data.img.resize(data.size) - data.conduct += "_r" + def resize(data: Data, proportion: float) -> Data: + size = (int(data.size[0] * proportion), int(data.size[1] * proportion)) + data.img = data.img.resize(size) + data.conduct += f"_r{data.repeat}" + data.size = data.img.size return data - def force_resize(data: Data, size: list): - data.size = (size[0], size[1]) - data.img = data.img.resize(data.size) - data.conduct += "_fr" + def force_resize(data: Data, size: list) -> Data: + data.img = data.img.resize(size) + data.conduct += f"_fr{data.repeat}" + data.size = data.img.size return data - - def offset(data: Data,offset:int): - data.img = data.img.offset(offset,0) - data.conduct += "_off" + + def offset(data: Data, offset: int) -> Data: + data.img = data.img.offset(offset, 0) + data.conduct += f"_off{data.repeat}" return data - - def rotation(data: Data, rot:int): + + def rotation(data: Data, rot: int) -> Data: data.img = data.img.rotate(rot) - data.conduct += "_rot" + data.conduct += f"_rot{data.repeat}" return data - - def contrast_enhancement(data: Data): #对比度增强 + + def contrast_enhancement(data: Data) -> Data: # 对比度增强 image = data.img enh_con = ImageEnhance.Contrast(image) contrast = 1.5 data.img = enh_con.enhance(contrast) - data.conduct += "_con_e" + data.conduct += f"_con_e{data.repeat}" return data - - def brightness_enhancement(data: Data):#亮度增强 + + def brightness_enhancement(data: Data) -> Data: # 亮度增强 image = data.img enh_bri = ImageEnhance.Brightness(image) brightness = 1.5 data.img = enh_bri.enhance(brightness) - data.conduct += "_bri_e" + data.conduct += f"_bri_e{data.repeat}" return data - def color_enhancement(data: Data):#颜色增强 + def color_enhancement(data: Data) -> Data: # 颜色增强 image = data.img enh_col = ImageEnhance.Color(image) color = 1.5 data.img = enh_col.enhance(color) data.conduct += "_col_e" return data - - def random_enhancement(data: Data): #随机抖动 - """ - 对图像进行颜色抖动 - :param image: PIL的图像image - :return: 有颜色色差的图像image - """ + + def random_enhancement(data: Data) -> Data: # 随机抖动 image = data.img random_factor = np.random.randint(8, 31) / 10. # 随机因子 color_image = ImageEnhance.Color(image).enhance(random_factor) # 调整图像的饱和度 @@ -91,42 +90,42 @@ def random_enhancement(data: Data): #随机抖动 contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor) # 调整图像对比度 random_factor = np.random.randint(8, 20) / 10. # 随机因子 data.img = ImageEnhance.Sharpness(contrast_image).enhance(random_factor) # 调整图像锐度 - data.conduct += "_ran_e" + data.conduct += f"_ran_e{data.repeat}" return data - def none(data: Data): + def none(data: Data) -> Data: """ 无操作,主要用于一些特殊场景 """ return data - def append_tag(data: Data, tag: str): + def append_tag(data: Data, tag: str) -> Data: data.token.append(tag) return data - def remove_tag(data: Data, tag: str): + def remove_tag(data: Data, tag: str) -> Data: if tag in data.token: data.token.remove(tag) else: - raise TagNotExistError(tag,data.name + data.ext) + raise TagNotExistError(tag, data.name + data.ext) return data - def insert_tag(data: Data, tag: str): + def insert_tag(data: Data, tag: str) -> Data: data.token.insert(0, tag) return data - def tag_move_forward(data: Data,tag:str): + def tag_move_forward(data: Data, tag: str) -> Data: """ 将匹配项放到开头 """ if tag in data.token: data.token.remove(tag) else: - raise TagNotExistError(tag,data.name + data.ext) + raise TagNotExistError(tag, data.name + data.ext) data.token.insert(0, tag) return data - - def rename_tag(data:Data,tags:list[str]): + + def rename_tag(data: Data, tags: list[str]) -> Data: """ 将Atag改名为Btag """ @@ -134,12 +133,21 @@ def rename_tag(data:Data,tags:list[str]): tag_b = tags[1] if tag_a in data.token: index = data.token.index(tag_a) - data.token.insert(index,tag_b) + data.token.insert(index, tag_b) data.token.remove(tag_a) else: - raise TagNotExistError(tag_a,data.name + data.ext) + raise TagNotExistError(tag_a, data.name + data.ext) return data + def tag_image(data: Data, tagger: Tagger): + return tagger.tag_data(data) + + def upscale_image(data: Data, upscale: UpscaleModel): + data.img = upscale.upscale_data(data) + data.size = data.img.size + return data + + # 自定义异常 class ProcessorError(RuntimeError): def __init__(self, *args: object) -> None: @@ -150,6 +158,7 @@ class ImageTooSmallError(ProcessorError): def __init__(self, name: str): print("image " + name + " is too small!") + class TagNotExistError(ProcessorError): - def __init__(self,tag,name: str): - print("Tag"+ tag + "not exist in"+name+"!") + def __init__(self, tag, name: str): + print("Tag" + tag + "not exist in" + name + "!") diff --git a/dataset_processor/tools/__init__.py b/dataset_processor/tools/__init__.py new file mode 100644 index 0000000..b6ef0d4 --- /dev/null +++ b/dataset_processor/tools/__init__.py @@ -0,0 +1,2 @@ +from .tagger import Tagger, TaggerOption +from .upscale import UpscaleModel, UpcaleOption diff --git a/module/tools/tagger.py b/dataset_processor/tools/tagger.py similarity index 97% rename from module/tools/tagger.py rename to dataset_processor/tools/tagger.py index 709abd5..8f0c0e4 100644 --- a/module/tools/tagger.py +++ b/dataset_processor/tools/tagger.py @@ -6,7 +6,7 @@ import cv2 -from module import Data +from dataset_processor import Data from dataclasses import dataclass, field import os from enum import Enum @@ -71,11 +71,7 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset): def __init__(self,data_list:list[Data]) -> None: self.dataset = [] for data in data_list: - self.dataset.append( - {'img':data.img, - 'sorce_data':data - } - ) + self.dataset.append({'img':data.img,'sorce_data':data}) def __len__(self): return len(self.dataset) @@ -235,3 +231,10 @@ def tag_data_list(self,data_list:list[Data]): if len(b_imgs) > 0: b_imgs = [(sorce_data, image) for sorce_data, image in b_imgs] # Convert image_path to string self.run_batch(b_imgs) + + def tag_data(self,data:Data): + img = preprocess_image(data.img) + b_imgs = [(data,img)] + self.run_batch(b_imgs) + return data + diff --git a/dataset_processor/tools/upscale.py b/dataset_processor/tools/upscale.py new file mode 100644 index 0000000..67b3efd --- /dev/null +++ b/dataset_processor/tools/upscale.py @@ -0,0 +1,191 @@ +import numpy as np +from torch import nn as nn +from PIL.Image import Image, fromarray +from basicsr.archs.rrdbnet_arch import RRDBNet +from basicsr.utils.download_util import load_file_from_url +from huggingface_hub import hf_hub_download +from realesrgan import RealESRGANer as RealESRGANModel +from realcugan_ncnn_py import Realcugan as RealcuganModel + +from dataclasses import dataclass, field +import os +from enum import Enum, auto as enumauto + +from dataset_processor import Data + + +class ModelType(Enum): + R_ESRGAN_2X = enumauto() + R_ESRGAN_4X = enumauto() + R_ESRNET_4X = enumauto() + R_ESRGAN_ANIME6B_4X = enumauto() + R_CUGAN_2X_CON = enumauto() + R_CUGAN_2X_ND = enumauto() + R_CUGAN_2X_D1 = enumauto() + R_CUGAN_2X_D2 = enumauto() + R_CUGAN_2X_D3 = enumauto() + R_CUGAN_3X_CON = enumauto() + R_CUGAN_3X_ND = enumauto() + R_CUGAN_3X_D3 = enumauto() + R_CUGAN_4X_CON = enumauto() + R_CUGAN_4X_ND = enumauto() + R_CUGAN_4X_D3 = enumauto() + CUSTOM = enumauto() + + +@dataclass +class UpcaleOption: + force_download: bool = field(default=False) + model_type: ModelType = field(default=ModelType.R_ESRGAN_2X) + model_path: str = field(default="./models") + custom_model_name: str = field(default="") + custom_model: nn.Module = field(default=None) + custom_scale: int = field(default=2) + tile: int = field(default=512) + tile_pad: int = field(default=10) + pre_pad: int = field(default=10) + half: bool = field(default=True) + gpuid: int = field(default=0) + + +class CustomModelError(RuntimeError): ... + + +class UpscaleModel(): + REAL_ESRGAN_MODEL = [ + ModelType.R_ESRGAN_2X, + ModelType.R_ESRGAN_4X, + ModelType.R_ESRNET_4X, + ModelType.R_ESRGAN_ANIME6B_4X + ] + REAL_CUGAN_MODEL = [ + ModelType.R_CUGAN_2X_CON, + ModelType.R_CUGAN_2X_ND, + ModelType.R_CUGAN_2X_D1, + ModelType.R_CUGAN_2X_D2, + ModelType.R_CUGAN_2X_D3, + ModelType.R_CUGAN_3X_CON, + ModelType.R_CUGAN_3X_ND, + ModelType.R_CUGAN_3X_D3, + ModelType.R_CUGAN_4X_CON, + ModelType.R_CUGAN_4X_ND, + ModelType.R_CUGAN_4X_D3 + ] + + def __init__(self, option: UpcaleOption | None = UpcaleOption()): + print("Init upscale...") + self.realesrgan = None + self.realcugan = None + match option.model_type.value: + case ModelType.R_ESRGAN_2X.value: + url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth' + file = "RealESRGAN_x2plus.pth" + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + scale = 2 + case ModelType.R_ESRGAN_4X.value: + url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth' + file = "RealESRGAN_x4plus.pth" + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + scale = 4 + case ModelType.R_ESRNET_4X.value: + url = "'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'" + file = "RealESRNet_x4plus.pth" + scale = 8 + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + case ModelType.R_ESRGAN_ANIME6B_4X.value: + url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" + file = "RealESRGAN_x4plus_anime_6B.pth" + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + scale = 4 + case ModelType.R_CUGAN_2X_CON.value: + model = "models-se" + noise = -1 + scale = 2 + case ModelType.R_CUGAN_2X_ND.value: + model = "models-se" + noise = 0 + scale = 2 + case ModelType.R_CUGAN_2X_D1.value: + model = "models-se" + noise = 1 + scale = 2 + case ModelType.R_CUGAN_2X_D2.value: + model = "models-se" + noise = 2 + scale = 2 + case ModelType.R_CUGAN_2X_D3.value: + model = "models-se" + noise = 3 + scale = 2 + case ModelType.R_CUGAN_3X_CON.value: + model = "models-se" + noise = -1 + scale = 3 + case ModelType.R_CUGAN_3X_ND.value: + model = "models-se" + noise = 0 + scale = 3 + case ModelType.R_CUGAN_3X_D3.value: + model = "models-se" + noise = 3 + scale = 3 + case ModelType.R_CUGAN_4X_CON.value: + model = "models-se" + noise = -1 + scale = 4 + case ModelType.R_CUGAN_4X_ND.value: + model = "models-se" + noise = 0 + scale = 4 + case ModelType.R_CUGAN_4X_D3.value: + model = "models-se" + noise = 3 + scale = 4 + case ModelType.CUSTOM.value: + try: + file = option.custom_model_name + model = option.custom_model + scale = option.custom_scale + if not os.path.exists(os.path.join(option.model_path, file)): + raise ChildProcessError + except CustomModelError: + print("UpcaleOption:custom_model is not exist!") + exit(1) + case _: + raise RuntimeError + print("Loading upscale model...") + if option.model_type in self.REAL_ESRGAN_MODEL: # 这是一个过度办法,将来我会把这些源换成从抱脸下载 + if os.path.exists(os.path.join(option.model_path, file)): + model_path = os.path.join(option.model_path, file) + else: + model_path = os.path.join(option.model_path, "real_esrgan") + if option.model_type is not ModelType.CUSTOM.value and ( + not os.path.exists(os.path.join(model_path, file)) or option.force_download): + model_path = load_file_from_url( + url=url, model_dir=option.model_path, progress=True, file_name=None) + tile = option.tile + tile_pad = option.tile_pad + pre_pad = option.pre_pad + half = option.half + gpuid = option.gpuid + self.realesrgan = RealESRGANModel( + scale=scale, + model_path=model_path, + model=model, + tile=tile, + tile_pad=tile_pad, + pre_pad=pre_pad, + half=half, + gpu_id=gpuid) + if option.model_type in self.REAL_CUGAN_MODEL: + tile_size = option.tile + gpuid = option.gpuid + self.realcugan = RealcuganModel(gpuid, noise=noise, scale=scale, model=model, tilesize=tile_size) + + def upscale_data(self, data: Data) -> Image: + if self.realesrgan: + np_img = np.array(data.img) + np_img, _ = self.realesrgan.enhance(np_img) + return fromarray(np_img) + if self.realcugan: + return self.realcugan.process_pil(data.img) diff --git a/dataset_processor/uitl.py b/dataset_processor/uitl.py new file mode 100644 index 0000000..85ccbc6 --- /dev/null +++ b/dataset_processor/uitl.py @@ -0,0 +1,289 @@ +from tqdm import tqdm + +from dataset_processor import Data +from dataset_processor import Filter +from dataset_processor import Processor, ProcessorError +from .tools.tagger import Tagger, TaggerOption, ModelType as TaggerType +from .tools.upscale import UpcaleOption, UpscaleModel, ModelType as UpscaleType +import copy +import os + +# 文件分类 +IMG_EXT = [".png", ".jpg"] # 支持的图片格式 +TEXT_EXT = [".txt"] # 支持的标签格式 + + +def tagger_builder(args: dict) -> Tagger: + option = TaggerOption() + if args.get('model_path'): + option.model_path = args['model_path'] + if args.get('model_type'): + try: + option.model_type = TaggerType[args['model_type']] + except KeyError: + print(f"Invalid type:{args['model_type']}") + if args.get('force_download'): + option.force_download = args['force_download'] + if args.get('undesired_tags'): + option.undesired_tags = args['undesired_tags'] + if args.get('batch_size'): + option.batch_size = args['batch_size'] + if args.get('max_data_loader_n_workers'): + option.max_data_loader_n_workers = args['max_data_loader_n_workers'] + if args.get('remove_underscore'): + option.remove_underscore = args['remove_underscore'] + if args.get('thresh'): + option.thresh = args['thresh'] + if args.get('character_threshold'): + option.character_threshold = args['character_threshold'] + if args.get('general_threshold'): + option.general_threshold = args['general_threshold'] + return Tagger(option) + + +def upscale_model_builder(args: dict) -> UpscaleModel: + option = UpcaleOption() + if args.get('model_path'): + UpcaleOption.model_path = args['model_path'] + if args.get('force_download'): + UpcaleOption.force_download = args['force_download'] + if args.get('model_type'): + try: + UpcaleOption.model_type = UpscaleType[args['model_type']] + except KeyError: + print(f"Invalid type:{args['model_type']}") + if args.get('tile'): + UpcaleOption.tile = args['tile'] + if args.get('tile_pad'): + UpcaleOption.tile_pad = args['tile_pad'] + if args.get('pre_pad'): + UpcaleOption.pre_pad = args['pre_pad'] + if args.get('half'): + UpcaleOption.half = args['half'] + return UpscaleModel(option) + + +class MainOption: + def __init__(self, args={}): + if args.get('save_source_name'): + self.save_source_name = args.get('save_source_name') + else: + self.save_source_name = False + + if args.get('save_conduct_id'): + self.save_conduct_id = args.get('save_conduct_id') + else: + self.save_conduct_id = False + + if args.get('save_sub'): + self.save_sub = args.get('save_sub') + else: + self.save_sub = False + + if args.get('clean_tag'): + self.clean_tag = args.get('clean_tag') + else: + self.clean_tag = True + + if args.get('tag_no_paired_data'): + self.tag_no_paired_data = args.get('tag_no_paired_data') + else: + self.tag_no_paired_data = True + + if args.get('force_tag_all'): + self.force_tag_all = args.get('force_tag_all') + else: + self.force_tag_all = False + + +class DatasetProcessor: + """ + 构建DatasetProcessor对象以开始数据处理 + """ + upscale: UpscaleModel = None + tagger: Tagger = None + option: MainOption = None + + def data_list_builder(self, input_dir: str) -> list[Data]: + ... + + def pair_token(self, token_file_list: list, data_list: list[Data]): + ... + + def __init__(self, + input_dir: str, + output_dir: str, + conduct: dict, + option: dict | None = None, + tagger: dict | None = None, + upscale: dict | None = None + ): + self.input_dir = input_dir + self.conduct = conduct + if not os.path.exists(output_dir): + os.makedirs(output_dir) + self.output_dir = output_dir + if tagger and tagger.get('active'): self.tagger = tagger_builder(tagger) + if upscale and upscale.get('active'): self.upscale = upscale_model_builder(upscale) + if option: + self.option = MainOption(option) + else: + self.option = MainOption() + self.data_list = self.data_list_builder(input_dir) + + # 匹配标签 + def pair_token(self, token_file_list: list, data_list: list[Data]): + no_paired_data_list = [] + for data in data_list: + for file_name in token_file_list: + splitext = os.path.splitext(file_name) + name = splitext[0] + if name == data.name: + data.input_token(file_name, self.option) + token_file_list.remove(file_name) + if not data.token: + no_paired_data_list.append(data) + return no_paired_data_list + + # 读取文件并建立列表 + def data_list_builder(self, input_dir: str) -> list[Data]: + data_list: list[Data] = [] + token_list = [] + no_paired_data_list = [] + count = 0 + print("load files...\n开始读取文件...") + for file_name in tqdm(os.listdir(input_dir)): + splitext = os.path.splitext(file_name) + name = splitext[0] + ext = splitext[1] + if ext in IMG_EXT: + img = Data(input_dir, name, ext) + data_list.append(img) + count += 1 + if ext in TEXT_EXT: + token_list.append(file_name) + no_paired_data_list = self.pair_token(token_list, data_list) + token_list.clear() + print( + "一共读取" + str(count) + "张图片,其中有" + + str(no_paired_data_list.__len__()) + "张图片没有配对的标签" + ) + if self.tagger: + if self.option.tag_no_paired_data and no_paired_data_list != []: + print("已启用对未标签的图片进行打标") + self.tagger.tag_data_list(no_paired_data_list) + if self.option.force_tag_all: + print("已强制对所有图片进行机器标注") + self.tagger.tag_data_list(data_list) + return data_list + + # 过滤器管理 + def filter_manager(self, filter_list: list, data: Data) -> bool: + flag = False + for filter in filter_list: + fun = getattr(Filter, filter.get('filter')) + if filter.get('arg'): + if fun(data, filter.get('arg')): return True + else: + if fun(data): return True + return False + + # 处理器管理 + def processor_manager(self, processor_list: list, data: Data): + for processor in processor_list: + try: + fun = getattr(Processor, processor.get('method')) + if fun == Processor.tag_image: + if self.tagger is None: + raise NoneTaggerError('tag_image') + data = fun(data, self.tagger) + elif fun == Processor.upscale_image: + if self.upscale is None: + raise NoneUpscaleError('upscale_image') + data = fun(data, self.upscale) + elif bool(processor.get("arg")): + data = fun(data, processor.get("arg")) + else: + data = fun(data) + except ProcessorError: + raise ProcessorError + except AttributeError: + print(f"\nError:Invalid method: {processor.get('method')}\nPlease check the config file") + exit(1) + except NoneUpscaleError as e: + print(f"\nError:{e.name} is faild!") + print("Upscale is not active!Please add this commit in config:") + print("======================") + print("upscale:\n active: True") + print("======================") + exit(1) + except NoneTaggerError as e: + print(f"\nError:{e.name} is faild!") + print("Tagger is not active!Please add this commit in config:") + print("======================") + print("Tagger:\n active: True") + print("======================") + exit(1) + return data + + def conduct_manager(self, conducts: list[dict], data_list: list[Data]) -> list[Data]: + """ + 处理行为管理函数,虽然可以接受data_list,但是存在文件名碰撞隐患 + 推荐只传入一个data对象 + """ + return_list = [] + output_dir = self.output_dir + for conduct in conducts: + if conduct.get('sub_conduct'): + sub_data_list = [copy.copy(data) for data in data_list] + for data in sub_data_list: + data.conduct += "_sub[" + sub_data_list = self.conduct_manager(conduct.get('sub_conduct'), sub_data_list) + if sub_data_list: + for data in sub_data_list: + data.conduct += "]" + data_list = copy.deepcopy(sub_data_list) + if self.option.save_sub: + sub_output = os.path.join(output_dir, "sub") + if not (os.path.exists(sub_output)): + os.mkdir(sub_output) + for sub_data in sub_data_list: + sub_data.save(sub_output, self.option) + for data in data_list: + filters = conduct.get('filters') + if filters: + if self.filter_manager(filters, data): continue + if bool(conduct.get('repeat')): + repeat = conduct.get('repeat') + else: + repeat = 1 + for j in range(0, repeat): + data.repeat = j + try: + return_list.append(self.processor_manager(conduct.get('processor'), copy.deepcopy(data))) + except ProcessorError: + break + return return_list + + def main(self): + """ + 主入口 + """ + print("开始图片处理...") + for i in tqdm(range(0, len(self.data_list))): + data = self.data_list.pop() + data.id = i + data_list = self.conduct_manager(self.conduct, [data]) + if data_list: + for data in data_list: + data.save(self.output_dir, self.option) + + +class NoneTaggerError(RuntimeError): + def __init__(self, name): + self.name = name + + +class NoneUpscaleError(RuntimeError): + def __init__(self, name): + self.name = name diff --git a/doc/doc_cn.md b/doc/doc_cn.md index 254f9f4..1f540e4 100644 --- a/doc/doc_cn.md +++ b/doc/doc_cn.md @@ -65,11 +65,12 @@ conduct: |名称|说明| |--|--| -|save_sorce_name|保存原文件名称| +|save_source_name|保存原文件名称| |save_conduct_id|保存处理id| -|save_repeat|保存重复次数| |save_sub|保存子处理| -|clean_tag|清理标签(将"_"换成空格,给括号加上"\")| +|clean_tag|清洗标签(将"_"换成空格,给括号加上"\"),默认开启| +|tag_no_paired_data|自动对没有标签的图片进行打标,需要配置`tagger`,默认开启| +|force_tag_all|强制对所有图片进行打标,需要配置`tagger`| 以下是示例: @@ -96,6 +97,7 @@ option: |color_enhancement|_col_e|饱和度增强|-| |random_enhancement|_ran_e|随机增强|-| |none|-|不做操作(用于特定场合)|-| +|upscale_image|-|放大图片,要使用这个方法,请配置`upscale`|-| ### 标签处理 @@ -106,6 +108,7 @@ option: |insert_tag|在标签组开头插入标签|标签(文本)| |tag_move_forward|将选定标签移到开头|标签(文本)| |rename_tag|重命名标签,将`A标签`重命名为`B标签`|['A标签','B标签']| +|tag_image|对图片进行打标并覆盖原来的标签,要使用这个方法,请配置`tagger`|-| ## 过滤器说明 @@ -115,7 +118,7 @@ option: | -- | -- | -- | |img_size|过虑特定尺寸的图片|数组,输入格式为[max,min],缺省填-1 | |tag_filter|过滤掉特定标签|标签(文本)| -|tag_selecter|须要包含特定标签|标签(文本)| +|tag_selector|须要包含特定标签|标签(文本)| |tag_is_not_none|只含有带标签的图片|-| |tag_is_none|只含有不带标签的图片|-| @@ -154,13 +157,26 @@ conduct: ## 自动打标设置说明 -在Tagger后添加`active: True`可以启用自动打标 +在配置中添加添加以下条目即可启用自动打标: -如果你已经全部完成打标了,依旧打开此项会大大拖慢速度(花时间读取模型),所以请自行选择是否打开 +``` yaml +Tagger: + active: True +``` + +如果你已经全部完成打标了,依旧打开此项会大大拖慢速度(花时间读取模型),所以请结合实际情况自行选择是否打开 + +你可以像这样来配置打标设置: -目前只支持对未打标文件进行打标 +``` yaml +Tagger: + active: True + model_type: WD14_MOAT +``` + +如果你不清楚是什么,请保持默认 -以下这下配置是可选的,如果你不清楚是什么,请保持默认 +### 配置项说明 |名称|说明|参数| |--|--|--| @@ -186,3 +202,58 @@ conduct: |WD14_SWINV2|[链接](https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2)|0.3771|0.6854| |WD14_CONVNEXT|[链接](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)|0.3685|0.6810| |WD14_CONVNEXT2|[链接](https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2)|0.3710|0.6862| + +## 图片放大说明 + +在配置中添加添加以下条目即可启用图片放大: + +``` yaml +upscale: + active: True +``` + +你可以像这样来配置打标设置: + +``` yaml +upscale: + active: True + model_type: R_CUGAN_2X_CON +``` + +如果你不清楚是什么,请保持默认 + +### 配置项说明 + +|名称|说明|参数| +|--|--|--| +|active|启用自动图片放大|布尔值| +|model_path|模型路径,下载的模型都会放在此文件夹内(仅适用于Real_ESRGAN)|路径| +|model_type|模型种类,具体看下一章|模型种类| +|force_download|强制下载模型(仅适用于Real_ESRGAN)|布尔值| +|tile|切分图片,减少显存占用,0为不裁切,默认是512|每块切片的分辨率(整型)| +|tile_pad|切分pad尺寸,用于减轻合并伪影,默认是10(仅适用于Real_ESRGAN)|pad分辨率(整型)| +|pre_pad|pad填充像素,用于减轻合并伪影,默认是10(仅适用于Real_ESRGAN)|pad填充像素(整型)| +|half|半精度,如果您是20系或者更高,推荐打开来加速(仅适用于Real_ESRGAN)|布尔值| + +### 放大模型种类 + +默认参数为`R_ESRGAN_2X`可以按照喜好自行选择 + +|值|说明| +|--|--| +|R_ESRGAN_2X|Real_ESRGAN算法,2X代表2倍放大,下同| +|R_ESRGAN_4X|-| +|R_ESRGAN_8X|-| +|R_ESRNET_4X|仅在Real_ESRGAN库中支持,作者尚未验证| +|R_ESRGAN_ANIME6B_4X|Real_ESRGAN针对二次元训练的算法,仅有4X放大| +|R_CUGAN_2X_CON|Real_CUGAN算法,针对二次元的AI放大算法,2X代表2倍放大,CON代表保守降噪策略,推荐原图清晰度较高下使用| +|R_CUGAN_2X_ND|同上,ND代表不降噪,推荐原图清晰度非常高的情况下使用| +|R_CUGAN_2X_D3|同上,D3代表3级降噪,等级越高降噪程度越高,仅有2X模型降噪分为三个等级,其余均只有3级降噪,推荐原图清晰度不高的情况下使用| +|R_CUGAN_2X_D2|-| +|R_CUGAN_2X_D1|-| +|R_CUGAN_3X_CON|-| +|R_CUGAN_3X_ND|-| +|R_CUGAN_3X_D3|-| +|R_CUGAN_4X_CON|-| +|R_CUGAN_4X_ND|-| +|R_CUGAN_4X_D3|-| diff --git a/main.py b/main.py index e7caab7..a464825 100644 --- a/main.py +++ b/main.py @@ -1,38 +1,43 @@ -from module import * +from dataset_processor import * -from module.tools.tagger import Tagger - -import os import yaml - -def main(input_dir, output_dir, conducts,option:dict|None=None, tagger:Tagger|None=None): - data_list = data_list_builder(input_dir,tagger) - if not (os.path.exists(output_dir)): - os.mkdir(output_dir) - i = 0 - for i in range(0, len(data_list)): - data = data_list.pop() - data.id = i - data = conduct_manager(conducts,data,output_dir,option) - if data is None: - continue - else: - data.save(output_dir,option) - +from argparse import ArgumentParser if __name__ == "__main__": - with open("./conf.yaml", "r", encoding="utf-8") as f: + parser = ArgumentParser() + parser.add_argument( + '--input_dir', + default=None, + type=str, + help='input dir,if used,it will cover config ''input_dir''//数据集输入路径,如果指定则会覆盖配置文件中的''input_dir''' + ) + parser.add_argument( + '--output_dir', + default=None, + type=str, + help='output dir,if used,it will cover config ''output_dir''//数据集输入路径,如果指定则会覆盖配置文件中的''output_dir''' + ) + parser.add_argument( + '--config', + default='./conf.yaml', + type=str, + help='yaml config path,default to reading conf.yaml in the root directory//指定yaml配置文件,默认读取根目录下conf.yaml' + ) + args = parser.parse_args() + with open(args.config, "r", encoding="utf-8") as f: config = yaml.load(f.read(), yaml.FullLoader) # 设置 - input_dir = config.get('path').get('input') # 输入目录 - output_dir = config.get('path').get('output') # 输出目录 + if args.input_dir: + input_dir = args.input_dir + else: + input_dir = config.get('path').get('input') # 输入目录 + if args.input_dir: + output_dir = args.output_dir + else: + output_dir = config.get('path').get('output') # 输出目录 # 参数 conducts = config.get('conduct') option = config.get('option') tagger = config.get('tagger') - - if tagger: - if tagger['active']: - tagger=tagger_bulider(tagger) - else: tagger=None - main(input_dir,output_dir,conducts,option,tagger=tagger) + upscale = config.get('upscale') + DatasetProcessor(input_dir,output_dir,conducts,option,tagger,upscale).main() \ No newline at end of file diff --git a/module/data.py b/module/data.py deleted file mode 100644 index caa9de3..0000000 --- a/module/data.py +++ /dev/null @@ -1,57 +0,0 @@ -from PIL import Image -import os - - -class Data: - - # 图片读取并初始化 - def __init__(self, path: str, name: str, ext: str): - self.token: list[str] = [] - self.conduct = "" - self.repeat = 0 - self.id = 0 - self.name = name - self.ext = ext - self.path = path - # 读取图片 - self.img = Image.open(os.path.join(path, name + ext)) - self.size = self.img.size - - # 载入标签 - def input_token(self, file_name: str,option:dict|None=None): - NO_CHECK=[ #清洗排除标签 - ':)',';)',':(','>:)','>:(','\(^o^)/', #括号相关 - '^_^','@_@','>_@','+_+','+_-','o_o','0_0','|_|','._.','>_<','=_=','_','<|>_<|>' #下划线相关 - ] - if option: - if option.get(clean_tag): - clean_tag = True - with open(os.path.join(self.path, file_name), "r") as f: - self.token = f.read(-1).split(",") - for tag in self.token: - tag = tag.strip() - if clean_tag: - if tag not in NO_CHECK: - tag = tag.replace("_"," ") - tag = tag.replace("(","\\(") - tag = tag.replace(")","\\)") - - # 保存的方法 - def save(self, output_dir,option:dict|None=None): - #默认命名方式:id_conduct_repeat.ext 比如"000001_r_0.jpg" - save_name = str(self.id).zfill(6) + self.conduct +"_"+ str(self.repeat) - if option: - if option.get('save_sorce_name') or option.get('save_conduct_id') or option.get('save_repeat'): - save_name=str(self.id).zfill(6) - if option.get('save_sorce_name'): - save_name = save_name.join('_'+self.name) - if option.get('save_conduct_id'): - save_name = save_name.join(self.conduct) - if option.get('save_repeat'): - save_name = save_name.join('_'+self.repeat) - self.img.save(os.path.join(output_dir, save_name + self.ext)) - # print(save_name) - with open(os.path.join(output_dir, save_name + ".txt"), mode="w") as f: - text = ",".join(self.token) - f.write(text) - self.img.close() diff --git a/module/filter.py b/module/filter.py deleted file mode 100644 index 2fdbd2c..0000000 --- a/module/filter.py +++ /dev/null @@ -1,31 +0,0 @@ -from .data import Data -class Filter: - """ - 此处用于编写过滤器 - """ - def img_size(data:Data,size:list): - min,max = tuple(size) - x,y = data.size - if min != -1: - if data.size[0] <= min or data.size[1] <= min: - return True - if max != -1: - if data.size[0] > max and data.size[1] > max: - return True - else: return False - - def tag_filter(data:Data,tag): - if tag in data.token: return True - else: return False - - def tag_selecter(data:Data,tag): - if tag in data.token: return False - else: return True - - def tag_is_not_none(data:Data): - if data.token: return False - else: return True - - def tag_is_none(data:Data): - if data.token: return True - else: return False \ No newline at end of file diff --git a/module/uitl.py b/module/uitl.py deleted file mode 100644 index 92bfb72..0000000 --- a/module/uitl.py +++ /dev/null @@ -1,137 +0,0 @@ -from module import Data -from module import Filter -from module import Processor, ProcessorError -from .tools.tagger import Tagger,TaggerOption -import copy -import os - -# 文件分类 -IMG_EXT = [".png", ".jpg"] # 支持的图片格式 -TEXT_EXT = [".txt"] # 支持的标签格式 - - -# 匹配标签 -def pair_token(token_list: list, data_list: list[Data],option:dict|None=None): - no_paired_data_list = [] - for data in data_list: - for file_name in token_list: - splitext = os.path.splitext(file_name) - name = splitext[0] - if name == data.name: - data.input_token(file_name,option) - token_list.remove(file_name) - if not data.token: - no_paired_data_list.append(data) - return no_paired_data_list - - -# 读入文件的方法 -def data_list_builder(input_dir:str,option:dict|None=None,tagger:Tagger|None=None) -> list[Data]: - data_list: list[Data] = [] - token_list = [] - no_paired_data_list = [] - count = 0 - for file_name in os.listdir(input_dir): # 读取图片 - splitext = os.path.splitext(file_name) - name = splitext[0] - ext = splitext[1] - # 如果是图片 - if ext in IMG_EXT: - img = Data(input_dir, name, ext) - data_list.append(img) - count += 1 - if ext in TEXT_EXT: - token_list.append(file_name) - no_paired_data_list = pair_token(token_list, data_list,option) - token_list.clear() - print( - "一共读取" + str(count) + "张图片,其中有" + - str(no_paired_data_list.__len__()) + "张图片没有配对的标签" - ) - if tagger: - print("已启用打标") - tagger.tag_data_list(no_paired_data_list) - return data_list - - -def filter_manager(filter_list: list, data: Data) -> bool: - flag=False - for filter in filter_list: - fun = getattr(Filter, filter.get('filter')) - if filter.get('arg'): - if fun(data, filter.get('arg')): return True - else: - if fun(data): return True - return False - - -def processor_manager(processor_list: list, data: Data): - for processor in processor_list: - try: - fun = getattr(Processor, processor.get('method')) - if bool(processor.get("arg")): - data = fun(data, processor.get("arg")) - else: - data = fun(data) - except ProcessorError: - raise ProcessorError - except AttributeError: - print("输入错误:不存在的method:"+processor.get('method')+"\n请检查配置文件") - exit(1) - return data - - -def conduct_manager(conducts:list,data:Data,output_dir:str,option:dict|None=None)->Data: - new_data = copy.deepcopy(data) - for conduct in conducts: - if conduct.get('sub_conduct'): - sub_data = copy.copy(new_data) - sub_data.conduct += "_sub[" - sub_data = conduct_manager(conduct.get('sub_conduct'),sub_data,output_dir,option) - if sub_data is not None: - sub_data.conduct += "]" - new_data.img = sub_data.img.copy() - new_data.conduct += sub_data.conduct - if option.get('save_sub'): - sub_output=os.path.join(output_dir,"sub") - if not (os.path.exists(sub_output)): - os.mkdir(sub_output) - sub_data.save(sub_output,option) - filters = conduct.get('filters') - if filters: - if filter_manager(filters, new_data): continue - if bool(conduct.get('repeat')): - repeat = conduct.get('repeat') - else: - repeat = 1 - for j in range(0, repeat): - new_data.repeat = j - try: - return processor_manager(conduct.get('processor'), new_data) - except ProcessorError: - break - - -def tagger_bulider(args:dict)->Tagger: - option = TaggerOption() - if args.get('model_path'): - option.model_path = args['model_path'] - if args.get('model_type'): - option.model_type = args['model_type'] - if args.get('force_download'): - option.force_download = args['force_download'] - if args.get('undesired_tags'): - option.undesired_tags = args['undesired_tags'] - if args.get('batch_size'): - option.batch_size = args['batch_size'] - if args.get('max_data_loader_n_workers'): - option.max_data_loader_n_workers = args['max_data_loader_n_workers'] - if args.get('remove_underscore'): - option.remove_underscore = args['remove_underscore'] - if args.get('thresh'): - option.thresh = args['thresh'] - if args.get('character_threshold'): - option.character_threshold = args['character_threshold'] - if args.get('general_threshold'): - option.general_threshold = args['general_threshold'] - return Tagger(option) diff --git a/requirements.txt b/requirements.txt index 99c3305..04ce9c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch>=2.0.0 +torch~=2.0.1 numpy~=1.24.3 opencv-python~=4.8.0.74 keras~=2.13.1 @@ -6,4 +6,6 @@ tqdm~=4.65.0 Pillow~=10.0.0 PyYAML~=6.0.1 huggingface_hub -tensorflow \ No newline at end of file +tensorflow +realesrgan +realcugan-ncnn-py \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..9f786b4 --- /dev/null +++ b/setup.py @@ -0,0 +1,32 @@ +from setuptools import setup, find_packages +import os + +requires = [] +with open('requirements.txt', encoding='utf8') as f: + for x in f.readlines(): + requires.append(f'{x.strip()}') + +data_files = [('conf',['conf.yaml'])] + +for f in os.listdir('./doc'): + data_files.append(('doc',['doc/'+f])) + +setup( + name='dataset_processor', + version='0.3.0', + packages=['dataset_processor', 'dataset_processor.tools'], + url='https://github.com/waterminer/SD-DatasetProcessor', + license='GPLv3', + author='Water_miner', + author_email='420773173@qq.com', + description='A dataset preprocess toolkit', + classifiers=[ + 'Development Status :: 3 - Alpha', + 'Environment :: GPU :: NVIDIA CUDA :: 11.8', + 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)', + 'Programming Language :: Python :: 3.10', + ], + install_requires=requires, + data_files=data_files +) +