diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..048ed2ce3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2920,6 +2920,9 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") + parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") + parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): diff --git a/networks/dylora.py b/networks/dylora.py index 637f33450..0546fc7ae 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -406,27 +406,63 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): logger.info(f"weights are merged") """ - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, + loraplus_ratio=None + ): self.requires_grad_(True) all_params = [] - def enumerate_params(loras): - params = [] + def assemble_params(loras, lr, ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) + for name, param in lora.named_parameters(): + if ratio is not None and "lora_B" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_loraplus_ratio or loraplus_ratio + ) + all_params.extend(params) if self.unet_loras: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_loraplus_ratio or loraplus_ratio + ) + all_params.extend(params) return all_params diff --git a/networks/lora.py b/networks/lora.py index d1208040f..edbbdc0d8 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1034,21 +1034,55 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, + loraplus_ratio=None + ): self.requires_grad_(True) all_params = [] - def enumerate_params(loras): - params = [] + def assemble_params(loras, lr, ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) + for name, param in lora.named_parameters(): + if ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + print("NO LR skipping!") + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_loraplus_ratio or loraplus_ratio + ) + all_params.extend(params) if self.unet_loras: if self.block_lr: @@ -1062,21 +1096,20 @@ def enumerate_params(loras): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) - elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) + params = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + unet_loraplus_ratio or loraplus_ratio + ) + all_params.extend(params) else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + unet_loraplus_ratio or loraplus_ratio + ) + all_params.extend(params) return all_params diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 919222ce8..9a608118a 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1033,22 +1033,54 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, + loraplus_ratio=None + ): self.requires_grad_(True) all_params = [] - def enumerate_params(loras: List[LoRAModule]): - params = [] + def assemble_params(loras, lr, ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - # params.extend(lora.parameters()) - params.extend(lora.get_trainable_params()) + for name, param in lora.named_parameters(): + if ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_loraplus_ratio or loraplus_ratio + ) + all_params.extend(params) if self.unet_loras: if self.block_lr: @@ -1062,21 +1094,20 @@ def enumerate_params(loras: List[LoRAModule]): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) - elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) + params = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + unet_loraplus_ratio or loraplus_ratio + ) + all_params.extend(params) else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + unet_loraplus_ratio or loraplus_ratio + ) + all_params.extend(params) return all_params @@ -1093,6 +1124,9 @@ def on_epoch_start(self, text_encoder, unet): def get_trainable_params(self): return self.parameters() + def get_trainable_named_params(self): + return self.named_parameters() + def save_weights(self, file, dtype, metadata): if metadata is not None and len(metadata) == 0: metadata = None diff --git a/train_network.py b/train_network.py index aad5a7194..9670490ae 100644 --- a/train_network.py +++ b/train_network.py @@ -64,34 +64,69 @@ def generate_step_logs( lrs = lr_scheduler.get_last_lr() - if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) - if args.network_train_unet_only: - logs["lr/unet"] = float(lrs[0]) - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = float(lrs[0]) - else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - ) - else: + if len(lrs) > 4: idx = 0 if not args.network_train_unet_only: logs["lr/textencoder"] = float(lrs[0]) idx = 1 for i in range(idx, len(lrs)): - logs[f"lr/group{i}"] = float(lrs[i]) + lora_plus = "" + group_id = i + + if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + lora_plus = '_lora+' if i % 2 == 1 else '' + group_id = int((i / 2) + (i % 2 + 0.5)) + + logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i]) if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{i}"] = ( + logs[f"lr/d*lr/group{group_id}{lora_plus}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) + else: + if args.network_train_text_encoder_only: + if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + else: + logs["lr/textencoder"] = float(lrs[0]) + + elif args.network_train_unet_only: + if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + logs["lr/unet"] = float(lrs[0]) + logs["lr/unet_lora+"] = float(lrs[1]) + else: + logs["lr/unet"] = float(lrs[0]) + else: + if len(lrs) == 2: + if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None: + logs["lr/unet"] = float(lrs[0]) + logs["lr/unet_lora+"] = float(lrs[1]) + elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None: + logs["lr/all"] = float(lrs[0]) + logs["lr/all_lora+"] = float(lrs[1]) + else: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/unet"] = float(lrs[-1]) + elif len(lrs) == 4: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + logs["lr/unet"] = float(lrs[2]) + logs["lr/unet_lora+"] = float(lrs[3]) + else: + logs["lr/all"] = float(lrs[0]) + + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + ) + return logs def assert_extra_args(self, args, train_dataset_group): @@ -338,7 +373,7 @@ def train(self, args): # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio) except TypeError: accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" @@ -347,6 +382,11 @@ def train(self, args): optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + assert ( + (optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name) + ), "LoRA+ and Prodigy/DAdaptation is not supported" + # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers