Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace print with logger if they are logs #905

Merged
merged 10 commits into from
Feb 4, 2024
12 changes: 7 additions & 5 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

from library.utils import get_my_logger
logger = get_my_logger(__name__)
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
Expand Down Expand Up @@ -51,11 +53,11 @@ def train(args):
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
Expand Down Expand Up @@ -88,7 +90,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
Expand All @@ -99,7 +101,7 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"

# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)

# mixed precisionに対応した型を用意しておき適宜castする
Expand Down Expand Up @@ -460,7 +462,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
logger.info("model saved.")


def setup_parser() -> argparse.ArgumentParser:
Expand Down
4 changes: 3 additions & 1 deletion finetune/blip/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
from library.utils import get_my_logger
logger = get_my_logger(__name__)

class BLIP_Base(nn.Module):
def __init__(self,
Expand Down Expand Up @@ -235,6 +237,6 @@ def load_checkpoint(model,url_or_filename):
del state_dict[key]

msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
logger.info('load checkpoint from %s'%url_or_filename)
return model,msg

40 changes: 21 additions & 19 deletions finetune/clean_captions_and_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import re

from tqdm import tqdm
from library.utils import get_my_logger
logger = get_my_logger(__name__)

PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
Expand Down Expand Up @@ -36,13 +38,13 @@ def clean_tags(image_key, tags):
tokens = tags.split(", rating")
if len(tokens) == 1:
# WD14 taggerのときはこちらになるのでメッセージは出さない
# print("no rating:")
# print(f"{image_key} {tags}")
# logger.info("no rating:")
# logger.info(f"{image_key} {tags}")
pass
else:
if len(tokens) > 2:
print("multiple ratings:")
print(f"{image_key} {tags}")
logger.info("multiple ratings:")
logger.info(f"{image_key} {tags}")
tags = tokens[0]

tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
Expand Down Expand Up @@ -124,43 +126,43 @@ def clean_caption(caption):

def main(args):
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
logger.info(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
print("no metadata / メタデータファイルがありません")
logger.error("no metadata / メタデータファイルがありません")
return

print("cleaning captions and tags.")
logger.info("cleaning captions and tags.")
image_keys = list(metadata.keys())
for image_key in tqdm(image_keys):
tags = metadata[image_key].get('tags')
if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}")
else:
org = tags
tags = clean_tags(image_key, tags)
metadata[image_key]['tags'] = tags
if args.debug and org != tags:
print("FROM: " + org)
print("TO: " + tags)
logger.info("FROM: " + org)
logger.info("TO: " + tags)

caption = metadata[image_key].get('caption')
if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
else:
org = caption
caption = clean_caption(caption)
metadata[image_key]['caption'] = caption
if args.debug and org != caption:
print("FROM: " + org)
print("TO: " + caption)
logger.info("FROM: " + org)
logger.info("TO: " + caption)

# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
logger.info(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
print("done!")
logger.info("done!")


def setup_parser() -> argparse.ArgumentParser:
Expand All @@ -178,10 +180,10 @@ def setup_parser() -> argparse.ArgumentParser:

args, unknown = parser.parse_known_args()
if len(unknown) == 1:
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
print("All captions and tags in the metadata are processed.")
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
print("メタデータ内のすべてのキャプションとタグが処理されます。")
logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
logger.warning("All captions and tags in the metadata are processed.")
logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。")
args.in_json = args.out_json
args.out_json = unknown[0]
elif len(unknown) > 0:
Expand Down
20 changes: 11 additions & 9 deletions finetune/make_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder
import library.train_util as train_util
from library.utils import get_my_logger
logger = get_my_logger(__name__)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -47,7 +49,7 @@ def __getitem__(self, idx):
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None

return (tensor, img_path)
Expand All @@ -74,19 +76,19 @@ def main(args):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path

cwd = os.getcwd()
print("Current Working Directory is: ", cwd)
logger.info(f"Current Working Directory is: {cwd}")
os.chdir("finetune")

print(f"load images from {args.train_data_dir}")
logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")

print(f"loading BLIP caption: {args.caption_weights}")
logger.info(f"loading BLIP caption: {args.caption_weights}")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
model.eval()
model = model.to(DEVICE)
print("BLIP loaded")
logger.info("BLIP loaded")

# captioningする
def run_batch(path_imgs):
Expand All @@ -106,7 +108,7 @@ def run_batch(path_imgs):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
logger.info(f'{image_path} {caption}')

# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
Expand Down Expand Up @@ -136,7 +138,7 @@ def run_batch(path_imgs):
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue

b_imgs.append((image_path, img_tensor))
Expand All @@ -146,7 +148,7 @@ def run_batch(path_imgs):
if len(b_imgs) > 0:
run_batch(b_imgs)

print("done!")
logger.info("done!")


def setup_parser() -> argparse.ArgumentParser:
Expand Down
21 changes: 11 additions & 10 deletions finetune/make_captions_by_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from transformers.generation.utils import GenerationMixin

import library.train_util as train_util

from library.utils import get_my_logger
logger = get_my_logger(__name__)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand All @@ -35,8 +36,8 @@ def remove_words(captions, debug):
for pat in PATTERN_REPLACE:
cap = pat.sub("", cap)
if debug and cap != caption:
print(caption)
print(cap)
logger.info(caption)
logger.info(cap)
removed_caps.append(cap)
return removed_caps

Expand Down Expand Up @@ -70,16 +71,16 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs)
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
"""

print(f"load images from {args.train_data_dir}")
logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")

# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
logger.info(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")
logger.info("GIT loaded")

# captioningする
def run_batch(path_imgs):
Expand All @@ -97,7 +98,7 @@ def run_batch(path_imgs):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
logger.info(f"{image_path} {caption}")

# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
Expand Down Expand Up @@ -126,7 +127,7 @@ def run_batch(path_imgs):
if image.mode != "RGB":
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue

b_imgs.append((image_path, image))
Expand All @@ -137,7 +138,7 @@ def run_batch(path_imgs):
if len(b_imgs) > 0:
run_batch(b_imgs)

print("done!")
logger.info("done!")


def setup_parser() -> argparse.ArgumentParser:
Expand Down
18 changes: 10 additions & 8 deletions finetune/merge_captions_to_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,28 @@
from tqdm import tqdm
import library.train_util as train_util
import os
from library.utils import get_my_logger
logger = get_my_logger(__name__)

def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"

train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")

if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json

if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}")
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
print("new metadata will be created / 新しいメタデータファイルが作成されます")
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}

print("merge caption texts to metadata json.")
logger.info("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip()
Expand All @@ -38,12 +40,12 @@ def main(args):

metadata[image_key]['caption'] = caption
if args.debug:
print(image_key, caption)
logger.info(f"{image_key} {caption}")

# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
print("done!")
logger.info("done!")


def setup_parser() -> argparse.ArgumentParser:
Expand Down
Loading