Skip to content

Commit

Permalink
Merge branch 'dev' into train_resume_step
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Jun 11, 2024
2 parents 4dbcef4 + 3259928 commit 4a44188
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/typos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4

- name: typos-action
uses: crate-ci/typos@v1.19.0
uses: crate-ci/typos@v1.21.0
2 changes: 2 additions & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started

[default.extend-identifiers]
ddPn08="ddPn08"

[default.extend-words]
NIN="NIN"
Expand All @@ -27,6 +28,7 @@ rik="rik"
koo="koo"
yos="yos"
wn="wn"
hime="hime"


[files]
Expand Down
2 changes: 1 addition & 1 deletion library/ipex/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pylint: disable=protected-access, missing-function-docstring, line-too-long

# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers

sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
Expand Down
12 changes: 4 additions & 8 deletions networks/control_net_lllite_for_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import torch
from library import sdxl_original_unet
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)

# input_blocksに適用するかどうか / if True, input_blocks are not applied
Expand Down Expand Up @@ -103,19 +105,15 @@ def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplie
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)

self.cond_image = None
self.cond_emb = None

def set_cond_image(self, cond_image):
self.cond_image = cond_image
self.cond_emb = None

def forward(self, x):
if not self.enabled:
return super().forward(x)

if self.cond_emb is None:
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible

# reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape
Expand Down Expand Up @@ -159,9 +157,7 @@ def forward(self, x): # , cond_image=None):
if not self.enabled:
return super().forward(x)

if self.cond_emb is None:
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
cx = self.lllite_conditioning1(self.cond_image)

cx = torch.cat([cx, self.down(x)], dim=1)
cx = self.mid(cx)
Expand Down
3 changes: 3 additions & 0 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

if isinstance(unet, DDP):
unet._set_static_graph() # avoid error for multiple use of the parameter

if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
Expand Down

0 comments on commit 4a44188

Please sign in to comment.