Skip to content

Commit

Permalink
NPU implementation for FLUX
Browse files Browse the repository at this point in the history
  • Loading branch information
蒋硕 committed Oct 30, 2024
1 parent 189c617 commit c814e06
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions examples/dreambooth/train_dreambooth_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,11 +1044,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))

if cur_class_images < args.num_class_images:
has_supported_fp16_accelerator = (
torch.cuda.is_available()
or torch.backends.mps.is_available()
or is_torch_npu_available()
)
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() or is_torch_npu_available()
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
Expand Down

0 comments on commit c814e06

Please sign in to comment.