diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 9520afd95b9..520e96d22ab 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -907,7 +907,7 @@ def _validate_launch_command(args): warned.append(f"\t`--num_processes` was set to a value of `{args.num_processes}`") if not args.multi_gpu and ( (args.use_xpu and is_xpu_available() and torch.xpu.device_count() > 1) - or (is_npu_available() and torch.npu.device_count > 1) + or (is_npu_available() and torch.npu.device_count() > 1) or (torch.cuda.device_count() > 1) ): warned.append(