Skip to content

Commit

Permalink
Merge pull request #1331 from kohya-ss/lora-plus
Browse files Browse the repository at this point in the history
Lora plus
  • Loading branch information
kohya-ss authored May 12, 2024
2 parents 1ffc0b3 + 4419041 commit 02298e3
Show file tree
Hide file tree
Showing 5 changed files with 387 additions and 193 deletions.
26 changes: 24 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,18 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required.
- Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side.

- Fixed some bugs when using DeepSpeed. Related [#1247]
- LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO!
- LoRA+ is a method to improve training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiple. The original paper recommends 16, but adjust as needed. Please see the PR for details.
- Specify `loraplus_lr_ratio` with `--network_args`. Example: `--network_args "loraplus_lr_ratio=16"`
- `loraplus_unet_lr_ratio` and `loraplus_lr_ratio` can be specified separately for U-Net and Text Encoder.
- Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc.
- `network_module` `networks.lora` and `networks.dylora` are available.

- LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
- Specify the learning rate and dim (rank) for each block.
- See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only).

- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)

- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
Expand All @@ -171,7 +182,18 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- `--fused_optimizer_groups``--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。
- 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。

- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247]
- LoRA+ がサポートされました。PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) rockerBOO 氏に感謝します。
- LoRA の UP 側(LoRA-B)の学習率を上げることで学習速度の向上を図る手法です。倍数で指定します。元の論文では 16 が推奨されていますが、データセット等にもよりますので、適宜調整してください。PR もあわせてご覧ください。
- `--network_args``loraplus_lr_ratio` を指定します。例:`--network_args "loraplus_lr_ratio=16"`
- `loraplus_unet_lr_ratio``loraplus_lr_ratio` で、U-Net および Text Encoder に個別の値を指定することも可能です。
- 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など
- `network_module``networks.lora` および `networks.dylora` で使用可能です。

- SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
- ブロックごとに学習率および dim (rank) を指定することができます。
- 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。

- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)


### Apr 7, 2024 / 2024-04-07: v0.8.7
Expand Down
11 changes: 7 additions & 4 deletions docs/train_network_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,16 @@ python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.saf

詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。

SDXLは現在サポートしていません。

フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。

SDXL では down/up 9 個、middle 3 個の値を指定してください。

`--network_args` で以下の引数を指定してください。

- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個(SDXL では 9 個)の数値を指定します
- プリセットからの指定 : `"down_lr_weight=sine"` のように指定します(サインカーブで重みを指定します)。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します(0.25~1.25になります)。
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します(SDXL の場合は 3 個)
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。
Expand All @@ -215,6 +215,9 @@ network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_l

フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。

SDXL では 23 個の値を指定してください。一部のブロックにはLoRA が存在しませんが、`sdxl_train.py`[階層別学習率](./train_SDXL-en.md) との互換性のためです。
対応は、`0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out` です。

`--network_args` で以下の引数を指定してください。

- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。
Expand Down
78 changes: 63 additions & 15 deletions networks/dylora.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
import torch
from torch import nn
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)


class DyLoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
Expand Down Expand Up @@ -195,7 +198,7 @@ def create_network(
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)

if unit is not None:
unit = int(unit)
else:
Expand All @@ -211,6 +214,16 @@ def create_network(
unit=unit,
varbose=True,
)

loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)

return network


Expand Down Expand Up @@ -280,6 +293,10 @@ def __init__(
self.alpha = alpha
self.apply_to_conv = apply_to_conv

self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None

if modules_dim is not None:
logger.info("create LoRA network from weights")
else:
Expand Down Expand Up @@ -320,9 +337,9 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
loras.append(lora)
return loras

text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]

self.text_encoder_loras = []
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
Expand All @@ -331,7 +348,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
else:
index = None
logger.info("create LoRA for Text Encoder")

text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)

Expand All @@ -346,6 +363,11 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
self.unet_loras = create_modules(True, unet, target_modules)
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")

def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio

def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
Expand Down Expand Up @@ -406,27 +428,53 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
logger.info(f"weights are merged")
"""

# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []

def enumerate_params(loras):
params = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
for name, param in lora.named_parameters():
if ratio is not None and "lora_B" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}

if len(param_data["params"]) == 0:
continue

if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr

if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue

params.append(param_data)

return params

if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)

if self.unet_loras:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params = assemble_params(
self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio
)
all_params.extend(params)

return all_params

Expand Down
Loading

0 comments on commit 02298e3

Please sign in to comment.