From 89722bd0e6bf2336a85c257c18b6f338a528d305 Mon Sep 17 00:00:00 2001 From: Bhavya Date: Tue, 1 Oct 2024 20:39:19 +0000 Subject: [PATCH 1/4] Allow MpDeviceLoader to shard dictionaries of tensor with different shapes --- docs/spmd_advanced.md | 18 +++ test/run_tests.sh | 1 + test/spmd/test_mp_input_sharding.py | 149 +++++++++++++++++++++++ torch_xla/distributed/parallel_loader.py | 75 +++++++++--- 4 files changed, 228 insertions(+), 15 deletions(-) create mode 100644 test/spmd/test_mp_input_sharding.py diff --git a/docs/spmd_advanced.md b/docs/spmd_advanced.md index 4cd07a558c9f..a6bd0762b0d1 100644 --- a/docs/spmd_advanced.md +++ b/docs/spmd_advanced.md @@ -14,6 +14,24 @@ train_loader = pl.MpDeviceLoader( input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None))) ``` +It is also possible to specify a different `input_sharding` for each element of the batch if they are different shapes: + +```python +# if batch = next(train_loader) looks like +# {'x': , 'y': } + +# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator +train_loader = pl.MpDeviceLoader( + train_loader, # wraps PyTorch DataLoader + device, + # specify different sharding for each input of the batch. + input_sharding={ + 'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)), + 'y': xs.ShardingSpec(input_mesh, ('data', None)) + } +) +``` + ### Virtual Device Optimization PyTorch/XLA normally transfers tensor data asynchronously from host to device once the tensor is defined. This is to overlap the data transfer with the graph tracing time. However, because GSPMD allows the user to modify the tensor sharding _after _the tensor has been defined, we need an optimization to prevent unnecessary transfer of tensor data back and forth between host and device. We introduce Virtual Device Optimization, a technique to place the tensor data on a virtual device SPMD:0 first, before uploading to the physical devices when all the sharding decisions are finalized. Every tensor data in SPMD mode is placed on a virtual device, SPMD:0. The virtual device is exposed to the user as an XLA device XLA:0 with the actual shards on physical devices, like TPU:0, TPU:1, etc. diff --git a/test/run_tests.sh b/test/run_tests.sh index 9a8c8fce9d5d..0912d53ded5a 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -245,6 +245,7 @@ function run_xla_op_tests3 { run_test "$CDIR/spmd/test_dtensor_integration2.py" run_test "$CDIR/spmd/test_xla_auto_sharding.py" run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py" + run_test "$CDIR/spmd/test_mp_input_sharding.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_input_output_aliases.py" run_test "$CDIR/test_torch_distributed_xla_backend.py" diff --git a/test/spmd/test_mp_input_sharding.py b/test/spmd/test_mp_input_sharding.py new file mode 100644 index 000000000000..c88f1ed63fbf --- /dev/null +++ b/test/spmd/test_mp_input_sharding.py @@ -0,0 +1,149 @@ +import sys +import numpy as np +import unittest + +import torch +import torch_xla +from torch_xla import runtime as xr +import torch_xla.core.xla_model as xm +from torch_xla.distributed.spmd import Mesh +import torch_xla.distributed.spmd as xs +import torch_xla.distributed.parallel_loader as pl + +xr.use_spmd() + + +class MpInputShardingTest(unittest.TestCase): + + class fake_dataloader: + + def __init__(self, batch, size=1): + self.batch = batch + self.batch_size = size + self.counter = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.counter < self.batch_size: + self.counter += 1 + return self.batch + raise StopIteration + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_multiple_inputs(self): + device = xm.xla_device() + batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} + train_loader = self.fake_dataloader(batch) + num_devices = xr.global_runtime_device_count() + mesh = xs.get_1d_mesh('x') + + train_loader = pl.MpDeviceLoader( + train_loader, + device, + input_sharding={ + 'x': xs.ShardingSpec(mesh, ('x', None)), + 'y': xs.ShardingSpec(mesh, ('x', None, None)) + }) + train_loader = iter(train_loader) + data = next(train_loader) + annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + annotation_y = '{devices=[%d,1,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + self.assertEqual(annotation_x, + torch_xla._XLAC._get_xla_sharding_spec(data['x'])) + self.assertEqual(annotation_y, + torch_xla._XLAC._get_xla_sharding_spec(data['y'])) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_single_tensor(self): + device = xm.xla_device() + batch = torch.randn((16, 128)) + train_loader = self.fake_dataloader(batch) + num_devices = xr.global_runtime_device_count() + mesh = xs.get_1d_mesh('x') + + train_loader = pl.MpDeviceLoader( + train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None))) + train_loader = iter(train_loader) + data = next(train_loader) + annotation = '{devices=[%d,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(data)) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_error_single_tensor_with_input_sharding_dict(self): + device = xm.xla_device() + batch = torch.randn((16, 128)) + train_loader = self.fake_dataloader(batch) + num_devices = xr.global_runtime_device_count() + mesh = xs.get_1d_mesh('x') + + train_loader = pl.MpDeviceLoader( + train_loader, device, input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))}) + train_loader = iter(train_loader) + with self.assertRaises(ValueError): + data = next(train_loader) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_input_sharding_none(self): + device = xm.xla_device() + batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} + train_loader = self.fake_dataloader(batch) + num_devices = xr.global_runtime_device_count() + + train_loader = pl.MpDeviceLoader(train_loader, device, input_sharding=None) + train_loader = iter(train_loader) + data = next(train_loader) + annotation = '{replicated}' + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(data['x'])) + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(data['y'])) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_error_missing_keys(self): + device = xm.xla_device() + batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} + train_loader = self.fake_dataloader(batch) + mesh = xs.get_1d_mesh('x') + train_loader = pl.MpDeviceLoader( + train_loader, + device, + input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))}) + train_loader = iter(train_loader) + with self.assertRaises(KeyError): + data = next(train_loader) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_input_sharding_not_dict(self): + device = xm.xla_device() + num_devices = xr.global_runtime_device_count() + batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128))} + train_loader = self.fake_dataloader(batch) + mesh = xs.get_1d_mesh('x') + train_loader = pl.MpDeviceLoader( + train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None))) + train_loader = iter(train_loader) + data = next(train_loader) + annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + annotation_y = '{devices=[%d,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + self.assertEqual(annotation_x, + torch_xla._XLAC._get_xla_sharding_spec(data['x'])) + self.assertEqual(annotation_y, + torch_xla._XLAC._get_xla_sharding_spec(data['y'])) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index a0304b4523a1..85d13139d033 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -12,7 +12,7 @@ class PerDeviceQueue(object): def __init__(self, device, loader_prefetch_size, device_prefetch_size): self.device = device - self.loader_queue = kq.Queue(maxsize=loader_prefetch_size) + self.cpu_loader_queue = kq.Queue(maxsize=loader_prefetch_size) self.queue = kq.Queue(maxsize=device_prefetch_size) self.close_queue_count = itertools.count() @@ -46,6 +46,8 @@ def next(self): self._batches_yielded += 1 item = self._loader.next_item(self._device) + if isinstance(item, Exception): + raise item if item is None: xm.mark_step() raise StopIteration @@ -56,7 +58,7 @@ class ParallelLoader(object): """Wraps an existing PyTorch DataLoader with background data upload. Args: - loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be + cpu_loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be wrapped. devices (`torch.device`...): The list of devices where the data has to be sent. The i-th sample returned by the `loader` will be sent to `devices[i @@ -74,13 +76,12 @@ class ParallelLoader(object): host_to_device_transfer_threads (int, optional): The number of threads that work in parallel to transfer data from loader queue to device queue. Default: 1 - input_sharding (ShardingSpec, optional): Sharding spec to apply to - compatible input tensors after loading. - Default: None + input_sharding (ShardingSpec, Dict(str, ShardingSpec), optional): Sharding + spec to apply to compatible input tensors after loading. """ def __init__(self, - loader, + cpu_loader, devices, batchdim=0, batches_per_execution=1, @@ -88,7 +89,7 @@ def __init__(self, device_prefetch_size=8, host_to_device_transfer_threads=1, input_sharding=None): - self._loader = loader + self._cpu_loader = cpu_loader self._devices = [torch.device(x) for x in devices] self._batchdim = batchdim self._batches_per_execution = batches_per_execution @@ -140,7 +141,7 @@ def close(self): self._done = True for dqueue in self._queues.values(): dqueue.queue.close() - dqueue.loader_queue.close() + dqueue.cpu_loader_queue.close() for thread in self._threads: thread.join() @@ -151,7 +152,7 @@ def batches_per_execution(self): def _loader_worker(self): queues = list(self._queues.values()) - data_iter = enumerate(self._loader) + data_iter = enumerate(self._cpu_loader) batch = [] try: @@ -163,21 +164,66 @@ def _loader_worker(self): batch.append(data) if len(batch) == len(self._devices): for queue_no, device_batch in enumerate(batch): - queues[queue_no].loader_queue.put(device_batch) + queues[queue_no].cpu_loader_queue.put(device_batch) batch = [] finally: for dqueue in queues: - dqueue.loader_queue.close_write() + dqueue.cpu_loader_queue.close_write() def _get_batch(self, dqueue): batch = [] - while dqueue.queue.max_size() > len(batch): - item = dqueue.loader_queue.get() + while len(batch) < dqueue.queue.max_size(): + item = dqueue.cpu_loader_queue.get() if item is None: break batch.append(item) return batch + def send_cpu_data_to_device(self, batches, device): + """Move batch to device. + Args: + batch -> List(torch.Tensor), List(Dict(str: torch.Tensor)): Input batch + present in the cpu memory + device: TPU device where the batch should be moved + + Returns: + result -> List(torch.Tensor), Dict(str: torch.Tensor): Returns a dict if the + input batch is a dict. Otherwise, returns a list of torch.Tensor. + """ + result = None + if isinstance(self._input_sharding, dict): + if not isinstance(batches[0], dict): + return [ + ValueError( + f"input batch should be a dict when input sharding is a dict." + ) + ] + result = [] + for batch in batches: + xla_batch = {} + missing_keys = [] + for key, tensor in batch.items(): + assert type(tensor) == torch.Tensor + sharding_spec = None + if self._input_sharding: + if key not in self._input_sharding: + missing_keys.append(key) + continue + sharding_spec = self._input_sharding[key] + + # xla_tensor is a list of tensors. + xla_tensor = xm.send_cpu_data_to_device(tensor, device, sharding_spec) + xla_batch[key] = xla_tensor[0] + if len(missing_keys) != 0: + # Returning exception as raising in the dataloading thread doesn't surface the problem in the main thread. + return [ + KeyError(f"Keys: {missing_keys} are missing from input_sharding.") + ] + result.append(xla_batch) + else: + result = xm.send_cpu_data_to_device(batches, device, self._input_sharding) + return result + def _worker(self, dqueue, host_to_device_transfer_threads): device = torch.device(dqueue.device) @@ -187,8 +233,7 @@ def _worker(self, dqueue, host_to_device_transfer_threads): if not batch: break with torch.no_grad(): - batch = xm.send_cpu_data_to_device(batch, device, - self._input_sharding) + batch = self.send_cpu_data_to_device(batch, device) for data in batch: dqueue.queue.put(data) finally: From 88bbd54aaef28f3fb1decc5e326c084b7af9ab44 Mon Sep 17 00:00:00 2001 From: Bhavya Date: Tue, 1 Oct 2024 20:44:21 +0000 Subject: [PATCH 2/4] Fix formatting --- test/spmd/test_mp_input_sharding.py | 4 +++- torch_xla/distributed/parallel_loader.py | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/spmd/test_mp_input_sharding.py b/test/spmd/test_mp_input_sharding.py index c88f1ed63fbf..6b78a3714e79 100644 --- a/test/spmd/test_mp_input_sharding.py +++ b/test/spmd/test_mp_input_sharding.py @@ -85,7 +85,9 @@ def test_error_single_tensor_with_input_sharding_dict(self): mesh = xs.get_1d_mesh('x') train_loader = pl.MpDeviceLoader( - train_loader, device, input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))}) + train_loader, + device, + input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))}) train_loader = iter(train_loader) with self.assertRaises(ValueError): data = next(train_loader) diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index 85d13139d033..96a65d74b19b 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -195,8 +195,7 @@ def send_cpu_data_to_device(self, batches, device): if not isinstance(batches[0], dict): return [ ValueError( - f"input batch should be a dict when input sharding is a dict." - ) + f"input batch should be a dict when input sharding is a dict.") ] result = [] for batch in batches: From fb9c25711256408d79e8bad1369c8f7b3ffd4d59 Mon Sep 17 00:00:00 2001 From: Bhavya Date: Wed, 2 Oct 2024 21:02:40 +0000 Subject: [PATCH 3/4] Fix indent and add test to tpu ci --- docs/spmd_advanced.md | 4 ++-- test/tpu/run_tests.sh | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/spmd_advanced.md b/docs/spmd_advanced.md index a6bd0762b0d1..369fdfe25700 100644 --- a/docs/spmd_advanced.md +++ b/docs/spmd_advanced.md @@ -10,7 +10,7 @@ PyTorch/XLA SPMD takes a single-device program, shards and executes it in parall train_loader = pl.MpDeviceLoader( train_loader, # wraps PyTorch DataLoader device, - # assume 4d input and we want to shard at the batch dimension. + # assume 4d input and we want to shard at the batch dimension. input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None))) ``` @@ -24,7 +24,7 @@ It is also possible to specify a different `input_sharding` for each element of train_loader = pl.MpDeviceLoader( train_loader, # wraps PyTorch DataLoader device, - # specify different sharding for each input of the batch. + # specify different sharding for each input of the batch. input_sharding={ 'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)), 'y': xs.ShardingSpec(input_mesh, ('data', None)) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 52d1de5b1505..89661e29a58f 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -5,6 +5,7 @@ set -xue python3 test/test_operations.py -v python3 test/pjrt/test_runtime_tpu.py python3 test/pjrt/test_collective_ops_tpu.py +python3 test/spmd/test_mp_input_sharding.py python3 test/spmd/test_xla_sharding.py python3 test/spmd/test_xla_virtual_device.py python3 test/spmd/test_xla_distributed_checkpoint.py From 487d47152c9c6e90ddee9dec32581c99bff28ebd Mon Sep 17 00:00:00 2001 From: Bhavya Date: Wed, 2 Oct 2024 22:10:41 +0000 Subject: [PATCH 4/4] Add exception_queue to parallel loader --- torch_xla/distributed/parallel_loader.py | 25 ++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index 96a65d74b19b..a177c92b59d8 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -1,4 +1,5 @@ import itertools +import queue import threading import torch import torch_xla @@ -46,9 +47,9 @@ def next(self): self._batches_yielded += 1 item = self._loader.next_item(self._device) - if isinstance(item, Exception): - raise item if item is None: + if not self._loader._exception_queue.empty(): + raise self._loader._exception_queue.get() xm.mark_step() raise StopIteration return item @@ -95,6 +96,7 @@ def __init__(self, self._batches_per_execution = batches_per_execution self._done = False self._queues = dict() + self._exception_queue = queue.Queue() self._input_sharding = input_sharding self._threads = [] for device in self._devices: @@ -193,10 +195,8 @@ def send_cpu_data_to_device(self, batches, device): result = None if isinstance(self._input_sharding, dict): if not isinstance(batches[0], dict): - return [ - ValueError( - f"input batch should be a dict when input sharding is a dict.") - ] + raise ValueError( + f"input batch should be a dict when input sharding is a dict.") result = [] for batch in batches: xla_batch = {} @@ -215,9 +215,8 @@ def send_cpu_data_to_device(self, batches, device): xla_batch[key] = xla_tensor[0] if len(missing_keys) != 0: # Returning exception as raising in the dataloading thread doesn't surface the problem in the main thread. - return [ - KeyError(f"Keys: {missing_keys} are missing from input_sharding.") - ] + raise KeyError( + f"Keys: {missing_keys} are missing from input_sharding.") result.append(xla_batch) else: result = xm.send_cpu_data_to_device(batches, device, self._input_sharding) @@ -232,7 +231,13 @@ def _worker(self, dqueue, host_to_device_transfer_threads): if not batch: break with torch.no_grad(): - batch = self.send_cpu_data_to_device(batch, device) + try: + batch = self.send_cpu_data_to_device(batch, device) + except Exception as e: + # _worker is being run in a daemon thread, raise the error + # will not work. Put the error in an error queue instead. + self._exception_queue.put(e) + break for data in batch: dqueue.queue.put(data) finally: