From 1933ab4b4848b1f8b578c10f25bd050f5e246ac0 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 3 Apr 2024 12:46:34 -0400 Subject: [PATCH] Fix default_lr being applied --- networks/dylora.py | 21 ++++++++++++++++++--- networks/lora.py | 30 +++++++++++++++++++++++------- networks/lora_fa.py | 30 +++++++++++++++++++++++------- 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index a73ade8bd..edc3e2229 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -407,7 +407,14 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): """ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -442,11 +449,19 @@ def assemble_params(loras, lr, lora_plus_ratio): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params diff --git a/networks/lora.py b/networks/lora.py index 8d7619777..e082941e5 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1035,7 +1035,14 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -1070,7 +1077,11 @@ def assemble_params(loras, lr, lora_plus_ratio): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: @@ -1085,14 +1096,19 @@ def assemble_params(loras, lr, lora_plus_ratio): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - if unet_lr is not None: - params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) - elif default_lr is not None: - params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + params = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + unet_lora_plus_ratio + ) all_params.extend(params) else: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params diff --git a/networks/lora_fa.py b/networks/lora_fa.py index fcc503e89..3f6774dd8 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1033,7 +1033,14 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -1068,7 +1075,11 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: @@ -1083,14 +1094,19 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - if unet_lr is not None: - params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) - elif default_lr is not None: - params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + params = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + unet_lora_plus_ratio + ) all_params.extend(params) else: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params