From db7c973daccd71d53a91c17fb962cb1fd64607db Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Mon, 10 Jul 2023 15:45:49 -0700 Subject: [PATCH 01/20] Update inline style code to multiline (#5291) --- torch_xla/csrc/aten_xla_bridge.cpp | 8 ++++++-- torch_xla/csrc/tensor_methods.cpp | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index d0e9e4845e7d..4542b6dbda4f 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -150,9 +150,13 @@ std::vector XlaCreateTensorList(const at::ITensorListRef& tensors) { std::vector to_translate(tensors.size()); size_t ix = 0; for (const auto& tensor : tensors) { - if (!tensor.defined()) continue; + if (!tensor.defined()) { + continue; + } auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor); - if (!inner_tensor.defined()) continue; + if (!inner_tensor.defined()) { + continue; + } auto xtensor = TryGetXlaTensor(tensor); if (xtensor) { diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 25be3316a191..348d8c555bcc 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2482,7 +2482,9 @@ XLATensorPtr squeeze(const XLATensorPtr& input, std::vector dims) { input_shape.get().dimensions()); std::vector output_dimensions; for (int64_t dim : dims) { - if (dim >= input_dimensions.size()) continue; + if (dim >= input_dimensions.size()) { + continue; + } int64_t squeeze_dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_dimensions.size()); output_dimensions = BuildSqueezedDimensions(input_dimensions, squeeze_dim); From e8e66d4041e284e99ac7a9d2e9a42ad21181cf52 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 10 Jul 2023 16:22:54 -0700 Subject: [PATCH 02/20] Fix typo in _test.yml (#5172) s/metadtaa/metadata/ --- .github/workflows/_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 7e70ed7e3574..203639f37bf5 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -125,7 +125,7 @@ jobs: INC_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "patchset_num": 1, "change_id": '\"${CIRCLE_BUILD_NUM}\"', "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}' echo $INC_METADATA > inc_metadata.json - gsutil cp inc_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadtaa.json + gsutil cp inc_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json fi - name: Teardown Linux From 848c00d72ad0018f2f6ab3dfdd5b435649b968c2 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Tue, 11 Jul 2023 15:22:22 -0700 Subject: [PATCH 03/20] [SPMD][Virtual Device]All tensors should be in SPMD:0 C++ device (#5284) * Move all tensors to SPMD:0 C++ device under spmd context * fix load shards * fix test_mark_sharding_2d by not creating placeholder for virtual device * fix the waitdeviceop for spmd case * Fix test_shard_hashing * fix spmd device casting issue * remove hacks in test_xla_virtual_device.py * add test for new virtual device usage * fix review comments * fix IsTpuDevice * linter --- test/spmd/test_xla_sharding.py | 1 - test/spmd/test_xla_virtual_device.py | 63 +++++++++++++++++-- torch_xla/csrc/aten_xla_bridge.cpp | 13 ++-- torch_xla/csrc/aten_xla_type.cpp | 2 +- torch_xla/csrc/device.cpp | 18 +++++- torch_xla/csrc/device.h | 4 ++ torch_xla/csrc/init_python_bindings.cpp | 38 ++++++----- .../csrc/runtime/pjrt_computation_client.cc | 4 ++ torch_xla/csrc/tensor.cpp | 6 +- torch_xla/csrc/tensor_util.cpp | 45 ++++++------- torch_xla/csrc/xla_graph_executor.cpp | 21 +++---- torch_xla/csrc/xla_sharding_util.cpp | 14 ----- torch_xla/csrc/xla_sharding_util.h | 4 -- torch_xla/experimental/xla_sharded_tensor.py | 7 +-- 14 files changed, 150 insertions(+), 90 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index a9fa7d3dc7d0..6f175c69862f 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -314,7 +314,6 @@ def test_execute_replicated_metrics(self): xt = torch.ones(2, 2).to(xm.xla_device()) xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1)) xt += 2 - sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt) xm.mark_step() xm.wait_device_ops() self.assertEqual(met.metric_data('ExecuteReplicatedTime')[0], 1) diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index 99dbd4f9015c..cc9e9df93c52 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -79,9 +79,6 @@ def test_outbound_data_metrics(self): def test_non_tensor_scalar(self): sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) - # TODO(JackCaoG)currently, execution will only happen if there is at least one - # tensor on non-spmd:0 device. - t1 = torch.randn(3, 3, device=xm.xla_device()) # tensor will have device as `SPMD:0` in c++ xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], xm.xla_device(), @@ -95,9 +92,6 @@ def test_non_tensor_scalar(self): def test_mark_step_on_virtual_device(self): xm.mark_step() sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) - # TODO(JackCaoG)currently, execution will only happen if there is at least one - # tensor on non-spmd:0 device. - t1 = torch.randn(3, 3, device=xm.xla_device()) # tensor will have device as `SPMD:0` in c++ xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], xm.xla_device(), @@ -108,6 +102,63 @@ def test_mark_step_on_virtual_device(self): self.assertNotIn('aten::div', torch_xla._XLAC._get_xla_tensor_debug_info(xt2)) + def test_virtual_device_no_upload(self): + met.clear_all() + device = xm.xla_device() + t1 = torch.randn(5, 5).to(device) + t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) + # t1's upload to device should be deferred + self.assertIn("Tensor on host: with size [5, 5]", t1_debug_info) + self.assertNotIn("TransferToServerTime", met.metric_names()) + # t1 should be on SPMD device under spmd context + self.assertIn("Device: SPMD:0", t1_debug_info) + self.assertIn("IR: None", t1_debug_info) + self.assertIn("XLAData: None", t1_debug_info) + + def test_virtual_device_upload_after_mark_sharding(self): + met.clear_all() + partition_spec = (0, 1) + device = xm.xla_device() + t1 = torch.randn(8, 8).to(device) + t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) + self.assertIn("Tensor on host: with size [8, 8]", t1_debug_info) + xs.mark_sharding(t1, self._get_mesh((1, self.n_devices)), partition_spec) + t1_debug_info_new = torch_xla._XLAC._get_xla_tensor_debug_info(t1) + # tensor should be uploaded to device after mark_sharding + self.assertIn("Tensor on host: None", t1_debug_info_new) + self.assertIn("xla::device_data", t1_debug_info_new) + self.assertIn("XLAShardedData", t1_debug_info_new) + self.assertIn("TransferToServerTime", met.metric_names()) + + def test_virtual_device_upload_after_tracing(self): + met.clear_all() + device = xm.xla_device() + t1 = torch.randn(8, 8).to(device) + t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) + self.assertIn("Tensor on host: with size [8, 8]", t1_debug_info) + t2 = t1 + t1 + t1_debug_info_new = torch_xla._XLAC._get_xla_tensor_debug_info(t1) + # tensor should be uploaded to device after being used as input to other op. + self.assertIn("Tensor on host: None", t1_debug_info_new) + self.assertIn("xla::device_data", t1_debug_info_new) + self.assertIn("TransferToServerTime", met.metric_names()) + + def test_virtual_device_upload_for_sharded_dataloader(self): + met.clear_counters() + device = xm.xla_device() + sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) + # tensor will have device as `SPMD:0` in c++ + t1 = xm.send_cpu_data_to_device([torch.randn(8, 8)], + device, + input_sharding=sharding_spec)[0] + t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) + self.assertIn("Device: SPMD:0", t1_debug_info) + # tensor should be uploaded to device after send_cpu_data_to_device + sharding_spec + self.assertIn("Tensor on host: None", t1_debug_info) + self.assertIn("xla::device_data", t1_debug_info) + self.assertIn("XLAShardedData", t1_debug_info) + self.assertIn("TransferToServerTime", met.metric_names()) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 4542b6dbda4f..d49302ab6674 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -39,10 +39,15 @@ class AtenXlaDeviceMapper { private: AtenXlaDeviceMapper() { - for (auto& device_str : - torch_xla::runtime::GetComputationClient()->GetLocalDevices()) { - devices_.emplace_back(ParseDeviceString(device_str)); - devices_ordinals_[devices_.back()] = devices_.size() - 1; + if (UseVirtualDevice()) { + devices_.emplace_back(ParseDeviceString("SPMD:0")); + devices_ordinals_[devices_.back()] = 0; + } else { + for (auto& device_str : + torch_xla::runtime::GetComputationClient()->GetLocalDevices()) { + devices_.emplace_back(ParseDeviceString(device_str)); + devices_ordinals_[devices_.back()] = devices_.size() - 1; + } } } diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 7ae4fb08b463..23c6e35a4a5a 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -467,7 +467,7 @@ at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self, if (!self_tensor) { static bool sync_update = runtime::sys_util::GetEnvBool("XLA_TENSOR_UPDATE_SYNC", true) && - !ShardingUtil::UseVirtualDevice(); + !UseVirtualDevice(); XLA_CHECK(dst_tensor); dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update); } else if (!dst_tensor) { diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index ce8f365a93de..fad59c187687 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -37,7 +37,9 @@ std::string DeviceType::toString() const { torch::lazy::BackendDevice ParseDeviceString(const std::string& device_spec) { if (device_spec.empty()) { std::string default_device_spec = - runtime::GetComputationClient()->GetDefaultDevice(); + UseVirtualDevice() + ? "SPMD:0" + : runtime::GetComputationClient()->GetDefaultDevice(); XLA_CHECK(!default_device_spec.empty()); return ParseDeviceString(default_device_spec); } @@ -101,4 +103,18 @@ torch::lazy::BackendDevice SetCurrentDevice( return current; } +bool ShouldUseVirtualDevice() { + bool use_virtual_device = + runtime::sys_util::GetEnvBool("XLA_USE_SPMD", false); + if (use_virtual_device) { + TF_LOG(INFO) << "Using SPMD virtual device optimization"; + } + return use_virtual_device; +} + +bool UseVirtualDevice() { + static bool use_virtual_device = ShouldUseVirtualDevice(); + return use_virtual_device; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/device.h b/torch_xla/csrc/device.h index fd57b6e95ffa..3b11dd398564 100644 --- a/torch_xla/csrc/device.h +++ b/torch_xla/csrc/device.h @@ -42,6 +42,10 @@ static inline torch::lazy::BackendDevice GetDeviceOrCurrent( return device != nullptr ? *device : GetCurrentDevice(); } +// Test whether the XLA_USE_SPMD environment variable is set to enable the +// virtual device optimization. +bool UseVirtualDevice(); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_DEVICE_H_ diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ffb54202cfcc..f9b4d9569e1c 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -385,7 +385,7 @@ std::string GetXLATensorDebugInfo(const at::Tensor& tensor) { auto at_tensor = xtensor->CurrentTensorData(); ss << "Tensor on host: "; if (at_tensor) { - ss << " with size " << at_tensor->sizes() << "\n"; + ss << "with size " << at_tensor->sizes() << "\n"; } else { ss << "None\n"; } @@ -1150,7 +1150,7 @@ void InitXlaModuleBindings(py::module m) { [](const std::vector& devices) { NoGilSection nogil; XLAGraphExecutor::Get()->WaitDeviceOps(devices); - if (ShardingUtil::UseVirtualDevice()) { + if (UseVirtualDevice()) { std::vector spmd_device = {"SPMD:0"}; runtime::GetComputationClient()->WaitDeviceOps(spmd_device); } else { @@ -1339,8 +1339,7 @@ void InitXlaModuleBindings(py::module m) { const py::list& group_assignment, const py::list& replication_groups, int sharding_type) { TORCH_LAZY_COUNTER("XlaMarkSharding", 1); - XLA_CHECK(ShardingUtil::UseVirtualDevice()) - << "Please set `XLA_USE_SPMD=1`"; + XLA_CHECK(UseVirtualDevice()) << "Please set `XLA_USE_SPMD=1`"; XLATensorPtr xtensor = bridge::GetXlaTensor(input); xla::OpSharding sharding = ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups, @@ -1419,23 +1418,33 @@ void InitXlaModuleBindings(py::module m) { // shape. Note that this padding is _not_ included in the global indices // returned by `_get_local_shard_indices`. m.def("_get_local_shards", - [](const at::Tensor& input) -> std::vector { + [](const at::Tensor& input) + -> std::tuple, std::vector> { XLATensorPtr xtensor = bridge::GetXlaTensor(input); XLA_CHECK(xtensor->GetXlaData() != nullptr) << "Shard data is not available"; XLA_CHECK(xtensor->sharding_spec() != nullptr) << "Tensor is not sharded"; - XLA_CHECK(ShardingUtil::UseVirtualDevice()) + XLA_CHECK(UseVirtualDevice()) << "Virtual device must be enabled to use _get_local_shards"; auto handle = UnwrapXlaData(xtensor->GetXlaData()); - auto shard_handles = + std::vector shard_handles = runtime::GetComputationClient()->GetDataShards(handle); std::vector shards; - for (auto& shard_handle : shard_handles) { - auto xshard = XLATensor::Create(WrapXlaData(shard_handle)); - shards.push_back(bridge::AtenFromXlaTensor(std::move(xshard))); + std::vector str_devices; + shards.reserve(shard_handles.size()); + str_devices.reserve(shard_handles.size()); + // Tansfer shards from the device and create cpu tensors. + for (const runtime::ComputationClient::DataPtr shard_handle : + shard_handles) { + shards.push_back( + XlaDataToTensors( + {WrapXlaData(shard_handle)}, + TensorTypeFromXlaType(shard_handle->shape().element_type())) + .front()); + str_devices.push_back(shard_handle->device()); } - return shards; + return std::make_tuple(shards, str_devices); }); // Returns the indices of the shards into the global tensor as either // a Python list of slices for each dimension or a Python Ellipsis object @@ -1504,8 +1513,7 @@ void InitXlaModuleBindings(py::module m) { << "Input shard shape must include padding: " << shard.sizes() << " vs " << shard_shape; } - auto xla_devices = GetXlaDevices(devices); - auto xla_data = ShardingUtil::CreateShardedData(shards, xla_devices, + auto xla_data = ShardingUtil::CreateShardedData(shards, devices, xtensor->shape(), sharding); xtensor->SetXlaData(WrapXlaData(xla_data)); }); @@ -1703,8 +1711,8 @@ void InitXlaModuleBindings(py::module m) { torch::lazy::hash_t hash = *(torch::lazy::hash_t*)(hash_str.c_str()); // Device will be Virtual device if SPMD is enabled. torch::lazy::BackendDevice device = - ShardingUtil::UseVirtualDevice() ? ParseDeviceString("SPMD:0") - : torch_xla::GetCurrentDevice(); + UseVirtualDevice() ? ParseDeviceString("SPMD:0") + : torch_xla::GetCurrentDevice(); auto results = XLAGraphExecutor::Get()->ExecuteComputationWithBarrier( hash, graph_inputs, device); std::vector retlist; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 657d6a0b6051..af0397cc0717 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -520,6 +520,10 @@ PjRtComputationClient::ExecuteComputation( << device; // Grab the shared lock and block the `WaitDeviceOps` until buffer is // ready. + // TODO(JackCaoG): This lock should acquired outside of the lockfn and + // passed in. It is possible that lockfn started after ExecuteComputation + // released the xla_graph_executor lock, which will create a short windows + // where device is unlcoked while execution is still running. auto lock = lock_device_shared(device); TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device << " Done"; diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 44efa4f394b8..6e86dcaa5954 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -338,6 +338,7 @@ torch::lazy::Value XLATensor::GetIrValue() const { c10::optional tensor_data = CurrentTensorData(); XLA_CHECK(tensor_data); AssignIrValue(GetIrValueForTensor(*tensor_data, GetDevice())); + data()->tensor_data = c10::nullopt; return data()->ir_value; } @@ -492,9 +493,8 @@ void XLATensor::SetTensor(at::Tensor tensor) { } void XLATensor::UpdateFromTensor(at::Tensor tensor, bool sync) { - torch::lazy::BackendDevice device = ShardingUtil::UseVirtualDevice() - ? ParseDeviceString("SPMD:0") - : GetDevice(); + torch::lazy::BackendDevice device = + UseVirtualDevice() ? ParseDeviceString("SPMD:0") : GetDevice(); if (sync) { at::Tensor typed_tensor = torch::lazy::CopyTensor(tensor, dtype(), /*copy=*/false); diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index cacd12ac4183..93c8311eb8f8 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -106,13 +106,20 @@ bool Use32BitLong() { return use_32bit_long; } +bool IsTpuDevice(XlaDeviceType hw_type) { + static bool spmd_device_is_tpu = + (hw_type == XlaDeviceType::SPMD) && + runtime::GetComputationClient()->GetDefaultDevice().find("TPU") == 0; + return (hw_type == XlaDeviceType::TPU) || spmd_device_is_tpu; +} + xla::PrimitiveType XlaTypeFromTensorType( at::ScalarType scalar_type, const torch::lazy::BackendDevice& device) { XlaDeviceType hw_type = static_cast(device.type()); switch (scalar_type) { case at::ScalarType::Double: - return hw_type != XlaDeviceType::TPU ? xla::PrimitiveType::F64 - : xla::PrimitiveType::F32; + return !IsTpuDevice(hw_type) ? xla::PrimitiveType::F64 + : xla::PrimitiveType::F32; case at::ScalarType::Float: return xla::PrimitiveType::F32; case at::ScalarType::BFloat16: @@ -600,19 +607,7 @@ torch::lazy::BackendDataPtr TensorToXlaData( const at::Tensor& tensor, const xla::Shape& shape, const torch::lazy::BackendDevice& device) { TORCH_LAZY_TIMED("TensorToData"); - if (ShardingUtil::UseVirtualDevice()) { - // Scalar value will be replicated, no need to delay the transfer here. - // TODO(JackCaoG): fix this for more general cases. - if (device.type() == (int8_t)XlaDeviceType::SPMD && shape.rank() > 0) { - // When SPMD is enabled, we want to delay the data transfer for XLA - // tensors until the data is sharded. So, we skip the data transfer - // here and simply return a placeholder for the backend data ptr. - // Data will only be transferred via CreateTensorsData, when users - // call the mark_sharding API. - return WrapXlaData(runtime::GetComputationClient()->CreateDataPlaceholder( - "SPMD:0", shape)); - } - + if (UseVirtualDevice()) { // The tensor is bypassing the virtual device, so it should be replicated // to all devices. std::vector local_devices = @@ -856,7 +851,7 @@ std::vector CreateTensorsData( TORCH_LAZY_TIMED("TensorToData"); XLA_CHECK_EQ(tensors.size(), devices.size()); - if (ShardingUtil::UseVirtualDevice()) { + if (UseVirtualDevice()) { // When running in SPMD mode, tensors here in the unsharded // CreateTensorsData should be implicitly replicated to all devices. // This case should always apply when using SPMD regardless @@ -936,7 +931,7 @@ std::vector CreateTensorsData( std::vector source_tensors; // in std::vector new_handles; // out - if (ShardingUtil::UseVirtualDevice()) { + if (UseVirtualDevice()) { // GetLocalDevices returns the list of local devices specified by their // global ordinals (e.g. ["TPU:4", "TPU:5", "TPU:6", "TPU:7"]). std::vector local_devices = @@ -1160,8 +1155,8 @@ xla::PrimitiveType GetDevicePrimitiveType( if (DowncastBF16() || DowncastF16()) { return xla::PrimitiveType::F32; } - return hw_type != XlaDeviceType::TPU ? xla::PrimitiveType::F64 - : xla::PrimitiveType::F32; + return !IsTpuDevice(hw_type) ? xla::PrimitiveType::F64 + : xla::PrimitiveType::F32; case xla::PrimitiveType::F32: if (UseF16() || DowncastF16()) { return xla::PrimitiveType::F16; @@ -1169,18 +1164,18 @@ xla::PrimitiveType GetDevicePrimitiveType( return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16 : xla::PrimitiveType::F32; case xla::PrimitiveType::U16: - return hw_type != XlaDeviceType::TPU ? xla::PrimitiveType::U16 - : xla::PrimitiveType::U32; + return !IsTpuDevice(hw_type) ? xla::PrimitiveType::U16 + : xla::PrimitiveType::U32; case xla::PrimitiveType::S16: - return hw_type != XlaDeviceType::TPU ? xla::PrimitiveType::S16 - : xla::PrimitiveType::S32; + return !IsTpuDevice(hw_type) ? xla::PrimitiveType::S16 + : xla::PrimitiveType::S32; case xla::PrimitiveType::S64: return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64; case xla::PrimitiveType::U64: return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64; case xla::PrimitiveType::C128: - return hw_type != XlaDeviceType::TPU ? xla::PrimitiveType::C128 - : xla::PrimitiveType::C64; + return !IsTpuDevice(hw_type) ? xla::PrimitiveType::C128 + : xla::PrimitiveType::C64; default: return type; } diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 26f84e854062..2714d6ef6fb0 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -81,7 +81,6 @@ auto XLAGraphExecutor::DeviceContextArena::Get() -> DeviceContextArena* { std::vector XLAGraphExecutor::DeviceContextArena::GetLiveTensors( const torch::lazy::BackendDevice* device) { std::vector tensors; - torch::lazy::BackendDevice virtual_device = GetVirtualDevice(); auto fn = [&](DeviceContext* devctx) { std::lock_guard lock(devctx->lock); for (auto& uid_wptr : devctx->tensors_data) { @@ -93,8 +92,6 @@ std::vector XLAGraphExecutor::DeviceContextArena::GetLiveTensors( } }; ForAllDeviceContexts(fn, device); - // TODO(JackCaoG): all tensors should be on spmd:0 in SPMD mode. - ForAllDeviceContexts(fn, &virtual_device); return tensors; } @@ -398,15 +395,20 @@ void XLAGraphExecutor::WaitDeviceOps(absl::Span devices) { wait_devices.insert(ParseDeviceString(device_str)); } } else { - for (auto& device_str : - runtime::GetComputationClient()->GetLocalDevices()) { - wait_devices.insert(ParseDeviceString(device_str)); + if (UseVirtualDevice()) { + wait_devices.insert(ParseDeviceString("SPMD:0")); + } else { + for (auto& device_str : + runtime::GetComputationClient()->GetLocalDevices()) { + wait_devices.insert(ParseDeviceString(device_str)); + } } } // The DeviceLockerArena::Get()->LockDevices() API returns a vector of // torch::lazy::ExceptionCleanup object, which is going to be freed // immediately, turning this operation into a lock barrier. DeviceLockerArena::Get()->LockDevices(wait_devices); + TF_VLOG(4) << "XLAGraphExecutor::WaitDeviceOps completed"; } std::vector XLAGraphExecutor::GetTensors( @@ -505,10 +507,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( tsl::profiler::TraceMeLevel::kInfo); runtime::util::Unique unique_device; for (size_t i = 0; i < tensors.size(); ++i) { - // TODO(JackCaoG): all tensors should be on spmd:0 in SPMD mode. - if (tensors[i]->GetDevice().toString() != "SPMD:0") { - unique_device.set(tensors[i]->GetDevice()); - } + unique_device.set(tensors[i]->GetDevice()); } SyncTensorCollection coll; if (!unique_device) { @@ -649,7 +648,7 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( } std::vector sharding_specs(placeholders.size()); - if (ShardingUtil::UseVirtualDevice()) { + if (UseVirtualDevice()) { ShardingUtil::PrepareOutputShardingPropagation( placeholders, sharding_specs, output_shapes, cachedComputation->computation, device); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 36c723152811..ea91fb2b3118 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -141,20 +141,6 @@ std::vector> ExtractGroupMembers( } // namespace -bool ShouldUseVirtualDevice() { - bool use_virtual_device = - runtime::sys_util::GetEnvBool("XLA_USE_SPMD", false); - if (use_virtual_device) { - TF_LOG(INFO) << "Using SPMD virtual device optimization"; - } - return use_virtual_device; -} - -bool ShardingUtil::UseVirtualDevice() { - static bool use_virtual_device = ShouldUseVirtualDevice(); - return use_virtual_device; -} - bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) { bool is_sharded = false; for (std::pair elem : diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 0f356651a634..86a18dc83f89 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -28,10 +28,6 @@ class ShardingUtil { // Determine the ShardingType of the given xla::OpSharding. static ShardingType GetShardingType(xla::OpSharding& sharding); - // Test whether the XLA_USE_SPMD environment variable is set to enable the - // virtual device optimization. - static bool UseVirtualDevice(); - // Annotates HLO instructions in the lowered computation and returns true if // the computation needs to be compiled with SPMD partitioning. For this call // to be effective, this needs to be called after the lowering and before diff --git a/torch_xla/experimental/xla_sharded_tensor.py b/torch_xla/experimental/xla_sharded_tensor.py index 4ab0117f4c67..ce423b3918f8 100644 --- a/torch_xla/experimental/xla_sharded_tensor.py +++ b/torch_xla/experimental/xla_sharded_tensor.py @@ -109,12 +109,9 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs): # which results from the sharding. @property def local_shards(self) -> List[XLAShard]: - shards = torch_xla._XLAC._get_local_shards(self.global_tensor) - devices = [str(shard.device) for shard in shards] + shards, devices = torch_xla._XLAC._get_local_shards(self.global_tensor) indices = torch_xla._XLAC._get_local_shard_indices(self.global_tensor) - return [ - XLAShard(s.cpu(), i, d) for s, i, d in zip(shards, indices, devices) - ] + return [XLAShard(s, i, d) for s, i, d in zip(shards, indices, devices)] # Load the given list of local shards into the underlying tensor's data # on the local devices. From 07d6f7f59c99cb016a4e3ec21f7d5b4f796f03a0 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Tue, 11 Jul 2023 17:46:14 -0700 Subject: [PATCH 04/20] Revert pr https://github.com/pytorch/xla/pull/2682 (#5215) --- torch_xla/csrc/tensor_impl.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 63d18b9f54a2..62804b6f2f2e 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -80,7 +80,6 @@ XLATensorImpl::XLATensorImpl(XLATensor&& tensor) auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA); key_set_ = (key_set_ - autocast_xla_ks) | autocast_cuda_ks; } - is_non_overlapping_and_dense_ = false; const_cast(this)->SetupSizeProperties(); set_sizes_and_strides(sym_sizes_, c10::fromIntArrayRefSlow( sizes_and_strides_.strides_arrayref())); From 9a0e24dec56e9dd69ee06424a289d38c06cebd84 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 11 Jul 2023 18:03:02 -0700 Subject: [PATCH 05/20] Make README more actionable (#5262) * Make README more actionable * move profiling guide link * text wrapping --- README.md | 298 +++++++++++++++++++++++++++++++--------------------- docs/gpu.md | 11 +- 2 files changed, 188 insertions(+), 121 deletions(-) diff --git a/README.md b/README.md index d27440b500de..a2b50e1df7f5 100644 --- a/README.md +++ b/README.md @@ -1,125 +1,133 @@ # PyTorch/XLA -Current CI status: [![CircleCI](https://circleci.com/gh/pytorch/xla.svg?style=svg)](https://circleci.com/gh/pytorch/xla) +Current CI status: ![GitHub Actions +status](https://github.com/pytorch/xla/actions/workflows/build_and_test.yml/badge.svg) -PyTorch/XLA is a Python package that uses the -[XLA deep learning compiler](https://www.tensorflow.org/xla) -to connect the [PyTorch deep learning framework](https://pytorch.org/) and -[Cloud TPUs](https://cloud.google.com/tpu/). You can try it right now, for free, -on a single Cloud TPU with [Google Colab](https://colab.research.google.com/), -and use it in production and on Cloud TPU Pods -with [Google Cloud](https://cloud.google.com/gcp). +PyTorch/XLA is a Python package that uses the [XLA deep learning +compiler](https://www.tensorflow.org/xla) to connect the [PyTorch deep learning +framework](https://pytorch.org/) and [Cloud +TPUs](https://cloud.google.com/tpu/). You can try it right now, for free, on a +single Cloud TPU VM with +[Kaggle](https://www.kaggle.com/discussions/product-feedback/369338)! -Take a look at one of our Colab notebooks to quickly try different PyTorch networks -running on Cloud TPUs and learn how to use Cloud TPUs as PyTorch devices: +Take a look at one of our [Kaggle +notebooks](https://github.com/pytorch/xla/tree/master/contrib/kaggle) to get +started: -* [Getting Started with PyTorch on Cloud TPUs](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb) -* [Training AlexNet on Fashion MNIST with a single Cloud TPU Core](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/single-core-alexnet-fashion-mnist.ipynb) -* [Training AlexNet on Fashion MNIST with multiple Cloud TPU Cores](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/multi-core-alexnet-fashion-mnist.ipynb) -* [Fast Neural Style Transfer (NeurIPS 2019 Demo)](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/style_transfer_inference.ipynb) -* [Training A Simple Convolutional Network on MNIST](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/mnist-training.ipynb) -* [Training a ResNet18 Network on CIFAR10](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/resnet18-training.ipynb) -* [ImageNet Inference with ResNet50](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/resnet50-inference.ipynb) -* [Training DC-GAN using Colab Cloud TPU](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/DC-GAN.ipynb) +* [Stable Diffusion with PyTorch/XLA + 2.0](https://github.com/pytorch/xla/blob/master/contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb) +* [Distributed PyTorch/XLA + Basics](https://github.com/pytorch/xla/blob/master/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb) -The rest of this README covers: +## Getting Started -* [User Guide & Best Practices](#user-guide--best-practices) -* [Running PyTorch on Cloud TPUs and GPU](#running-pytorchxla-on-cloud-tpu-and-gpu) -Google Cloud also runs networks faster than Google Colab. -* [Available docker images and wheels](#available-docker-images-and-wheels) -* [Performance Profiling and Auto-Metrics Analysis](#performance-profiling-and-auto-metrics-analysis) -* [Troubleshooting](#troubleshooting) -* [Providing Feedback](#providing-feedback) -* [Building and Contributing to PyTorch/XLA](#contributing) -* [Additional Reads](#additional-reads) +To install PyTorch/XLA a new VM: +``` +pip install torch~=2.0.0 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl +``` +To update your existing training loop, make the following changes: -Additional information on PyTorch/XLA, including a description of its -semantics and functions, is available at [PyTorch.org](http://pytorch.org/xla/). - -## User Guide & Best Practices - -Our comprehensive user guides are available at: - -[Documentation for the latest release](https://pytorch.org/xla) - -[Documentation for master branch](https://pytorch.org/xla/master) +``` +-import torch.multiprocessing as mp ++import torch_xla.core.xla_model as xm ++import torch_xla.distributed.parallel_loader as pl ++import torch_xla.distributed.xla_multiprocessing as xmp + + def _mp_fn(index): + ... + ++ # Move the model paramters to your XLA device ++ model.to(xm.xla_device()) ++ ++ # MpDeviceLoader preloads data to the XLA device ++ xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device()) + +- for inputs, labels in train_loader: ++ for inputs, labels in xla_train_loader: + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() +- optimizer.step() ++ ++ # `xm.optimizer_step` combines gradients across replocas ++ xm.optimizer_step() + + if __name__ == '__main__': +- mp.spawn(_mp_fn, args=(), nprocs=world_size) ++ # xmp.spawn automatically selects the correct world size ++ xmp.spawn(_mp_fn, args=()) +``` -See the [API Guide](API_GUIDE.md) for best practices when writing networks that -run on XLA devices(TPU, GPU, CPU and...) +If you're using `DistributedDataParallel`, make the following changes: -## Running PyTorch/XLA on Cloud TPU and GPU -* [Running on a single Cloud TPU](#running-on-a-single-cloud-tpu-vm) -* [Running on a Cloud TPU Pod](#how-to-run-on-tpu-vm-pods-distributed-training) -* [Running on a Cloud GPU](docs/gpu.md) +``` + import torch.distributed as dist +-import torch.multiprocessing as mp ++import torch_xla.core.xla_model as xm ++import torch_xla.distributed.parallel_loader as pl ++import torch_xla.distributed.xla_multiprocessing as xmp + + def _mp_fn(rank, world_size): + ... + +- os.environ['MASTER_ADDR'] = 'localhost' +- os.environ['MASTER_PORT'] = '12355' +- dist.init_process_group("gloo", rank=rank, world_size=world_size) ++ # Rank and world size are inferred from the XLA device runtime ++ dist.init_process_group("xla", init_method='pjrt://') ++ ++ model.to(xm.xla_device()) ++ # `gradient_as_bucket_view=tpu` required for XLA ++ ddp_model = DDP(model, gradient_as_bucket_view=True) + +- model = model.to(rank) +- ddp_model = DDP(model, device_ids=[rank]) ++ xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device()) + +- for inputs, labels in train_loader: ++ for inputs, labels in xla_train_loader: + optimizer.zero_grad() + outputs = ddp_model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + + if __name__ == '__main__': +- mp.spawn(_mp_fn, args=(), nprocs=world_size) ++ xmp.spawn(_mp_fn, args=()) +``` ---- -## Running on a Single Cloud TPU VM +Additional information on PyTorch/XLA, including a description of its semantics +and functions, is available at [PyTorch.org](http://pytorch.org/xla/). See the +[API Guide](API_GUIDE.md) for best practices when writing networks that run on +XLA devices (TPU, GPU, CPU and...). -Google Cloud offers TPU VMs for more transparent and easier access to the TPU hardware. This is our **recommended way** of running PyTorch/XLA on Cloud TPU. Please check out our [Cloud TPU VM User Guide](https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm). To learn more about the Cloud TPU System Architecture, please check out [this doc](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_vms). +Our comprehensive user guides are available at: +[Documentation for the latest release](https://pytorch.org/xla) ---- +[Documentation for master branch](https://pytorch.org/xla/master) -## How to Run on TPU VM Pods (distributed training) -If a single TPU VM does not suit your requirement, you can consider using TPU Pod. TPU Pod is a collection of TPU devices connected by dedicated high-speed network interfaces. Please checkout our [Cloud TPU VM Pod User Guide](https://cloud.google.com/tpu/docs/pytorch-pods). +## PyTorch/XLA tutorials +* [Cloud TPU VM + quickstart](https://cloud.google.com/tpu/docs/run-calculation-pytorch) +* [Cloud TPU Pod slice + quickstart](https://cloud.google.com/tpu/docs/pytorch-pods) +* [Profiling on TPU + VM](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) +* [GPU guide](docs/gpu.md) ## Available docker images and wheels -### Docker -The following pre-built docker images are available. For running dockers, check [this doc](https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm#docker-tpuvm) for TPUVM and [this doc](https://github.com/pytorch/xla/blob/master/docs/gpu.md#docker) for GPU. - -| Version | Cloud TPU VMs Docker | -| --- | ----------- | -2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm` | -1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.8_tpuvm` | -nightly python 3.10 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm` | -nightly python 3.8 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm` | -nightly python 3.10(>= 2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_YYYYMMDD` | -nightly python 3.8(>= 2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_YYYYMMDD` | -nightly at date(< 2023/04/25) | `gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm_YYYYMMDD` | - -
- -| Version | GPU CUDA 11.8 + Python 3.8 Docker | -| --- | ----------- | -| 2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.8` | -| nightly | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8` | -| nightly at date(>=2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8_YYYYMMDD` | -| nightly at date(<2023/04/25) | `gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.8_YYYYMMDD` | - -
- -| Version | GPU CUDA 11.7 + Python 3.8 Docker | -| --- | ----------- | -| 2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.7` | -| nightly | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.7` | -| nightly at date(>=2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.7_YYYYMMDD` | -| nightly at date(<2023/04/25) | `gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.7_YYYYMMDD` | - -
- -| Version | GPU CUDA 11.2 + Python 3.8 Docker | -| --- | ----------- | -| 1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.8_cuda_11.2` | - -
- -| Version | GPU CUDA 11.2 + Python 3.7 Docker | -| --- | ----------- | -1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.7_cuda_11.2` | -1.12 | `gcr.io/tpu-pytorch/xla:r1.12_3.7_cuda_11.2` | - - - -To run on [compute instances with GPUs](https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus). ### Wheel + | Version | Cloud TPU VMs Wheel | | --- | ----------- | | 2.0 | `https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl` | @@ -132,7 +140,9 @@ To run on [compute instances with GPUs](https://cloud.google.com/compute/docs/gp
-Note: For TPU Pod customers using XRT (our legacy runtime), we have custom wheels for `torch`, `torchvision`, and `torch_xla` at `https://storage.googleapis.com/tpu-pytorch/wheels/xrt`. +Note: For TPU Pod customers using XRT (our legacy runtime), we have custom +wheels for `torch`, `torchvision`, and `torch_xla` at +`https://storage.googleapis.com/tpu-pytorch/wheels/xrt`. | Package | Cloud TPU VMs Wheel (XRT on Pod, Legacy Only) | | --- | ----------- | @@ -167,11 +177,14 @@ Note: For TPU Pod customers using XRT (our legacy runtime), we have custom wheel | --- | ----------- | | 2.0 | `https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl` | -You can also add `+yyyymmdd` after `torch_xla-nightly` to get the nightly wheel of a specified date. To get the companion pytorch and torchvision nightly wheel, replace the `torch_xla` with `torch` or `torchvision` on above wheel links. +You can also add `+yyyymmdd` after `torch_xla-nightly` to get the nightly wheel +of a specified date. To get the companion pytorch and torchvision nightly wheel, +replace the `torch_xla` with `torch` or `torchvision` on above wheel links. -### Installing libtpu +#### Installing libtpu (before PyTorch/XLA 2.0) -For PyTorch/XLA release r2.0 and older and when developing PyTorch/XLA, install the `libtpu` pip package with the following command: +For PyTorch/XLA release r2.0 and older and when developing PyTorch/XLA, install +the `libtpu` pip package with the following command: ``` pip3 install torch_xla[tpuvm] @@ -179,36 +192,87 @@ pip3 install torch_xla[tpuvm] This is only required on Cloud TPU VMs. -## Performance Profiling and Auto-Metrics Analysis +### Docker -With PyTorch/XLA we provide a set of performance profiling tooling and auto-metrics analysis which you can check the following resources: -* [Official tutorial](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) -* [Colab notebook](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/pytorch-xla-profiling-colab.ipynb) -* [Sample MNIST training script with profiling](https://github.com/pytorch/xla/blob/master/test/test_profile_mp_mnist.py) -* [Utility script for capturing performance profiles](https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py) +| Version | Cloud TPU VMs Docker | +| --- | ----------- | +2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm` | +1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.8_tpuvm` | +nightly python 3.10 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm` | +nightly python 3.8 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm` | +nightly python 3.10(>= 2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_YYYYMMDD` | +nightly python 3.8(>= 2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_YYYYMMDD` | +nightly at date(< 2023/04/25) | `gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm_YYYYMMDD` | + +
+ +| Version | GPU CUDA 11.8 + Python 3.8 Docker | +| --- | ----------- | +| 2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.8` | +| nightly | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8` | +| nightly at date(>=2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8_YYYYMMDD` | +| nightly at date(<2023/04/25) | `gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.8_YYYYMMDD` | + +
+ +| Version | GPU CUDA 11.7 + Python 3.8 Docker | +| --- | ----------- | +| 2.0 | `gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.7` | +| nightly | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.7` | +| nightly at date(>=2023/04/25) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.7_YYYYMMDD` | +| nightly at date(<2023/04/25) | `gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.7_YYYYMMDD` | + +
+ +| Version | GPU CUDA 11.2 + Python 3.8 Docker | +| --- | ----------- | +| 1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.8_cuda_11.2` | + +
+ +| Version | GPU CUDA 11.2 + Python 3.7 Docker | +| --- | ----------- | +1.13 | `gcr.io/tpu-pytorch/xla:r1.13_3.7_cuda_11.2` | +1.12 | `gcr.io/tpu-pytorch/xla:r1.12_3.7_cuda_11.2` | + + +To run on [compute instances with +GPUs](https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus). ## Troubleshooting -If PyTorch/XLA isn't performing as expected, see the -[troubleshooting guide](TROUBLESHOOTING.md), which has suggestions for -debugging and optimizing your network(s). +If PyTorch/XLA isn't performing as expected, see the [troubleshooting +guide](TROUBLESHOOTING.md), which has suggestions for debugging and optimizing +your network(s). ## Providing Feedback The PyTorch/XLA team is always happy to hear from users and OSS contributors! -The best way to reach out is by filing an issue on this Github. Questions, -bug reports, feature requests, build issues, etc. are all welcome! +The best way to reach out is by filing an issue on this Github. Questions, bug +reports, feature requests, build issues, etc. are all welcome! ## Contributing See the [contribution guide](CONTRIBUTING.md). ## Disclaimer -This repository is jointly operated and maintained by Google, Facebook and a number of individual contributors listed in the [CONTRIBUTORS](https://github.com/pytorch/xla/graphs/contributors) file. For questions directed at Facebook, please send an email to opensource@fb.com. For questions directed at Google, please send an email to pytorch-xla@googlegroups.com. For all other questions, please open up an issue in this repository [here](https://github.com/pytorch/xla/issues). + +This repository is jointly operated and maintained by Google, Facebook and a +number of individual contributors listed in the +[CONTRIBUTORS](https://github.com/pytorch/xla/graphs/contributors) file. For +questions directed at Facebook, please send an email to opensource@fb.com. For +questions directed at Google, please send an email to +pytorch-xla@googlegroups.com. For all other questions, please open up an issue +in this repository [here](https://github.com/pytorch/xla/issues). ## Additional Reads + You can find additional useful reading materials in -* [Performance debugging on Cloud TPU VM](https://cloud.google.com/blog/topics/developers-practitioners/pytorchxla-performance-debugging-tpu-vm-part-1) -* [Lazy tensor intro](https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/) -* [Scaling deep learning workloads with PyTorch / XLA and Cloud TPU VM](https://cloud.google.com/blog/topics/developers-practitioners/scaling-deep-learning-workloads-pytorch-xla-and-cloud-tpu-vm) -* [Scaling PyTorch models on Cloud TPUs with FSDP](https://pytorch.org/blog/scaling-pytorch-models-on-cloud-tpus-with-fsdp/) +* [Performance debugging on Cloud TPU + VM](https://cloud.google.com/blog/topics/developers-practitioners/pytorchxla-performance-debugging-tpu-vm-part-1) +* [Lazy tensor + intro](https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/) +* [Scaling deep learning workloads with PyTorch / XLA and Cloud TPU + VM](https://cloud.google.com/blog/topics/developers-practitioners/scaling-deep-learning-workloads-pytorch-xla-and-cloud-tpu-vm) +* [Scaling PyTorch models on Cloud TPUs with + FSDP](https://pytorch.org/blog/scaling-pytorch-models-on-cloud-tpus-with-fsdp/) diff --git a/docs/gpu.md b/docs/gpu.md index 99ff86556612..afb8291ba30d 100644 --- a/docs/gpu.md +++ b/docs/gpu.md @@ -1,11 +1,14 @@ # How to run with PyTorch/XLA:GPU -PyTorch/XLA enables PyTorch users to utilize the XLA compiler which supports accelerators including TPU, GPU, CPU and … This doc will go over the basic steps to run PyTorch/XLA on a nvidia gpu instance +PyTorch/XLA enables PyTorch users to utilize the XLA compiler which supports accelerators including TPU, GPU, and CPU This doc will go over the basic steps to run PyTorch/XLA on a nvidia gpu instance ## Create a GPU instance Pytorch/XLA currently publish prebuilt docker images and wheels with cuda11.7/8 and python 3.8. We recommend users to create a GPU instance with corresponding config. For a full list of docker images and wheels, please refer to [this doc](https://github.com/pytorch/xla/tree/jackcao/gpu_doc#-available-images-and-wheels). ## Environment Setup + +To create a GPU VM in Google Compute Engine, follow the [Google Cloud documentation](https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus). + ### Docker ``` sudo docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.7 @@ -23,7 +26,7 @@ Note that you need to restart the docker to make gpu devices visible in the dock ``` (pytorch) root@20ab2c7a2d06:/# nvidia-smi -Thu Dec 8 06:24:29 2022 +Thu Dec 8 06:24:29 2022 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 | |-------------------------------+----------------------+----------------------+ @@ -35,7 +38,7 @@ Thu Dec 8 06:24:29 2022 | N/A 36C P0 38W / 300W | 0MiB / 16384MiB | 1% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ - + +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | @@ -70,4 +73,4 @@ Epoch 1 train begin 06:12:38 | Training Device=xla:0/0 Epoch=1 Step=120 Loss=2.68816 Rate=388.35 GlobalRate=169.49 Time=06:14:09 ``` ## AMP (AUTOMATIC MIXED PRECISION) -AMP is very useful on GPU training and PyTorch/XLA reuse Cuda's AMP rule. You can checkout our [mnist example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_amp.py) and [imagenet example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_amp.py). Note that we also used a modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) to avoid the additional sync between device and host. \ No newline at end of file +AMP is very useful on GPU training and PyTorch/XLA reuse Cuda's AMP rule. You can checkout our [mnist example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_amp.py) and [imagenet example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_amp.py). Note that we also used a modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) to avoid the additional sync between device and host. From cc4f30437841690ff08bfa7336e346148967b953 Mon Sep 17 00:00:00 2001 From: Mohit Khatwani <118776932+khatwanimohit@users.noreply.github.com> Date: Wed, 12 Jul 2023 10:16:45 -0700 Subject: [PATCH 06/20] [SPMD] Use xs.Mesh in test_2d_tensor_3d_mesh (#5295) * use mesh in test_2d_tensor_3d_mesh * remove attributes patch --- test/spmd/test_xla_sharding.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 6f175c69862f..100b810bb0c6 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -480,32 +480,14 @@ def test_xla_sharded_hlo_dump(self): # scalar 5 should be replicated self.assertIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo) - @unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU v2") - @patch('torch_xla.runtime.global_device_attributes') - @patch('torch_xla.core.xla_model.xla_device_hw') - def test_2d_tensor_3d_mesh(self, xla_device_mock, device_attributes_mock): - xla_device_mock.return_value = "TPU" - device_attributes_mock.return_value = [{ - 'coords': [0, 0, 0], - 'core_on_chip': 0 - }, { - 'coords': [1, 0, 0], - 'core_on_chip': 0 - }, { - 'coords': [0, 1, 0], - 'core_on_chip': 0 - }, { - 'coords': [1, 1, 0], - 'core_on_chip': 0 - }] + def test_2d_tensor_3d_mesh(self): ct1 = torch.randn(16, 16, device='cpu') ct2 = torch.randn(16, 16, device='cpu') expected = ct1 + ct2 t1 = ct1.to(xm.xla_device()) t2 = ct2.to(xm.xla_device()) - mesh = self._get_hybrid_mesh((1, self.n_devices, 1), - axis_names=('data', 'fsdp', 'tensor')) + mesh = self._get_mesh((1, self.n_devices, 1)) t1 = xs.mark_sharding(t1, mesh, partition_spec=(1, 2)) if self.n_devices > 1: hlo = torch_xla._XLAC._get_xla_tensors_hlo([t1.global_tensor]) From af0e0c39218af22c653e29681a1a956ff7e5c8f8 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 12 Jul 2023 14:56:09 -0700 Subject: [PATCH 07/20] [SPMD] Add FSDP sharding for test_train_spmd_linear_model.py (#5299) Summary: This diff adds FSDP sharding for test_train_spmd_linear_model.py. Test Plan: PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_train_spmd_linear_model.py --sharding fsdp --- test/spmd/test_train_spmd_linear_model.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_train_spmd_linear_model.py b/test/spmd/test_train_spmd_linear_model.py index 791bbee546f0..c86c496ff357 100644 --- a/test/spmd/test_train_spmd_linear_model.py +++ b/test/spmd/test_train_spmd_linear_model.py @@ -15,7 +15,7 @@ MODEL_OPTS = { '--sharding': { - 'choices': ['batch', 'megatron-lm'], + 'choices': ['batch', 'megatron-lm', 'fsdp'], 'nargs': '+', 'default': [], }, @@ -58,7 +58,6 @@ def forward(self, x): def train(): print('===> Preparing data..') - num_epochs = 18 lr = 0.1 train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, FLAGS.input_dim), @@ -78,6 +77,14 @@ def train(): train_loader = pl.MpDeviceLoader( train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1))) + if 'fsdp' in FLAGS.sharding: + train_loader = pl.MpDeviceLoader( + train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1))) + print('Sharding model weights') + # Shard the weights according to their 0th dim + xs.mark_sharding(model.fc1.weight, mesh, (0, 1)) + xs.mark_sharding(model.fc2.weight, mesh, (0, 1)) + if 'megatron-lm' in FLAGS.sharding: print('Sharding model weights') # Shard the first layer's weights row-wise From a27382a3ed03ea765b13121d34804ef45c3f65d0 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 12 Jul 2023 22:14:32 -0700 Subject: [PATCH 08/20] [SPMD] Avoid recompilations in xs.mark_sharding() (#5300) Summary: This pull requests fixes the recompilation issue in xs.mark_sharding(). xtensor->GetXlaData() will compile the program if xtensor is an IR in order to get the BackendData. I believe this is not intended given the error message below suggests only data type xtensors are supported. Test Plan: PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_xla_sharding.py --- test/spmd/test_xla_sharding.py | 19 +++++++++++++++---- torch_xla/csrc/init_python_bindings.cpp | 4 ++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 100b810bb0c6..53974700fa6f 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -223,9 +223,14 @@ def test_mark_sharding_4d(self): def test_mark_sharding_partial(self): device = xm.xla_device() - t1 = torch.randn(4, 4).to(xm.xla_device()) - t2 = torch.randn(4, 4).to(xm.xla_device()) - expected = (t1 @ t2).cpu() + t1 = torch.randn(4, 4).to(device) + t2 = torch.randn(4, 4).to(device) + # Somehow the eager cpu result is different from the xla result. + expected = t1 @ t2 + # To re-materialize t1 and t2. + xm.mark_step() + xm.wait_device_ops() + expected = expected.cpu() # Shard along two axes if four or more devices are available z_dim = 2 if self.n_devices >= 4 else 1 @@ -255,7 +260,12 @@ def test_partial_replication_addmm(self): xx = torch.randn(16, 128).to(device) xw = torch.randn(128, 256).to(device) xb = torch.randn(16, 256).to(device) - expected = (xx @ xw + xb).cpu() + + # Somehow the eager cpu result is different from the xla result. + expected = xx @ xw + xb + xm.mark_step() # To re-materialize xx, xw, and xb. + xm.wait_device_ops() + expected = expected.cpu() xs.mark_sharding(xx, mesh, (0, None)) xs.mark_sharding(xw, mesh, (None, 1)) @@ -480,6 +490,7 @@ def test_xla_sharded_hlo_dump(self): # scalar 5 should be replicated self.assertIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo) + @unittest.skip("TODO(alanwaketan): Implement IR sharding to re-enable this.") def test_2d_tensor_3d_mesh(self): ct1 = torch.randn(16, 16, device='cpu') ct2 = torch.randn(16, 16, device='cpu') diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index f9b4d9569e1c..c1cc829d2683 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1374,9 +1374,9 @@ void InitXlaModuleBindings(py::module m) { // If the at::Tensor data is not present, we need to re-download the // tensor from the physical device to CPU. In that case, the value // must be present on the backend device. - XLA_CHECK(xtensor->GetXlaData() != nullptr && + XLA_CHECK(xtensor->CurrentDataHandle() && xtensor->CurrentDataHandle()->HasValue()) - << "Cannot shard tensor. Data not present on any device."; + << "Cannot shard tensor. Data does not present on any device."; std::vector xla_tensors{xtensor}; cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; } From ff974276febb8e60a29fb00f7fb68b5236d157d1 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 13 Jul 2023 01:00:43 -0700 Subject: [PATCH 09/20] [SPMD] Support mark_sharding on IRs (#5301) Summary: This pull requests fixes the recompilation issue in xs.mark_sharding(). xtensor->GetXlaData() will compile the program if xtensor is an IR in order to get the BackendData. I believe this is not intended given the error message below suggests only data type xtensors are supported. Test Plan: PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_xla_sharding.py --- test/spmd/test_xla_sharding.py | 19 ++++++++++++++++++- torch_xla/csrc/init_python_bindings.cpp | 10 ++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 53974700fa6f..af2db93a8061 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -490,7 +490,6 @@ def test_xla_sharded_hlo_dump(self): # scalar 5 should be replicated self.assertIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo) - @unittest.skip("TODO(alanwaketan): Implement IR sharding to re-enable this.") def test_2d_tensor_3d_mesh(self): ct1 = torch.randn(16, 16, device='cpu') ct2 = torch.randn(16, 16, device='cpu') @@ -567,6 +566,24 @@ def test_hybrid_mesh(self, xla_device_mock, device_attributes_mock): self.assertEqual(hybrid_mesh.get_logical_mesh().tolist(), [[0, 1], [2, 3], [4, 5], [6, 7]]) + def test_mark_sharding_ir(self): + t1 = torch.randn(1, 128, device='cpu') + t2 = torch.randn(1, 128, device='cpu') + expected = t1 + t2 + + xt1 = t1.to(xm.xla_device()) + xt2 = t2.to(xm.xla_device()) + actual = xt1 + xt2 + xs.mark_sharding(actual, self._get_mesh((1, self.n_devices)), (0, 1)) + + if self.n_devices > 1: + annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join( + [str(i) for i in range(self.n_devices)])) + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(actual)) + + self.assertTrue(torch.allclose(expected, actual.cpu())) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c1cc829d2683..fa87a016b190 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1350,6 +1350,16 @@ void InitXlaModuleBindings(py::module m) { xtensor->shape(), static_cast(xtensor->GetDevice().type()))); + // For IR values, we directly attach the sharding spec to the xtensor. + if (xtensor->CurrentIrValue()) { + // TODO(alanwaketan): Do we want to check if there is any existing + // sharding spec? It seems okay to directly overwrite it. + xtensor->SetShardingSpec(*new_sharding_spec); + return; + } + + // For data, we need to deal with the data transfers between + // host and device. at::Tensor cpu_tensor; if (xtensor->CurrentTensorData().has_value()) { TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); From 21784cef8bff784fa2e33da78fcc6b9a99c88e16 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 13 Jul 2023 14:44:22 -0700 Subject: [PATCH 10/20] [SPMD] Allow dumping post optimizations hlo (#5302) Summary: This pull request partial reverts the change in #5266 to re-enble dumping post optimizations hlo. Test Plan: XLA_USE_SPMD=1 PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_xla_sharded_hlo_dump_post_optimizations --- configuration.yaml | 13 +++++++++++++ test/spmd/test_xla_sharding.py | 11 +++++++++++ torch_xla/csrc/ir_dump_util.cpp | 21 ++++++++++++++++++++- 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/configuration.yaml b/configuration.yaml index 8a8a147d8fdd..22ab8b687f34 100644 --- a/configuration.yaml +++ b/configuration.yaml @@ -501,3 +501,16 @@ variables: - Release Python's GIL when transferring data from the runtime. type: bool default_value: true + XLA_STABLEHLO_COMPILE: + descripton: + - Pass StableHLO to XLA PjRt client for compilatoin. This compilation + flag is experimental. The default_value will be set to true when + StableHLO workflow is mature. + type: bool + default_value: false + XLA_DUMP_POST_OPTIMIZATIONS: + descripton: + - Dump the HLO graph after optimizations. You need to use it together + with XLA_SAVE_TENSORS_FMT='hlo' and XLA_SAVE_TENSORS_FILE='your/location'. + type: bool + default_value: false diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index af2db93a8061..8f8d7bcc1fc3 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -584,6 +584,17 @@ def test_mark_sharding_ir(self): self.assertTrue(torch.allclose(expected, actual.cpu())) + @patch.dict(os.environ, {"XLA_DUMP_POST_OPTIMIZATIONS": "1"}) + def test_xla_sharded_hlo_dump_post_optimizations(self): + t1 = torch.randn(1, 128).to(xm.xla_device()) + t2 = torch.randn(128, 1).to(xm.xla_device()) + xs.mark_sharding(t1, self._get_mesh((1, self.n_devices)), (0, 1)) + + t3 = t1 @ t2 + hlo = torch_xla._XLAC._get_xla_tensors_hlo([t3]) + if self.n_devices > 1: + self.assertIn('all-reduce', hlo) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/ir_dump_util.cpp b/torch_xla/csrc/ir_dump_util.cpp index 7c442d59fb2b..a1998c2a2b88 100644 --- a/torch_xla/csrc/ir_dump_util.cpp +++ b/torch_xla/csrc/ir_dump_util.cpp @@ -263,8 +263,27 @@ std::string DumpUtil::ToHlo(c10::ArrayRef values, // Annotate HLO sharding selectively in the compuation. // This is no-op if an instruction doesn't have any sharding annotation. - ShardingUtil::SetHloSharding(&lowering_ctx); + auto is_sharded = ShardingUtil::SetHloSharding(&lowering_ctx); xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla()); + + static bool dump_post_optimizations = + runtime::sys_util::GetEnvBool("XLA_DUMP_POST_OPTIMIZATIONS", false); + if (dump_post_optimizations) { + xla::Shape shape = MakeShapeWithDeviceLayout( + ConsumeValue(computation.GetProgramShape()).result(), + static_cast(device.type())); + std::vector instances; + instances.push_back({std::move(computation), device.toString(), + runtime::GetComputationClient()->GetCompilationDevices( + device.toString(), {}), + &shape, + /*parameter_is_tupled_arguments=*/false, is_sharded}); + std::vector> + computations = + runtime::GetComputationClient()->Compile(std::move(instances)); + computation = std::move(computations[0]->move_computation()); + } + switch (mode) { case EmitMode::kHloReadable: return ConsumeValue(runtime::util::GetComputationHloText(computation)); From 67ab9750b0acfa2313e2e6f18733992f23f98ecf Mon Sep 17 00:00:00 2001 From: Yash Shah <55116947+yashs97@users.noreply.github.com> Date: Fri, 14 Jul 2023 13:54:47 -0700 Subject: [PATCH 11/20] Add `_sharded_cpu_state_dict` for distributed checkpointing (#5288) * initiak commit * Add test workflow for `xrt` branch (#5241) * Add test workflow for `xrt` branch * Only run for PRs targeting XRT branch * Add function to generate stablehlo based callable from pytorch model (#5216) * Add function to generate stablehlo based callable from pytorch model Added function `torch_xla.experimental.stablehlo_saved_model.export_pytorch_model`. This function will take a pytorch Module and convert it into stablehlo bytecode. * Only run the main CI workflow on PRs targeting master and release branches (#5244) * Only run main CI for master and release branches. * Disabling XRT tests on main CI * AMP for TPUs v3 (#5161) * remove duplicate autocast_test (#5246) * Remove `test_experimental_pjrt_tpu.py` from TPU CI (#5247) * Install `expecttest` in xla_test_job.yaml (#5252) * Add IAM roles for cloudbuild_editors (#5251) * [Functionalization] Remove view in view_symint (#5231) * [Functionalization] Remove view in view_symint Summary: This pull request removes views in tensor_method::view_symint. Test Plan: XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=TPU python ../test/test_view_ops.py -v -k TestViewOpsXLA.test_view_view PJRT_DEVICE=TPU python ../test/test_view_ops.py -v -k TestViewOpsXLA.test_view_view * Fix linters * fixed the test * ran the linter --------- Co-authored-by: Xiongfei Wei * Delete XRT from the main branch (#5240) * Delete XRT from the main branch * Remove dead import * formatting * Remove disable_xrt build option * Fix runtime init * Revert "Remove disable_xrt build option" This reverts commit ba312e76e069bef40c8f9803a672b29409862804. * Add disable XRT option back * formatting * Prune mesh service * Remove obsolete test * Remove other run server script * Remove XRT config * Update PJRT default device test * Add a file I forgot to save * if using_pjrt -> @requires_pjrt * Remove irrelevant test case * Remove XRT env vars * fix md link * formatting * Remove extra `requires_pjrt` * merge conflicts * Add other autocast back * Add nightly build for cuda 12 (#5253) * Fix the linter command in the CI (#5254) * fix linter command * ran linter * Jack cao g/fix spmd buff is null (#5256) * Fix that non-tensor scalar can't be handled by virtual device * add test * comment * Skip calling as_strided in empty_strided_symint if the input has dynamic dimensions. (#5239) * Skip calling as_strided in empty_strided_symint. * only return empty_symint conditionally. * add a comment * Add XRT nightly builds (#5261) * Add XRT nightly builds * remove space * [OpenXLA] Migrate to pull XLA from OpenXLA (#5202) PyTorch/XLA migrate to pull XLA from OpenXLA by replacing TensorFlow with OpenXLA after deprecating XRT usage, and replace TensorFlow-pin with OpenXLA-pin to May09 * Add ToString method for both PjrtData and PjrtShardedData (#5265) * Add ToString method for both PjrtData and PjrtShardedData * on cpu same config will become replicated, dont't check actual op sharding type * Update Sharded graph HLO dumping (#5266) * Enable PjRt Client Compilation with StableHLO (#5233) * Enable xla PjRt client compilation with StableHLO * add XLA_STABLEHLO_COMPILE to configuration.yaml * fix merge conflict * dummy commit to trigger ci * Revert "dummy commit to trigger ci" This reverts commit f7aec233d18637e242427c4542b12cf65c431ebc. * Disable Bazel remote cache for forked PR (#5259) * disable bazel remote cache if gcloud key is empty * remove remote cache from setup.py * experiment with debug msg * fix flag * add more logs * skip remote chache if credential file is empty * add comment * add logs * add check in test and coverage script * fix condition in coverage test * advance branch pr * allow remote cache if gloud file isn't specified explicitly * remove dummy comment * Suppress debug symbols in OpenXLA code (#5269) * [SPMD] Sharding n-d tensor on (n+1)-d Mesh (#5268) * Make TPU detection more robust (#5271) * Clean bazel stuff on distutils clean. (#5274) * Clean bazel stuff on distutils clean * Fix python formatting * Delete unused .so file, and .lds files (#5275) * [OpenXLA] Delete unused .so file and .lds files * Fix the error when export_torch_model is given a non-tensor (#5277) However the generated StableHLO graph still hardcodes the non-tensor value. this is not correct, will fix later. * Dsiable test_simple_model_with_different_input_shape since it is curretnly broken by pytorch (#5282) * Always do build_ext in python setup.py develop (#5273) Bazel should figure out that _XLAC.so is current or not, and trigger rebuild if any cpp files changed. * Remove or improve several hardcoded TPU test conditions (#5272) * Remove or improve several hardcoded TPU test conditions * Fix test condition * Add `runtime.host_index` (#5283) * Make it an error if calling sizes() on a dynamic tensor. (#4998) * Err if calling sizes() on dynamic tensor * try to set has_symbolic_sizes_strides_ * resolve merge conflict * enable CONTINUE_ON_ERROR * fixed the python test test_SizeEq_should_not_compile_for_identical_symints * fix test_index_types * set CONTINUE_ON_ERROR to true * remove some unwanted code. * add a print * directly set has_symbolic_sizes_strides_ = true * make some fixes. * fix empty_strided_symint * ran linter * change error type in the test. * fix comments * ran linter * Fix the error where mark_step does not materalize tensors on SPMD:0 (#5281) * Fix the error where mark_step does not materalize tensors on SPMD:0 * typo * fix test_non_tensor_scalar * Disable torch._dynamo.config.automatic_dynamic_shapes (#5285) * Set torch._dynamo.config.automatic_dynamic_shapes to False * Enable DynamoInferenceBasicTest.test_simple_model_with_different_input_shape * run linter * wrap only if sharding type is non-replicated * Handle non-tensors * run linter * Call wrap_if_sharded first * Add exception in test for unsharded tensor * fix test * Use torch.Tensor instead of torch.tensor * use .cpu() only for tensors --------- Co-authored-by: Will Cromar Co-authored-by: qihqi Co-authored-by: Meghan Cowan Co-authored-by: Mateusz Lewko Co-authored-by: Jiewen Tan Co-authored-by: Xiongfei Wei Co-authored-by: Wonjoo Lee Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Co-authored-by: Manfei <41607353+ManfeiBai@users.noreply.github.com> Co-authored-by: Siyuan Liu Co-authored-by: stgpetrovic Co-authored-by: Mohit Khatwani <118776932+khatwanimohit@users.noreply.github.com> --- test/spmd/test_xla_distributed_checkpoint.py | 21 +++++++++ .../_distributed_checkpoint_helpers.py | 46 +++++++++++++++++-- .../experimental/distributed_checkpoint.py | 17 ++----- 3 files changed, 66 insertions(+), 18 deletions(-) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index b8516667bb84..0d2f3e318cce 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -1,4 +1,5 @@ import os +import sys import tempfile import unittest import test_xla_sharding_base @@ -14,6 +15,8 @@ create_default_global_save_plan, ) from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner +from torch_xla.experimental._distributed_checkpoint_helpers import ( + _sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor) class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest): @@ -244,6 +247,24 @@ def test_resolve_shard_data(self): self.assertTrue(torch.allclose(shard.data, resolved_data)) +class DistributedCheckpointHelpersTest(DistributedCheckpointTestBase): + + def test_sharded_cpu_state_dict(self): + model = self.SimpleLinear().to(xm.xla_device()) + state_dict = model.state_dict() + sharded_cpu_state_dict = _sharded_cpu_state_dict(state_dict) + self.assertCountEqual(sharded_cpu_state_dict, + ['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']) + for name, param in sharded_cpu_state_dict.items(): + if name == 'fc1.weight': + # _sharded_cpu_state_dict returns _CpuShards only for sharded tensors + if _is_sharded_tensor(param): + self.assertTrue(isinstance(param, _CpuShards)) + else: + self.assertTrue(isinstance(param, torch.Tensor)) + self.assertTrue(param.device == torch.device("cpu")) + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/experimental/_distributed_checkpoint_helpers.py b/torch_xla/experimental/_distributed_checkpoint_helpers.py index 5f1a8f6489e7..b49e7419dcd9 100644 --- a/torch_xla/experimental/_distributed_checkpoint_helpers.py +++ b/torch_xla/experimental/_distributed_checkpoint_helpers.py @@ -2,10 +2,14 @@ # stable. Once the upstream makes these stable, we should take a dependency on # their APIs. +import dataclasses + import torch +import torch_xla.experimental.xla_sharding as xs from torch.distributed.checkpoint.planner import SavePlan from typing import ( + Any, Callable, Collection, Dict, @@ -14,12 +18,13 @@ MutableMapping, Sequence, Tuple, - TypeVar, Union, cast, ) -from torch.distributed.checkpoint.metadata import ( - STATE_DICT_TYPE,) +from torch.distributed.checkpoint.metadata import (MetadataIndex, + STATE_DICT_TYPE) +from torch_xla.experimental.xla_sharding import XLAShardedTensor, ShardingType +from torch.utils._pytree import tree_map PATH_ITEM = Union[str, int] OBJ_PATH = Tuple[PATH_ITEM, ...] @@ -186,4 +191,37 @@ def narrow_tensor_by_index(tensor: torch.Tensor, offsets: Sequence[int], # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. narrowed_tensor = narrowed_tensor.narrow(idx, offset, size) - return narrowed_tensor \ No newline at end of file + return narrowed_tensor + + +def _is_sharded_tensor(x: Any) -> bool: + """Return true if the tensor's data is sharded across multiple devices""" + return isinstance( + x, XLAShardedTensor) and x.sharding_type != ShardingType.REPLICATED + + +def _unwrap_xla_sharded_tensor(x: Any) -> Any: + if isinstance(x, XLAShardedTensor): + return x.global_tensor + return x + + +@dataclasses.dataclass +class _CpuShards: + shards: List[xs.XLAShard] + global_shape: torch.Size + + +def _sharded_cpu_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """ + Converts a state_dict on XLA device to a sharded state_dict on CPU. + """ + + def move_state_dict_to_cpu(v): + v = xs.wrap_if_sharded(v) + if not _is_sharded_tensor(v): + v = _unwrap_xla_sharded_tensor(v) + return v.cpu() if isinstance(v, torch.Tensor) else v + return _CpuShards(shards=v.local_shards, global_shape=v.global_tensor.shape) + + return tree_map(move_state_dict_to_cpu, state_dict) diff --git a/torch_xla/experimental/distributed_checkpoint.py b/torch_xla/experimental/distributed_checkpoint.py index cc4738ca09ff..c2a1545cfc3f 100644 --- a/torch_xla/experimental/distributed_checkpoint.py +++ b/torch_xla/experimental/distributed_checkpoint.py @@ -33,14 +33,15 @@ ) from torch.distributed.checkpoint.utils import find_state_dict_object from torch.utils._pytree import tree_map -from torch_xla.experimental.xla_sharding import (XLAShardedTensor, XLAShard, - ShardingType) +from torch_xla.experimental.xla_sharding import XLAShardedTensor, XLAShard from torch_xla.experimental._distributed_checkpoint_helpers import ( FLATTEN_MAPPING, flatten_state_dict, dedup_tensors, + _is_sharded_tensor, set_element, narrow_tensor_by_index, + _unwrap_xla_sharded_tensor, ) from typing import Any, Dict, List, Tuple, Union @@ -373,15 +374,3 @@ def _create_xla_read_items(sharded_state_dict: STATE_DICT_TYPE, chunks = [_create_chunk_from_shard_index(index) for index in shard_indices] items.extend(create_read_items_for_chunk_list(fqn, md, chunks)) return items - - -def _is_sharded_tensor(x: Any) -> bool: - """Return true if the tensor's data is sharded across multiple devices""" - return isinstance( - x, XLAShardedTensor) and x.sharding_type != ShardingType.REPLICATED - - -def _unwrap_xla_sharded_tensor(x: Any) -> Any: - if isinstance(x, XLAShardedTensor): - return x.global_tensor - return x From 080fdcf600fd6cfeb6878ca200609b841976332c Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Mon, 17 Jul 2023 10:13:44 -0700 Subject: [PATCH 12/20] Supoort unordered sharding spec correctly (#5305) * Supoort non-ordered sharding spec correctly * use permute instead of transpose * use dim > 2 to suit TPU v3(otherwise can't be divide evenly) --- test/spmd/test_xla_sharding.py | 49 ++++++++++++++++++++++++++ torch_xla/experimental/xla_sharding.py | 24 ++++++++++--- 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 8f8d7bcc1fc3..3c5c71addde8 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -221,6 +221,55 @@ def test_mark_sharding_4d(self): actual = (xt + xt).cpu() self.assertTrue(torch.allclose(expected, actual)) + def test_mark_sharding_not_ordered_sharding_spec_2d(self): + device = xm.xla_device() + t1 = torch.randn(8, 16, device='cpu') + expected = t1 + t1 + + xt1 = t1.to(device) + # Shard along first dimension + xt1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (1, 0)) + for local_shard in xt1.local_shards: + self.assertEqual(local_shard.data.size()[0], 8 / self.n_devices) + self.assertEqual(local_shard.data.size()[1], 16) + self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) + + def test_mark_sharding_not_ordered_sharding_spec_3d(self): + device = xm.xla_device() + t1 = torch.randn(4, 8, 16, device='cpu') + expected = t1 + t1 + + xt1 = t1.to(device) + z_dim = 2 if self.n_devices >= 4 else 1 + # Expect local shard size to be [4, 8 / (self.n_devices / z_dim), 16 / z_dim] + xt1 = xs.mark_sharding(xt1, + self._get_mesh((z_dim, 1, self.n_devices // z_dim)), + (1, 2, 0)) + for local_shard in xt1.local_shards: + self.assertEqual(local_shard.data.size()[0], 4) + self.assertEqual(local_shard.data.size()[1], 8 / (self.n_devices / z_dim)) + self.assertEqual(local_shard.data.size()[2], 16 / z_dim) + self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) + + def test_mark_sharding_not_ordered_sharding_spec_4d(self): + device = xm.xla_device() + t1 = torch.randn(32, 4, 8, 16, device='cpu') + expected = t1 + t1 + + xt1 = t1.to(device) + z_dim = 2 if self.n_devices >= 4 else 1 + # Expect local shard size to be [32 / (self.n_devices / z_dim), 4, 8 , 16 / z_dim] + xt1 = xs.mark_sharding( + xt1, self._get_mesh((z_dim, 1, 1, self.n_devices // z_dim)), + (3, 1, 2, 0)) + for local_shard in xt1.local_shards: + self.assertEqual(local_shard.data.size()[0], + 32 / (self.n_devices / z_dim)) + self.assertEqual(local_shard.data.size()[1], 4) + self.assertEqual(local_shard.data.size()[2], 8) + self.assertEqual(local_shard.data.size()[3], 16 / z_dim) + self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) + def test_mark_sharding_partial(self): device = xm.xla_device() t1 = torch.randn(4, 4).to(device) diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index d3c339b3d5cb..ee7e32117b43 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -319,8 +319,24 @@ def _get_sharding_type(partition_spec: Tuple[Union[int, None]], return sharding_type -def _get_tile_assignment(mesh: Mesh) -> List[int]: - return mesh.get_logical_mesh().tolist() +def _get_tile_assignment(mesh: Mesh, + partition_spec: Tuple[Union[int, None]]) -> List[int]: + # Use Torch.tensor here to make use of the torch.transpose_ + mesh_list_tensor = torch.tensor(mesh.get_logical_mesh().tolist()) + # This is partial sharding case, tile_assigniment will be ignore in favor of + # group_assignment and replication_groups. + if (mesh_list_tensor.dim() != len(partition_spec)): + return mesh_list_tensor.tolist() + partition_spec_list = list(partition_spec) + for i in range(len(partition_spec_list)): + if partition_spec_list[i] == None: + partition_spec_list[i] = i + # We currently do not support partition_spec like [0, None, 1, 3]. The None at partition_spec[1] + # suggested that we want to replicate on Mesh[1], hence we can't use Mesh[1] in + # partition_spec[2] + assert torch.unique( + torch.tensor(partition_spec_list)).size()[0] == len(partition_spec_list) + return mesh_list_tensor.permute(partition_spec_list).tolist() def _get_group_assignment( @@ -399,7 +415,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, assert len(specs) == len(np.unique(specs)), \ f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." - tile_assignment = _get_tile_assignment(mesh) + tile_assignment = _get_tile_assignment(mesh, partition_spec) # check for sharding 2D tensor on a 3D mesh original_shape = tuple(t.shape) # number of dims to expand on tensor @@ -464,7 +480,7 @@ class ShardingSpec: @xr.requires_pjrt def __post_init__(self): partition_spec, mesh = self.partition_spec, self.mesh - self._tile_assignment = _get_tile_assignment(mesh) + self._tile_assignment = _get_tile_assignment(mesh, partition_spec) self._sharding_type = _get_sharding_type(partition_spec, xr.global_device_count()) self._group_assignment, self._replication_groups = _get_group_assignment( From aac03da08007af22af9126c2320e974cc7803f40 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Tue, 18 Jul 2023 10:20:57 -0700 Subject: [PATCH 13/20] Support unordered sharding spec for partial replication (#5316) * Suport unordered sharding spec for partial replication * add 4d test * handle 2d tensor with 2d mesh case * refactoring --- test/spmd/test_xla_sharding.py | 89 ++++++++++++++++++++++++++ torch_xla/experimental/xla_sharding.py | 26 ++++---- 2 files changed, 101 insertions(+), 14 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 3c5c71addde8..21e1fd3843a0 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -301,6 +301,95 @@ def test_mark_sharding_partial(self): actual = (xt1 @ t2).cpu() self.assertTrue(torch.allclose(expected, actual)) + def test_mark_sharding_not_ordered_partial_3d(self): + device = xm.xla_device() + t1 = torch.randn(8, 16, 32).to(device) + t2 = torch.randn(8, 16, 32).to(device) + # Somehow the eager cpu result is different from the xla result. + expected = t1 + t2 + # To re-materialize t1 and t2. + xm.mark_step() + xm.wait_device_ops() + expected = expected.cpu() + + # Shard along two axes if four or more devices are available + z_dim = 2 if self.n_devices >= 4 else 1 + mesh = self._get_mesh((z_dim, 1, self.n_devices // z_dim)) + + # Expect local shard size to be [8, 16 / z_dim, 32] + xt1 = xs.mark_sharding(t1, mesh, (1, 0, None)) + + for local_shard in xt1.local_shards: + self.assertEqual(local_shard.data.size()[0], 8) + self.assertEqual(local_shard.data.size()[1], 16 / z_dim) + self.assertEqual(local_shard.data.size()[2], 32) + + # partial replication requires >1 devices; otherwise, it's replicated. + if self.n_devices > 1: + # xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way. + self.assertTrue('last_tile_dim_replicate' in + torch_xla._XLAC._get_xla_sharding_spec(t1)) + self.assertTrue('[%d,%d,1,%d]' % + (1, z_dim, self.n_devices // + z_dim) in torch_xla._XLAC._get_xla_sharding_spec(t1)) + actual = (xt1 + t2).cpu() + self.assertTrue(torch.allclose(expected, actual)) + + def test_mark_sharding_not_ordered_partial_4d(self): + device = xm.xla_device() + t1 = torch.randn(8, 16, 32, 64).to(device) + t2 = torch.randn(8, 16, 32, 64).to(device) + # Somehow the eager cpu result is different from the xla result. + expected = t1 + t2 + # To re-materialize t1 and t2. + xm.mark_step() + xm.wait_device_ops() + expected = expected.cpu() + + # Shard along two axes if four or more devices are available + z_dim = 2 if self.n_devices >= 4 else 1 + mesh = self._get_mesh((z_dim, 1, 1, self.n_devices // z_dim)) + + # Expect local shard size to be [8, 16, 32 / z_dim, 64] + xt1 = xs.mark_sharding(t1, mesh, (2, None, 0, None)) + + for local_shard in xt1.local_shards: + self.assertEqual(local_shard.data.size()[0], 8) + self.assertEqual(local_shard.data.size()[1], 16) + self.assertEqual(local_shard.data.size()[2], 32 / z_dim) + self.assertEqual(local_shard.data.size()[3], 64) + + # partial replication requires >1 devices; otherwise, it's replicated. + if self.n_devices > 1: + # xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way. + self.assertTrue('last_tile_dim_replicate' in + torch_xla._XLAC._get_xla_sharding_spec(t1)) + self.assertTrue('[1,1,%d,1,%d]' % + (z_dim, + (self.n_devices // + z_dim)) in torch_xla._XLAC._get_xla_sharding_spec(t1)) + actual = (xt1 + t2).cpu() + self.assertTrue(torch.allclose(expected, actual)) + + def test_mark_sharding_not_ordered_2d_tensor_3d_mesh(self): + ct1 = torch.randn(16, 16, device='cpu') + ct2 = torch.randn(16, 16, device='cpu') + expected = ct1 + ct2 + + t1 = ct1.to(xm.xla_device()) + t2 = ct2.to(xm.xla_device()) + mesh = self._get_mesh((1, self.n_devices, 1)) + # sharding spec here is not ordered. + xt1 = xs.mark_sharding(t1, mesh, partition_spec=(2, 1)) + if self.n_devices > 1: + hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt1.global_tensor]) + sharding_annotation = 'sharding={devices=[1,1,%d]%s}' % ( + self.n_devices, ','.join( + [str(d) for d in mesh.get_logical_mesh().flatten()])) + self.assertIn(sharding_annotation, hlo) + actual = (xt1 + t2).cpu() + self.assertTrue(torch.allclose(expected, actual)) + def test_partial_replication_addmm(self): device = xm.xla_device() z_dim = 2 if self.n_devices >= 4 else 1 diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index ee7e32117b43..6d841a33b8da 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -323,10 +323,6 @@ def _get_tile_assignment(mesh: Mesh, partition_spec: Tuple[Union[int, None]]) -> List[int]: # Use Torch.tensor here to make use of the torch.transpose_ mesh_list_tensor = torch.tensor(mesh.get_logical_mesh().tolist()) - # This is partial sharding case, tile_assigniment will be ignore in favor of - # group_assignment and replication_groups. - if (mesh_list_tensor.dim() != len(partition_spec)): - return mesh_list_tensor.tolist() partition_spec_list = list(partition_spec) for i in range(len(partition_spec_list)): if partition_spec_list[i] == None: @@ -339,26 +335,28 @@ def _get_tile_assignment(mesh: Mesh, return mesh_list_tensor.permute(partition_spec_list).tolist() -def _get_group_assignment( - sharding_type: ShardingType, mesh: Mesh, - partition_spec: Tuple[Union[int, None]]) -> Tuple[List, List]: +def _get_group_assignment(sharding_type: ShardingType, mesh: Mesh, + partition_spec: Tuple[Union[int, None]], + tile_assignment: List) -> Tuple[List, List]: group_assignment = list() replication_groups = list() + # TODO(JackCaoG): 3d mesh on 2d tensor + mesh_shape_list = list(torch.tensor(tile_assignment).size()) if sharding_type is ShardingType.PARTIAL: # Shard across groups and replicate within subgroups; replicated dims # will be used to group replication devices. tile_dims = [d for d in partition_spec if d is not None] - replicated_dims = set(range(len(mesh.mesh_shape))) - set(tile_dims) + replicated_dims = set(range(len(mesh_shape_list))) - set(tile_dims) - group_list = [np.array(mesh.get_logical_mesh().tolist())] + group_list = [np.array(tile_assignment)] for d in tile_dims: _group_list = list() for group_members in group_list: - _group_list += np.split(group_members, mesh.mesh_shape[d], d) + _group_list += np.split(group_members, mesh_shape_list[d], d) group_list = _group_list replication_groups = [group.flatten().tolist() for group in group_list] - group_tile_shape = list(mesh.mesh_shape) + group_tile_shape = mesh_shape_list for d in replicated_dims: group_tile_shape[d] = 1 group_assignment = np.arange(len(replication_groups)).reshape( @@ -415,7 +413,6 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, assert len(specs) == len(np.unique(specs)), \ f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." - tile_assignment = _get_tile_assignment(mesh, partition_spec) # check for sharding 2D tensor on a 3D mesh original_shape = tuple(t.shape) # number of dims to expand on tensor @@ -426,9 +423,10 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, shape = (1,) * tensor_expand + (*original_shape,) t = t.expand(shape) + tile_assignment = _get_tile_assignment(mesh, partition_spec) sharding_type = _get_sharding_type(partition_spec, num_devices) group_assignment, replication_groups = _get_group_assignment( - sharding_type, mesh, partition_spec) + sharding_type, mesh, partition_spec, tile_assignment) def tensor_squeeze(t, tensor_expand): if tensor_expand: @@ -484,7 +482,7 @@ def __post_init__(self): self._sharding_type = _get_sharding_type(partition_spec, xr.global_device_count()) self._group_assignment, self._replication_groups = _get_group_assignment( - self._sharding_type, mesh, partition_spec) + self._sharding_type, mesh, partition_spec, self._tile_assignment) def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: """ From def08b439829c7e2ba6a20913c570905235948f2 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Tue, 18 Jul 2023 14:42:11 -0700 Subject: [PATCH 14/20] Fix mismatched GPU docker image in the doc. (#5319) --- docs/gpu.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/gpu.md b/docs/gpu.md index afb8291ba30d..7bf8f665dc28 100644 --- a/docs/gpu.md +++ b/docs/gpu.md @@ -18,7 +18,7 @@ curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit sudo systemctl restart docker -sudo docker run --gpus all -it -d gcr.io/tpu-pytorch/xla:nightly_3.7\8_cuda_11.2 bin/bash +sudo docker run --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.7 bin/bash sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash ``` From 37b85187ed0ca7d13bc26429e3eb00ef5c498e23 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Tue, 18 Jul 2023 15:07:56 -0700 Subject: [PATCH 15/20] quick refactor on _get_group_assignment (#5318) --- torch_xla/experimental/xla_sharding.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 6d841a33b8da..e97f0b96b722 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -335,12 +335,11 @@ def _get_tile_assignment(mesh: Mesh, return mesh_list_tensor.permute(partition_spec_list).tolist() -def _get_group_assignment(sharding_type: ShardingType, mesh: Mesh, +def _get_group_assignment(sharding_type: ShardingType, partition_spec: Tuple[Union[int, None]], tile_assignment: List) -> Tuple[List, List]: group_assignment = list() replication_groups = list() - # TODO(JackCaoG): 3d mesh on 2d tensor mesh_shape_list = list(torch.tensor(tile_assignment).size()) if sharding_type is ShardingType.PARTIAL: # Shard across groups and replicate within subgroups; replicated dims @@ -426,7 +425,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, tile_assignment = _get_tile_assignment(mesh, partition_spec) sharding_type = _get_sharding_type(partition_spec, num_devices) group_assignment, replication_groups = _get_group_assignment( - sharding_type, mesh, partition_spec, tile_assignment) + sharding_type, partition_spec, tile_assignment) def tensor_squeeze(t, tensor_expand): if tensor_expand: @@ -482,7 +481,7 @@ def __post_init__(self): self._sharding_type = _get_sharding_type(partition_spec, xr.global_device_count()) self._group_assignment, self._replication_groups = _get_group_assignment( - self._sharding_type, mesh, partition_spec, self._tile_assignment) + self._sharding_type, partition_spec, self._tile_assignment) def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: """ From d9c92bd3ed9d6a1620b3a3d6438f39d1bc9b7125 Mon Sep 17 00:00:00 2001 From: qihqi Date: Tue, 18 Jul 2023 18:04:46 -0700 Subject: [PATCH 16/20] Add tf independent serialization (#5308) Create a serialization format for StableHLO graphs and weights without tf.saved_model Need to not use tensorflow because tensorflow is no longer dependency of pytorch/xla. Information saved are enough to reconstruct the tf.saved_model for serving. Information stored: * metadata on which tensor maps which input position * StableHLO version number * metadata on which tensor corresponds to user input or parameter * metadata on shape and dtype of each tensor. * Tensors themselves are saved as numpy arrays using np.save. --- test/stablehlo/test_stablehlo_dump.py | 15 + .../experimental/stablehlo_saved_model.py | 263 +++++++++++++++++- 2 files changed, 277 insertions(+), 1 deletion(-) diff --git a/test/stablehlo/test_stablehlo_dump.py b/test/stablehlo/test_stablehlo_dump.py index f608683f73ca..b14ca3e013d4 100644 --- a/test/stablehlo/test_stablehlo_dump.py +++ b/test/stablehlo/test_stablehlo_dump.py @@ -1,6 +1,9 @@ +import tempfile import torch_xla import torch_xla.core.xla_model as xm +from torch_xla.experimental import stablehlo_saved_model import torch +import torch._export import torchvision import unittest from torch import nn @@ -84,6 +87,18 @@ def test_cat(self): # FIXME: Currently the dim=1 is hard coded self.assertTrue('dim = 1' in stablehlo) + def test_save_load(self): + model = ElementwiseAdd() + inputs = model.get_random_inputs() + exported = torch._export.export(model, inputs) + bundle = stablehlo_saved_model._exported_program_to_stablehlo_bundle( + exported, inputs) + with tempfile.TemporaryDirectory() as tempdir: + stablehlo_saved_model._save_program_bundle(bundle, tempdir) + bundle2 = stablehlo_saved_model._load_program_bundle(tempdir) + + self.assertEqual(bundle.stablehlo_funcs, bundle2.stablehlo_funcs) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/experimental/stablehlo_saved_model.py b/torch_xla/experimental/stablehlo_saved_model.py index 4d54379614bc..68e87be06d2b 100644 --- a/torch_xla/experimental/stablehlo_saved_model.py +++ b/torch_xla/experimental/stablehlo_saved_model.py @@ -1,15 +1,21 @@ import copy +from dataclasses import dataclass +import enum +import json import shutil import os import re +from typing import List, Tuple, Optional, Mapping, Any, Dict +import dataclasses + import numpy as np import torch from torch import nn +from torch.fx import _pytree as fx_pytree import torch_xla from torch_xla.core import xla_model as xm from torch_xla.core import dynamo_bridge from torch_xla.debug import metrics -import tensorflow as tf import torchvision import torch._dynamo as torchdynamo @@ -137,3 +143,258 @@ def export_torch_model(model: torch.nn.Module, # mistakenlly update the input tensors. torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) return stablehlo_model + + +class VariableType(enum.Enum): + INPUT_ARG = 'input_arg' + PARAMETER = 'parameter' + CONSTANT = 'constant' + + +@dataclass +class VariableSignature: # either argument or parameters + shape: List[int] + dtype: str + + +@dataclass +class InputLocation: + type_: VariableType + position: int = -1 + name: str = '' + + @classmethod + def parameter(cls, name: str): + return cls(type_=VariableType.PARAMETER, name=name) + + @classmethod + def input_arg(cls, position: int): + return cls(type_=VariableType.INPUT_ARG, position=position) + + @classmethod + def constant(cls, position): + return cls(type_=VariableType.CONSTANT, position=position) + + +@dataclass +class StableHLOFunctionMeta: + # name of the callable. + name: str + # version. + stablehlo_version: str + # order is the order accepted by the stablehlo + input_signature: List[VariableSignature] + output_signature: List[VariableSignature] + # An input to the underlying stable callable can come from + # the arguments the user supplied, OR a parameter, OR a constant + input_locations: List[InputLocation] + + +class StableHLOJSONSerializer(json.JSONEncoder): + + def default(self, obj): + if dataclasses.is_dataclass(obj): + return dataclasses.asdict(obj) + if isinstance(obj, VariableType): + return obj.value + return super().default(obj) + + +def stablehlo_obj_hook(dct): + targets = [ + StableHLOFunctionMeta, + VariableSignature, + InputLocation, + VariableSignature, + ] + + def _match_field(clazz): + # A dataclass become a dict on serialization then, + # it the dict must have values for all the fields in that dataclass. + return set(f.name for f in dataclasses.fields(clazz)) == dct.keys() + + def _try_convert_as_enum(v): + try: + return VariableType(v) + except: + return v + + for clazz in targets: + if _match_field(clazz): + new_dict = {k: _try_convert_as_enum(v) for k, v in dct.items()} + return clazz(**new_dict) + + +@dataclass +class StableHLOModelBundle: + # original state dict; but torch.Tensor's converted to np.array + state_dict: Dict[str, Any] + # Additional constants that we decide to hardcode. + additional_constants: List[np.ndarray] + # can support the case of multiple callable of the same model. + stablehlo_funcs: List[Tuple[StableHLOFunctionMeta, bytes]] + + +@dataclass +class StableHLOExportOptions: + pass + + +def _exported_program_to_stablehlo_bundle(exported_model, args): + xm.mark_step() + metrics.clear_counters() + device = xm.xla_device() + + if exported_model.call_spec.in_spec is not None: + args = fx_pytree.tree_flatten_spec(args, exported_model.call_spec.in_spec) + else: + args = copy.deepcopy(args) + + args = [ + x.to(device=device) if isinstance(x, torch.Tensor) else x for x in args + ] + + input_ids = { + torch_xla._XLAC._xla_get_tensor_id(tensor): i + for i, tensor in enumerate(args) + if isinstance(tensor, torch.Tensor) + } + # NOTE call convention: (parameters, buffers, user_inputs) + param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers + state = exported_model.state_dict + param_buffer_values = tuple(state[key].to( + device=device) if isinstance(state[key], torch.Tensor) else state[key] + for key in param_and_buffer_keys) + + with torch.no_grad(): + res = torch.fx.Interpreter(exported_model.graph_module).run( + *param_buffer_values, *args, enable_io_processing=False) + + ( + graph_input_tensor_ids, + graph_input_xla_values, + ) = torch_xla._XLAC._get_tensors_xla_device_data_node(res) + + tensor_id_to_state_name = { + torch_xla._XLAC._xla_get_tensor_id(value): name + for name, value in zip(param_and_buffer_keys, param_buffer_values) + if isinstance(value, torch.Tensor) + } + stablehlo_content = xm.get_stablehlo_bytecode(res) + + pos_to_orig_pos = {} + pos_to_param = {} + input_locations = [] + input_signatures = [] + additional_constants = [] + for hlo_input_pos, (tensor_id, tensor_value) in enumerate( + zip(graph_input_tensor_ids, graph_input_xla_values)): + if tensor_id in input_ids: # this is input + location = InputLocation.input_arg(position=input_ids[tensor_id]) + elif tensor_id in tensor_id_to_state_name: + location = InputLocation.parameter( + name=tensor_id_to_state_name[tensor_id]) + else: # additional constants that WE created + location = InputLocation.constant(position=len(additional_constants)) + additional_constants.append(tensor_value) + input_locations.append(location) + input_signatures.append( + VariableSignature( + shape=list(tensor_value.shape), + dtype=str(tensor_value.dtype).replace('torch.', ''))) + + output_signature = [ + VariableSignature( + shape=list(tensor.shape), + dtype=str(tensor_value.dtype).replace('torch.', '')) for tensor in res + ] + + torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) + + meta = StableHLOFunctionMeta( + name='forward', + # TODO(qihqi) populate version from runtime + stablehlo_version="0.0.0", + input_signature=input_signatures, + output_signature=output_signature, + input_locations=input_locations, + ) + + return StableHLOModelBundle( + stablehlo_funcs=[(meta, stablehlo_content)], + state_dict=exported_model.state_dict, + additional_constants=additional_constants, + ) + + +class StableHLOExportOptions: + pass + + +def _save_program_bundle(bundle: StableHLOModelBundle, + stablehlo_dir: os.PathLike) -> None: + + data_dir = os.path.join(stablehlo_dir, 'data') + os.makedirs(data_dir, exist_ok=True) + for key, val in bundle.state_dict.items(): + with open(os.path.join(stablehlo_dir, 'data', key), 'wb') as f: + np.save(f, val.cpu().detach().numpy()) + + # save metadata and stablehlo bytecode + func_dir = os.path.join(stablehlo_dir, 'functions') + os.makedirs(func_dir, exist_ok=True) + for meta, bytecode in bundle.stablehlo_funcs: + with open(os.path.join(func_dir, meta.name + '.meta'), 'w') as f: + json.dump(meta, f, cls=StableHLOJSONSerializer) + with open(os.path.join(func_dir, meta.name + '.mlir'), 'wb') as f: + f.write(bytecode) + + const_dir = os.path.join(stablehlo_dir, 'constants') + os.makedirs(const_dir, exist_ok=True) + for i, constant in enumerate(bundle.additional_constants): + with open(os.path.join(const_dir, str(i)), 'wb') as f: + np.save(f, constant.cpu().detach().numpy()) + + +def _iter_dir(path: os.PathLike): + for name in os.listdir(path): + with open(os.path.join(path, name), 'rb') as f: + yield name, f + + +def _load_program_bundle(stablehlo_dir: os.PathLike) -> StableHLOModelBundle: + state_dict = {} + for name, f in _iter_dir(os.path.join(stablehlo_dir, 'data')): + state_dict[name] = np.load(f, allow_pickle=True) + + constants = [] + for name, f in _iter_dir(os.path.join(stablehlo_dir, 'constants')): + # name of constants are ints + constants.append((int(name), np.load(f, allow_pickle=True))) + constants = [v for k, v in sorted(constants)] + + metas = [] + name_to_bytecode = {} + stablehlo_funcs = [] + for name, f in _iter_dir(os.path.join(stablehlo_dir, 'functions')): + if name.endswith('.meta'): + metas.append(json.load(f, object_hook=stablehlo_obj_hook)) + else: + name_to_bytecode[os.path.splitext(name)[0]] = f.read() + + for meta in metas: + stablehlo_funcs.append((meta, name_to_bytecode[meta.name])) + + return StableHLOModelBundle( + stablehlo_funcs=stablehlo_funcs, + additional_constants=constants, + state_dict=state_dict) + + +def save_as_stablehlo(exported_model: 'ExportedProgram', + args: Tuple[Any], + stablehlo_dir: os.PathLike, + options: Optional[StableHLOExportOptions] = None): + + bundle = _exported_program_to_stablehlo_bundle(exported_model, args) + _save_program_bundle(bundle, stablehlo_dir) From ebb51208b8e87edccce53e6615e577c6e6b72aaa Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Wed, 19 Jul 2023 10:15:27 -0700 Subject: [PATCH 17/20] Disable coverage for now (#5321) --- .github/workflows/build_and_test.yml | 51 ++++++++++++++-------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index e45a7a4d41bb..a14845341f83 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -50,32 +50,33 @@ jobs: secrets: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} - test-cpu-coverage: - name: "Collect CPU test coverage" - if: github.event_name == 'push' && github.event.ref == 'refs/heads/master' - uses: ./.github/workflows/_test.yml - needs: build - with: - docker-image: ${{ needs.build.outputs.docker-image }} - collect-coverage: true - timeout-minutes: 120 - disable-xrt: 1 - secrets: - gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} + # break by gcc version update https://github.com/pytorch/xla/commit/e7e189961bd669c33939e269c248b391fe156d38 + # test-cpu-coverage: + # name: "Collect CPU test coverage" + # if: github.event_name == 'push' && github.event.ref == 'refs/heads/master' + # uses: ./.github/workflows/_test.yml + # needs: build + # with: + # docker-image: ${{ needs.build.outputs.docker-image }} + # collect-coverage: true + # timeout-minutes: 120 + # disable-xrt: 1 + # secrets: + # gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} - test-gpu-coverage: - name: "Collect GPU test coverage" - if: github.event_name == 'push' && github.event.ref == 'refs/heads/master' - uses: ./.github/workflows/_test.yml - needs: build - with: - docker-image: ${{ needs.build.outputs.docker-image }} - runner: linux.8xlarge.nvidia.gpu - timeout-minutes: 210 - collect-coverage: true - disable-xrt: 1 - secrets: - gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} + # test-gpu-coverage: + # name: "Collect GPU test coverage" + # if: github.event_name == 'push' && github.event.ref == 'refs/heads/master' + # uses: ./.github/workflows/_test.yml + # needs: build + # with: + # docker-image: ${{ needs.build.outputs.docker-image }} + # runner: linux.8xlarge.nvidia.gpu + # timeout-minutes: 210 + # collect-coverage: true + # disable-xrt: 1 + # secrets: + # gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} push-docs: name: "Build & publish docs" From 9a995403001d757980b804605a3fc32dcdf039c2 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Wed, 19 Jul 2023 14:05:57 -0700 Subject: [PATCH 18/20] Enable Some input output aliasing under SPMD (#5320) --- test/spmd/test_xla_sharding.py | 12 ++++++++++ torch_xla/csrc/init_python_bindings.cpp | 21 +++++++++++++----- torch_xla/csrc/xla_graph_executor.cpp | 29 +++++++++++++++++++++++-- torch_xla/csrc/xla_sharding_util.cpp | 5 +++++ torch_xla/csrc/xla_sharding_util.h | 4 ++++ 5 files changed, 63 insertions(+), 8 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 21e1fd3843a0..52ed34a10874 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -733,6 +733,18 @@ def test_xla_sharded_hlo_dump_post_optimizations(self): if self.n_devices > 1: self.assertIn('all-reduce', hlo) + def test_sharded_tensor_aliasing(self): + met.clear_all() + partition_spec = (0, 1) + xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], + dtype=torch.float, + device=xm.xla_device()) + xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), + partition_spec) + xst1 += 1 + xm.mark_step() + self.assertEqual(met.metric_data("InputOutputAliasCount")[0], 1) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index fa87a016b190..d295f5ad3e96 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -355,6 +355,15 @@ std::string GetTensorsHloGraph(const std::vector& tensors, return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode); } +std::string GetXLAShardingSpec(const XLATensorPtr xtensor) { + auto sharding_spec = xtensor->sharding_spec(); + if (sharding_spec != nullptr) { + auto hlo_sharding = xla::HloSharding::FromProto(sharding_spec->sharding); + return hlo_sharding->ToString(); + } + return std::string(); +} + std::string GetXLATensorDebugInfo(const at::Tensor& tensor) { auto xtensor = bridge::TryGetXlaTensor(tensor); if (!xtensor) { @@ -366,6 +375,11 @@ std::string GetXLATensorDebugInfo(const at::Tensor& tensor) { ss << "Device: " << xtensor->GetDevice() << "\n"; ss << "XLA Shape: " << xtensor->shape().get().ToString() << "\n"; + std::string sharding_spec_str = GetXLAShardingSpec(xtensor); + ss << "ShardingSpec: " + << ((sharding_spec_str.size() > 0) ? sharding_spec_str : "None"); + ss << "\n"; + torch::lazy::Value ir_value = xtensor->CurrentIrValue(); ss << "IR: "; if (ir_value) { @@ -1406,12 +1420,7 @@ void InitXlaModuleBindings(py::module m) { }); m.def("_get_xla_sharding_spec", [](const at::Tensor& input) -> std::string { XLATensorPtr xtensor = bridge::GetXlaTensor(input); - auto sharding_spec = xtensor->sharding_spec(); - if (sharding_spec != nullptr) { - auto hlo_sharding = xla::HloSharding::FromProto(sharding_spec->sharding); - return hlo_sharding->ToString(); - } - return std::string(); + return GetXLAShardingSpec(xtensor); }); m.def("_get_xla_sharding_type", [](const at::Tensor& input) -> std::optional { diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 2714d6ef6fb0..7cc8c6c94109 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1164,7 +1164,32 @@ XLAGraphExecutor::BuildInputOutputAliases( // Need to check whether existing buffer and the new value has the same // shape and the existing buffer has not been aliased before aliasing // the existing and new buffer. - if (parameter_data_shape == root_shape && alias_map[output_index] < 0) { + + bool equal_sharding; + // get sharding for the parameter data + std::optional parameter_sharding = + torch_xla::runtime::GetComputationClient()->GetDataSharding( + UnwrapXlaData(parameters_data[i])); + // get sharding for output tensor + size_t output_tensor_index = indices[output_index]; + XLATensor::ShardingSpecPtr output_sharding = + tensors[output_tensor_index]->sharding_spec(); + if (!parameter_sharding && !output_sharding) { + // Both parameter and output does not have sharding. + // TODO(JackCaoG): It is possible that output might get a sharding + // after sharding propagation. Consier not aliased here(if under SPMD + // mode). + equal_sharding = true; + } else if (parameter_sharding && output_sharding) { + equal_sharding = ShardingUtil::EqualOpShardings( + *parameter_sharding, output_sharding->sharding); + } else { + // one of the parameter and output does not have sharding. + equal_sharding = false; + } + + if (parameter_data_shape == root_shape && alias_map[output_index] < 0 && + equal_sharding) { // parameter is not a tuple so param_index will always be {} lowering_ctx->builder()->SetUpAlias( {/*output_index=*/static_cast(output_index)}, @@ -1216,7 +1241,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( // since the current aliasing compares the unpartitioned input and output // shapes which can lead to an incorrect aliasing pairs if sharded. if (enable_aliasing && coll.config.sync_ltc_data && - coll.config.force_ltc_data && !is_sharded) { + coll.config.force_ltc_data) { // We can only alias at the step barrier, when force_ltc_data is true. // Consider the case: // 1. Tensor A(DEVICE_DATA) diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index ea91fb2b3118..e2cbd4bfaefb 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -181,6 +181,11 @@ bool ShardingUtil::EqualShardingSpecs(const XLATensor::ShardingSpec& a, return xla::protobuf_util::ProtobufEquals(a.sharding, b.sharding); } +bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a, + const xla::OpSharding& b) { + return xla::protobuf_util::ProtobufEquals(a, b); +} + xla::OpSharding ShardingUtil::CreateOpSharding( const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, ShardingType sharding_type) { diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 86a18dc83f89..0d512a33b7ef 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -38,6 +38,10 @@ class ShardingUtil { static bool EqualShardingSpecs(const XLATensor::ShardingSpec& a, const XLATensor::ShardingSpec& b); + // Returns true if two OpShardings are the same. + static bool EqualOpShardings(const xla::OpSharding& a, + const xla::OpSharding& b); + // Creates an xla::OpSharding. `tile_assignmnent` is required for TILED // `sharding_type` and `replication_groups` for `PARTIAL`. static xla::OpSharding CreateOpSharding(const py::list& tile_assignment, From 8c20cbd6a27130974a9df9f2bf3dec151061ca86 Mon Sep 17 00:00:00 2001 From: Yash Shah <55116947+yashs97@users.noreply.github.com> Date: Wed, 19 Jul 2023 15:44:02 -0700 Subject: [PATCH 19/20] Use `_sharded_cpu_state_dict` functionality to Write Items for SPMD Save Planner (#5315) * initial commit * add suggested changes * add unit test * fix test * fix test * add suggested changes * remove is_sharded_tensor check * check if device type is xla in `wrap_if_sharded` * change order * update resolve_data and add more tests * run linter * use subtest * formatting fixes * run linter --- test/spmd/test_xla_distributed_checkpoint.py | 89 +++++++++++++------ .../experimental/distributed_checkpoint.py | 59 +++++++----- torch_xla/experimental/xla_sharding.py | 1 + 3 files changed, 98 insertions(+), 51 deletions(-) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 0d2f3e318cce..bf30ced90e42 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -57,7 +57,8 @@ def _save_and_restore(self, model_in, model_out, save_planner=None, - load_planner=None): + load_planner=None, + is_sharded_cpu_state_dict=False): """ Checkpoint model_in using the provided save_planner and load into model_out using the provided load_planner, and assert model_out equals model_in after @@ -66,18 +67,22 @@ def _save_and_restore(self, tmpdir = tempfile.mkdtemp() # Save an unsharded model using the provided save planner + model_in_state_dict = model_in.state_dict() + if is_sharded_cpu_state_dict: + model_in_state_dict = _sharded_cpu_state_dict(model_in_state_dict) + model_out_state_dict = model_out.state_dict() dist_cp.save_state_dict( - state_dict=model_in.state_dict(), + state_dict=model_in_state_dict, storage_writer=dist_cp.FileSystemWriter(tmpdir), planner=save_planner, no_dist=True, # Single-host checkpoint doesn't require a process group ) - # Load the checkpoint using the provided load planner for p1, p2 in zip(model_in.parameters(), model_out.parameters()): self.assertFalse(torch.allclose(p1, p2)) + dist_cp.load_state_dict( - state_dict=model_out.state_dict(), + state_dict=model_out_state_dict, storage_reader=dist_cp.FileSystemReader(tmpdir), planner=load_planner, no_dist=True, # Single-host checkpoint doesn't require a process group @@ -95,9 +100,15 @@ def test_unsharded_to_sharded(self): # TODO(jonbolin): Enable tests for resharding into coarser meshes @unittest.skip("View assignment with virtual device is not yet supported") def test_sharded_to_unsharded(self): - model = self.SimpleLinear().to(xm.xla_device()) - sharded_model = self._get_sharded_model() - self._save_and_restore(sharded_model, model, save_planner=SPMDSavePlanner()) + for chkpt_on_cpu in [True, False]: + with self.subTest(chkpt_on_cpu): + model = self.SimpleLinear().to(xm.xla_device()) + sharded_model = self._get_sharded_model() + self._save_and_restore( + sharded_model, + model, + save_planner=SPMDSavePlanner(), + is_sharded_cpu_state_dict=chkpt_on_cpu) # TODO(jonbolin): Enable tests for resharding into coarser meshes @unittest.skip("View assignment with virtual device is not yet supported") @@ -186,15 +197,16 @@ def test_resolve_and_commit_sharded_tensor(self): class SPMDSavePlannerTest(DistributedCheckpointTestBase): - def _get_save_planner(self, model): + def _get_save_planner(self, model, is_sharded_cpu_state_dict=False): # Create an SPMDSavePlanner for the given model. planner = SPMDSavePlanner() - planner.set_up_planner(model.state_dict(), True) + if not is_sharded_cpu_state_dict: + planner.set_up_planner(model.state_dict(), True) + else: + planner.set_up_planner(_sharded_cpu_state_dict(model.state_dict()), True) return planner - def test_state_dict_separation(self): - model = self._get_sharded_model() - planner = self._get_save_planner(model) + def _planner_assertions(self, planner): if self.n_devices > 1: # The state_dict should be flattened and separated self.assertCountEqual(planner.sharded_state_dict, ['fc1.weight']) @@ -208,27 +220,46 @@ def test_state_dict_separation(self): planner.unsharded_state_dict, ['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']) - def test_local_save_plan(self): + def test_state_dict_separation(self): model = self._get_sharded_model() planner = self._get_save_planner(model) - plan = planner.create_local_plan() - parameter_count = len(list(model.parameters())) + self._planner_assertions(planner) + + def test_save_state_dict_with_cpu_shards(self): + model = self._get_sharded_model() + planner = self._get_save_planner(model, is_sharded_cpu_state_dict=True) + self._planner_assertions(planner) if self.n_devices > 1: - # When the model is sharded across devices, fc1.weight will result in - # self.n_devices WriteItems while all other tensors result in a single - # WriteItem. - sharded_write_items = [ - wi for wi in plan.items if wi.index.fqn == 'fc1.weight' - ] - self.assertEqual(self.n_devices, len(sharded_write_items)) - # Every other parameter should have a single WriteItem - unsharded_write_items = [ - x for x in plan.items if x not in sharded_write_items - ] - self.assertEqual(parameter_count - 1, len(unsharded_write_items)) - else: + self.assertTrue( + isinstance(planner.sharded_state_dict['fc1.weight'], _CpuShards)) + + def test_local_save_plan(self): + + def _write_item_assertions(plan, n_devices, parameter_count): + if n_devices > 1: + # When the model is sharded across devices, fc1.weight will result in + # self.n_devices WriteItems while all other tensors result in a single + # WriteItem. + sharded_write_items = [ + wi for wi in plan.items if wi.index.fqn == 'fc1.weight' + ] + self.assertEqual(self.n_devices, len(sharded_write_items)) + # Every other parameter should have a single WriteItem + unsharded_write_items = [ + x for x in plan.items if x not in sharded_write_items + ] + self.assertEqual(parameter_count - 1, len(unsharded_write_items)) + else: + self.assertEqual(parameter_count, len(plan.items)) # If unsharded, there should be a single WriteItem per model parameter - self.assertEqual(parameter_count, len(plan.items)) + + for chkpt_on_cpu in [True, False]: + with self.subTest(chkpt_on_cpu): + model = self._get_sharded_model() + planner = self._get_save_planner(model, chkpt_on_cpu) + plan = planner.create_local_plan() + parameter_count = len(list(model.parameters())) + _write_item_assertions(plan, self.n_devices, parameter_count) @unittest.skipIf(xr.global_device_count() == 1, "Multiple devices required to shard tensors") diff --git a/torch_xla/experimental/distributed_checkpoint.py b/torch_xla/experimental/distributed_checkpoint.py index c2a1545cfc3f..89d7f5b33323 100644 --- a/torch_xla/experimental/distributed_checkpoint.py +++ b/torch_xla/experimental/distributed_checkpoint.py @@ -1,3 +1,4 @@ +from copy import copy import dataclasses import io import numpy as np @@ -35,14 +36,8 @@ from torch.utils._pytree import tree_map from torch_xla.experimental.xla_sharding import XLAShardedTensor, XLAShard from torch_xla.experimental._distributed_checkpoint_helpers import ( - FLATTEN_MAPPING, - flatten_state_dict, - dedup_tensors, - _is_sharded_tensor, - set_element, - narrow_tensor_by_index, - _unwrap_xla_sharded_tensor, -) + FLATTEN_MAPPING, flatten_state_dict, dedup_tensors, _is_sharded_tensor, + set_element, narrow_tensor_by_index, _unwrap_xla_sharded_tensor, _CpuShards) from typing import Any, Dict, List, Tuple, Union __all__ = [ @@ -70,7 +65,7 @@ def __init__(self): # Flattened state_dict tracking all sharded tensors to be checkpointed self.sharded_state_dict: Dict[str, XLAShardedTensor] = None - # Flattend state_dict tracking all other state_dict items + # Flattened state_dict tracking all other state_dict items self.unsharded_state_dict: Dict[str, Any] = None # Upon the first `resolve_data` call for a WriteItem associated with a @@ -90,10 +85,12 @@ def set_up_planner(self, state_dict: STATE_DICT_TYPE, state_dict, self.mappings = flatten_state_dict(state_dict) state_dict = tree_map(xs.wrap_if_sharded, state_dict) - # Select only XLAShardedTensors which are not replicated, since the - # default planner can handle everything else. + # Select only XLAShardedTensors which are not replicated or _CpuShards, + # since the default planner can handle everything else. self.sharded_state_dict = { - k: v for k, v in state_dict.items() if _is_sharded_tensor(v) + k: v + for k, v in state_dict.items() + if _is_sharded_tensor(v) or isinstance(v, _CpuShards) } unsharded = dict(state_dict.items() - self.sharded_state_dict.items()) self.unsharded_state_dict = tree_map(_unwrap_xla_sharded_tensor, unsharded) @@ -105,7 +102,7 @@ def create_local_plan(self) -> SavePlan: # Track the flattened mappings in the plan metadata plan = dataclasses.replace(plan, planner_data=self.mappings) - # Extend the plan for sharded tensor data + # Extend the plan for sharded tensor data and _CpuShards. xla_write_items = _create_xla_write_items(self.sharded_state_dict) plan.items.extend(xla_write_items) return plan @@ -138,9 +135,10 @@ def lookup_object(self, index: MetadataIndex) -> Any: if index.fqn not in self._local_shards: xtensor = self.sharded_state_dict[index.fqn] - assert isinstance(xtensor, - XLAShardedTensor), f"Unsupported object type: {xtensor}" - self._local_shards[index.fqn] = xtensor.local_shards + if isinstance(xtensor, XLAShardedTensor): + self._local_shards[index.fqn] = xtensor.local_shards + elif isinstance(xtensor, _CpuShards): + self._local_shards[index.fqn] = copy(xtensor.shards) shard = self._local_shards[index.fqn][index.index] assert shard is not None, f"WriteItem has already been processed: {index}" @@ -187,7 +185,7 @@ def __init__(self): # Flattened state_dict tracking all sharded tensors to be restored self.sharded_state_dict: Dict[str, XLAShardedTensor] = None - # Flattend state_dict tracking all other state_dict items + # Flattened state_dict tracking all other state_dict items self.unsharded_state_dict: Dict[str, Any] = None # Upon the first `resolve_tensor` call for a ReadItem associated with a @@ -297,8 +295,7 @@ def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: return self._pending_elements[fqn] -= np.prod(read_item.lengths) - assert self._pending_elements[ - fqn] >= 0, f"Too many writes for tensor {index.fqn}" + assert self._pending_elements[fqn] >= 0, f"Too many writes for tensor {fqn}" if self._pending_elements[fqn] == 0: # Load local shards into the XLAShardedTensor and release the shards # from CPU @@ -340,15 +337,33 @@ def _create_write_items_for_xla_sharded_tensor( return items +def _create_write_items_for_cpu_shards( + fqn: str, cpu_shards: _CpuShards) -> List[WriteItem]: + items = [] + for xla_shard in cpu_shards.shards: + prop = TensorProperties.create_from_tensor(xla_shard.data) + for shard_ind, indices in enumerate(xla_shard.indices): + write_item = _create_write_item_from_indices(fqn, shard_ind, indices, + cpu_shards.global_shape, + prop) + items.append(write_item) + return items + + def _create_xla_write_items(state_dict: STATE_DICT_TYPE) -> List[WriteItem]: """ Iterate through the state_dict and return WriteItems for all local shards """ items = [] for fqn, v in state_dict.items(): - assert isinstance(v, XLAShardedTensor - ), '_create_xla_write_items only accepts XLAShardedTensor' - items.extend(_create_write_items_for_xla_sharded_tensor(fqn, v)) + if isinstance(v, XLAShardedTensor): + items.extend(_create_write_items_for_xla_sharded_tensor(fqn, v)) + elif isinstance(v, _CpuShards): + items.extend(_create_write_items_for_cpu_shards(fqn, v)) + else: + raise TypeError( + "_create_xla_write_items accepts either XLAShardedTensor or _CpuShards as value type." + ) return items diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index e97f0b96b722..ee9413021ec4 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -458,6 +458,7 @@ def wrap_if_sharded(x: Any) -> Any: Otherwise, returns the input. """ if (isinstance(x, torch.Tensor) and not isinstance(x, XLAShardedTensor) and + x.device.type == 'xla' and torch_xla._XLAC._get_xla_sharding_type(x) is not None): return XLAShardedTensor(x) return x From d507067ae0b13e7b5c3968cdee5258e599ba786f Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Fri, 21 Jul 2023 11:41:31 -0700 Subject: [PATCH 20/20] handle single tensor for method send_to_device_single (#5317) * handle single tensor for method send_to_device_single * fix broadcast parameter --- test/test_operations.py | 7 +++++++ torch_xla/core/xla_model.py | 8 +++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 4d56c59ba38f..cab2b908d3f9 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1970,6 +1970,13 @@ def test_send_to_device_grad(self): dt = xm.send_cpu_data_to_device([t], xla_device) self.assertTrue(dt[0].requires_grad) + def test_send_to_device_single(self): + xla_device = xm.xla_device() + t = _gen_tensor(2, 2) + dt = xm.send_cpu_data_to_device(t, xla_device) + self.assertEqual(dt[0].device, xla_device) + self.assertTrue(torch.all(torch.eq(dt[0].cpu(), t))) + def test_nms(self): BOXES = ( (0, 0, 3, 2), diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 96b0bc450d2a..28251fc5f9a2 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -678,7 +678,7 @@ def collective_broadcast(tensors: List[torch.Tensor], 1 if get_ordinal() == root_ordinal else 0, dtype=tensor.dtype) # Transfer scale tensor as device data instead of constant 1 or 0. xscale = send_cpu_data_to_device(scale, tensor.device) - tensor.mul_(xscale) + tensor.mul_(xscale[0]) all_reduce(REDUCE_SUM, tensors, groups=groups, pin_layout=pin_layout) @@ -1011,7 +1011,7 @@ def select_fn(v): return ToXlaTensorArena(convert_fn, select_fn).transform(data) -def send_cpu_data_to_device(data, device, input_sharding=None): +def send_cpu_data_to_device(datas, device, input_sharding=None): def convert_fn(tensors): devices = [str(device)] * len(tensors) @@ -1025,7 +1025,9 @@ def convert_fn(tensors): def select_fn(v): return type(v) == torch.Tensor and v.device.type == 'cpu' - return ToXlaTensorArena(convert_fn, select_fn).transform(data) + if type(datas) is torch.Tensor: + datas = [datas] + return ToXlaTensorArena(convert_fn, select_fn).transform(datas) def xla_rendezvous(payload: bytes = b'',