Skip to content

Commit

Permalink
make forward/backward pathes same ref #1363
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Jun 9, 2024
1 parent 58fb648 commit 1a104dc
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions networks/control_net_lllite_for_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import torch
from library import sdxl_original_unet
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)

# input_blocksに適用するかどうか / if True, input_blocks are not applied
Expand Down Expand Up @@ -103,19 +105,15 @@ def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplie
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)

self.cond_image = None
self.cond_emb = None

def set_cond_image(self, cond_image):
self.cond_image = cond_image
self.cond_emb = None

def forward(self, x):
if not self.enabled:
return super().forward(x)

if self.cond_emb is None:
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible

# reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape
Expand Down Expand Up @@ -159,9 +157,7 @@ def forward(self, x): # , cond_image=None):
if not self.enabled:
return super().forward(x)

if self.cond_emb is None:
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
cx = self.lllite_conditioning1(self.cond_image)

cx = torch.cat([cx, self.down(x)], dim=1)
cx = self.mid(cx)
Expand Down

0 comments on commit 1a104dc

Please sign in to comment.