From 7fe81502d04c1f68c85f276517e7144e6378c484 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 May 2024 11:09:32 +0900 Subject: [PATCH] update loraplus on dylora/lofa_fa --- networks/dylora.py | 46 ++++++++++++++++++++++++--------------- networks/lora.py | 7 +++++- networks/lora_fa.py | 52 +++++++++++++++++++++++++++++++-------------- 3 files changed, 71 insertions(+), 34 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index 0546fc7ae..0d1701ded 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -18,10 +18,13 @@ import torch from torch import nn from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + class DyLoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -195,7 +198,7 @@ def create_network( conv_alpha = 1.0 else: conv_alpha = float(conv_alpha) - + if unit is not None: unit = int(unit) else: @@ -211,6 +214,16 @@ def create_network( unit=unit, varbose=True, ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + return network @@ -280,6 +293,10 @@ def __init__( self.alpha = alpha self.apply_to_conv = apply_to_conv + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + if modules_dim is not None: logger.info("create LoRA network from weights") else: @@ -320,9 +337,9 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) loras.append(lora) return loras - + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] - + self.text_encoder_loras = [] for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: @@ -331,7 +348,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules else: index = None logger.info("create LoRA for Text Encoder") - + text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) @@ -346,6 +363,11 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules self.unet_loras = create_modules(True, unet, target_modules) logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -407,15 +429,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): """ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params( - self, - text_encoder_lr, - unet_lr, - default_lr, - text_encoder_loraplus_ratio=None, - unet_loraplus_ratio=None, - loraplus_ratio=None - ): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] @@ -452,15 +466,13 @@ def assemble_params(loras, lr, ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_loraplus_ratio or loraplus_ratio + self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio, ) all_params.extend(params) if self.unet_loras: params = assemble_params( - self.unet_loras, - default_lr if unet_lr is None else unet_lr, - unet_loraplus_ratio or loraplus_ratio + self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_ratio ) all_params.extend(params) diff --git a/networks/lora.py b/networks/lora.py index 61b8cd5a7..6e5645577 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -499,7 +499,8 @@ def create_network( loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None - network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) if block_lr_weight is not None: network.set_block_lr_weight(block_lr_weight) @@ -855,6 +856,10 @@ def __init__( self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + if modules_dim is not None: logger.info(f"create LoRA network from weights") elif block_dims is not None: diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 9a608118a..58bcb2206 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -15,8 +15,10 @@ import torch import re from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -504,6 +506,15 @@ def create_network( if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + return network @@ -529,7 +540,9 @@ def parse_floats(s): len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning( + f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" + ) block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -803,11 +816,17 @@ def __init__( self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + if modules_dim is not None: logger.info(f"create LoRA network from weights") elif block_dims is not None: logger.info(f"create LoRA network from block_dims") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) logger.info(f"block_dims: {block_dims}") logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: @@ -815,9 +834,13 @@ def __init__( logger.info(f"conv_block_alphas: {conv_block_alphas}") else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) if self.conv_lora_dim is not None: - logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + ) # create module instances def create_modules( @@ -939,6 +962,11 @@ def create_modules( assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -1033,15 +1061,7 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params( - self, - text_encoder_lr, - unet_lr, - default_lr, - text_encoder_loraplus_ratio=None, - unet_loraplus_ratio=None, - loraplus_ratio=None - ): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] @@ -1078,7 +1098,7 @@ def assemble_params(loras, lr, ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_loraplus_ratio or loraplus_ratio + self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio, ) all_params.extend(params) @@ -1097,7 +1117,7 @@ def assemble_params(loras, lr, ratio): params = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_loraplus_ratio or loraplus_ratio + self.loraplus_unet_lr_ratio or self.loraplus_ratio, ) all_params.extend(params) @@ -1105,7 +1125,7 @@ def assemble_params(loras, lr, ratio): params = assemble_params( self.unet_loras, unet_lr if unet_lr is not None else default_lr, - unet_loraplus_ratio or loraplus_ratio + self.loraplus_unet_lr_ratio or self.loraplus_ratio, ) all_params.extend(params)