diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 08138033155..12f67088a2c 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -1002,8 +1002,12 @@ def sagemaker_launcher(sagemaker_config: SageMakerConfig, args): def launch_command(args): # Sanity checks - if sum([args.multi_gpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1: - raise ValueError("You can only pick one between `--multi_gpu`, `--use_deepspeed`, `--tpu`, `--use_fsdp`.") + if sum([args.multi_gpu, args.cpu, args.tpu, args.mps, args.use_deepspeed, args.use_fsdp]) > 1: + raise ValueError( + "You can only use one of `--cpu`, `--multi_gpu`, `--mps`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time." + ) + if args.multi_gpu and args.num_processes < 2: + raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.") defaults = None warned = []