From 1a104dc75ee5733af8ba17cc9778b39e26673734 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jun 2024 19:26:36 +0900 Subject: [PATCH] make forward/backward pathes same ref #1363 --- networks/control_net_lllite_for_train.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py index 65b3520cf..366451b7f 100644 --- a/networks/control_net_lllite_for_train.py +++ b/networks/control_net_lllite_for_train.py @@ -7,8 +7,10 @@ import torch from library import sdxl_original_unet from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied @@ -103,19 +105,15 @@ def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplie add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim) self.cond_image = None - self.cond_emb = None def set_cond_image(self, cond_image): self.cond_image = cond_image - self.cond_emb = None def forward(self, x): if not self.enabled: return super().forward(x) - if self.cond_emb is None: - self.cond_emb = self.lllite_conditioning1(self.cond_image) - cx = self.cond_emb + cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible # reshape / b,c,h,w -> b,h*w,c n, c, h, w = cx.shape @@ -159,9 +157,7 @@ def forward(self, x): # , cond_image=None): if not self.enabled: return super().forward(x) - if self.cond_emb is None: - self.cond_emb = self.lllite_conditioning1(self.cond_image) - cx = self.cond_emb + cx = self.lllite_conditioning1(self.cond_image) cx = torch.cat([cx, self.down(x)], dim=1) cx = self.mid(cx)