Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tpu): remove nprocs from xla.spawn #3324

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ def deepspeed_launcher(args):

def tpu_launcher(args):
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla import device_count

if args.no_python:
raise ValueError("--no_python cannot be used with TPU launcher")
Expand All @@ -875,13 +876,17 @@ def tpu_launcher(args):
f"Your training script should have a function named {args.main_training_function}, or you should pass a "
"different value to `--main_training_function`."
)
if args.num_processes and args.num_processes != device_count():
raise ValueError(
f"Number of processes ({args.num_processes}) must match the number of TPU devices ({device_count()})"
)

# Patch sys.argv
sys.argv = [mod.__file__] + args.training_script_args

main_function = getattr(mod, args.main_training_function)
with patch_environment(**current_env):
xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes)
xmp.spawn(PrepareForLaunch(main_function), args=())


def tpu_pod_launcher(args):
Expand Down
7 changes: 3 additions & 4 deletions src/accelerate/launchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,18 @@ def train(*args):
if (in_colab or in_kaggle) and (os.environ.get("TPU_NAME", None) is not None):
# TPU launch
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla import device_count

if len(AcceleratorState._shared_state) > 0:
raise ValueError(
"To train on TPU in Colab or Kaggle Kernel, the `Accelerator` should only be initialized inside "
"your training function. Restart your notebook and make sure no cells initializes an "
"`Accelerator`."
)
if num_processes is None:
num_processes = 8

launcher = PrepareForLaunch(function, distributed_type="XLA")
print(f"Launching a training on {num_processes} TPU cores.")
xmp.spawn(launcher, args=args, nprocs=num_processes, start_method="fork")
print(f"Launching a training on {device_count()} TPU cores.")
xmp.spawn(launcher, args=args, start_method="fork")
elif in_colab and get_gpu_info()[1] < 2:
# No need for a distributed launch otherwise as it's either CPU or one GPU.
if torch.cuda.is_available():
Expand Down
16 changes: 14 additions & 2 deletions tests/xla_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pathlib import Path

import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla import device_count


def parse_args():
Expand All @@ -46,7 +47,13 @@ def parse_args():
)

# Optional arguments for the launch helper
parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).")
num_devices = device_count()
parser.add_argument(
"--num_cores",
type=int,
default=num_devices,
help="Number of TPU cores to use (1 or number of available devices).",
)

# positional
parser.add_argument(
Expand Down Expand Up @@ -76,7 +83,12 @@ def main():
mod = importlib.import_module(mod_name)

# Patch sys.argv
sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)]
sys.argv = [args.training_script] + args.training_script_args
num_cores = args.num_cores
if num_cores == device_count() and num_cores != 1:
# There is an error in xmp.spawn that causes it to fail when num_cores is specified and not 1, so we set it to
# None when it matches the number of devices.
num_cores = None
xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)


Expand Down
Loading