From 58fb64819ab117e2b7bca6e87bae28901b616860 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jun 2024 19:26:09 +0900 Subject: [PATCH] set static graph flag when DDP ref #1363 --- sdxl_train_control_net_lllite.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 301310901..5ff060a9f 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -289,6 +289,9 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if isinstance(unet, DDP): + unet._set_static_graph() # avoid error for multiple use of the parameter + if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: