Skip to content

Commit

Permalink
[Fix] fp16 unscaling in train_dreambooth_lora_sdxl (#10889)
Browse files Browse the repository at this point in the history
Fix fp16 bug

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
SahilCarterr and sayakpaul authored Feb 24, 2025
1 parent db21c97 commit 170833c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def log_validation(

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)

pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand All @@ -213,7 +213,7 @@ def log_validation(
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()

with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
Expand Down

0 comments on commit 170833c

Please sign in to comment.