From f5c7367bcbbf04bbce18ec600d58119f06bf9cab Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 24 Aug 2022 02:24:42 +0000 Subject: [PATCH 01/12] Helper function to wrap entry HLO --- torch_xla/csrc/lowering_context.cpp | 3 ++ torch_xla/csrc/tensor.cpp | 65 +++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index c1c9c7ba3529..374ba074cb33 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -14,6 +14,9 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/tensor_util.h" +#include +using std::cerr; + namespace torch_xla { namespace { diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 574c8187b33e..96b906c82a2c 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -6,10 +6,12 @@ #include #include #include +#include #include #include #include #include +using std::cerr; #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" @@ -1639,6 +1641,63 @@ void XLATensor::BuildInputOutputAliases( XLA_VALUE_METRIC("InputOutputAliasCount", alias_map.size()); } +xla::StatusOr WrapComputation( + const xla::XlaComputation& computation, + const std::vector& parameter_shapes) { + xla::XlaBuilder builder(computation.proto().name()); + + // Construct a single tuple parameter. + const xla::XlaOp input_tuple = [&builder, ¶meter_shapes]() { + xla::Shape input_tuple; + input_tuple.set_element_type(xla::PrimitiveType::TUPLE); + input_tuple.mutable_tuple_shapes()->reserve(parameter_shapes.size()); + for (int i = 0; i < parameter_shapes.size(); ++i) { + *input_tuple.add_tuple_shapes() = parameter_shapes[i]; + } + return xla::Parameter(&builder, 0, input_tuple, "in"); + }(); + + // Handle the results of the original computation. + const std::vector inner_params = [&input_tuple, + ¶meter_shapes]() { + std::vector parameters; + parameters.reserve(parameter_shapes.size()); + for (int i = 0; i < parameter_shapes.size(); ++i) { + parameters.push_back(xla::GetTupleElement(input_tuple, i)); + } + return parameters; + }(); + + // Call the original computation. + xla::XlaOp orig_result; + orig_result = xla::Call(&builder, computation, inner_params); + + // Construct a single tuple result. + const std::vector results = [&orig_result]() { + std::vector results; + results.push_back(orig_result); + return results; + }(); + + xla::XlaOp result_tuple; + { result_tuple = xla::Tuple(&builder, results); } + + // // Preserve aliases. + // if (io_info.use_dummy_input()) { + // for (const auto& [input_index, output_index] : io_info.io_aliases) { + // Skip the dummy input at index 0. + // builder.SetUpAlias(xla::ShapeIndex({output_index}), 0, + // xla::ShapeIndex({input_index + 1})); + // } + // } else { + // for (const auto& [input_index, output_index] : io_info.io_aliases) { + // builder.SetUpAlias(xla::ShapeIndex({output_index}), 0, + // xla::ShapeIndex({input_index})); + // } + // } + return builder.Build(result_tuple); +} + XLATensor::CompilationResult XLATensor::Compile( const std::vector& tensors, absl::Span devices, const SyncTensorCollection& coll, @@ -1693,7 +1752,13 @@ XLATensor::CompilationResult XLATensor::Compile( } xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla()); + cerr << "hlo built = \n" + << ConsumeValue(xla::util::GetComputationHloText(computation)) << "\n"; xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); + xla::XlaComputation wrapped_computation = + ConsumeValue(WrapComputation(computation, program_shape.parameters())); + cerr << "wrapped hlo built = \n" + << ConsumeValue(xla::util::GetComputationHloText(wrapped_computation)) << "\n"; xla::Shape shape = MakeShapeWithDeviceLayout( program_shape.result(), static_cast(coll.device.type())); From 07161dce6d97054f64fbf08484edd24b09ee71b9 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 24 Aug 2022 02:53:34 +0000 Subject: [PATCH 02/12] Use the wrapped computation during computation --- third_party/xla_client/pjrt_computation_client.cc | 1 + torch_xla/csrc/tensor.cpp | 12 +++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/third_party/xla_client/pjrt_computation_client.cc b/third_party/xla_client/pjrt_computation_client.cc index e2951ff22440..a556912e52e2 100644 --- a/third_party/xla_client/pjrt_computation_client.cc +++ b/third_party/xla_client/pjrt_computation_client.cc @@ -153,6 +153,7 @@ std::vector PjRtComputationClient::Compile( compile_options.executable_build_options.set_num_partitions(1); compile_options.executable_build_options.set_num_replicas( client_->device_count()); + compile_options.parameter_is_tupled_arguments = true; std::unique_ptr executable = client_->Compile(instance.computation, compile_options).ValueOrDie(); std::shared_ptr pjrt_computation = diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 96b906c82a2c..7f6589586bae 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -1752,18 +1752,20 @@ XLATensor::CompilationResult XLATensor::Compile( } xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla()); - cerr << "hlo built = \n" - << ConsumeValue(xla::util::GetComputationHloText(computation)) << "\n"; xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); xla::XlaComputation wrapped_computation = ConsumeValue(WrapComputation(computation, program_shape.parameters())); + xla::ProgramShape wrapped_program_shape = ConsumeValue(wrapped_computation.GetProgramShape()); + xla::Shape shape = MakeShapeWithDeviceLayout( + wrapped_program_shape.result(), static_cast(coll.device.type())); + + cerr << "hlo built = \n" + << ConsumeValue(xla::util::GetComputationHloText(computation)) << "\n"; cerr << "wrapped hlo built = \n" << ConsumeValue(xla::util::GetComputationHloText(wrapped_computation)) << "\n"; - xla::Shape shape = MakeShapeWithDeviceLayout( - program_shape.result(), static_cast(coll.device.type())); std::vector instances; - instances.push_back({std::move(computation), coll.device.toString(), + instances.push_back({std::move(wrapped_computation), coll.device.toString(), xla::ComputationClient::Get()->GetCompilationDevices( coll.device.toString(), devices), &shape}); From 51aecfbdbdcc7da92293f564fc944eaf1690bfb8 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Thu, 25 Aug 2022 01:41:50 +0000 Subject: [PATCH 03/12] don't tuple the wrapped hlo's result --- torch_xla/csrc/lowering_context.cpp | 3 +-- torch_xla/csrc/tensor.cpp | 15 ++++++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 374ba074cb33..967d00c6ea8a 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -1,5 +1,6 @@ #include "torch_xla/csrc/lowering_context.h" +#include #include #include @@ -13,8 +14,6 @@ #include "torch_xla/csrc/computation.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/tensor_util.h" - -#include using std::cerr; namespace torch_xla { diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 7f6589586bae..d01b45bc7e4b 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -1695,7 +1695,9 @@ xla::StatusOr WrapComputation( // xla::ShapeIndex({input_index})); // } // } - return builder.Build(result_tuple); + + // return builder.Build(result_tuple); + return builder.Build(orig_result); } XLATensor::CompilationResult XLATensor::Compile( @@ -1755,14 +1757,17 @@ XLATensor::CompilationResult XLATensor::Compile( xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); xla::XlaComputation wrapped_computation = ConsumeValue(WrapComputation(computation, program_shape.parameters())); - xla::ProgramShape wrapped_program_shape = ConsumeValue(wrapped_computation.GetProgramShape()); - xla::Shape shape = MakeShapeWithDeviceLayout( - wrapped_program_shape.result(), static_cast(coll.device.type())); + xla::ProgramShape wrapped_program_shape = + ConsumeValue(wrapped_computation.GetProgramShape()); + xla::Shape shape = + MakeShapeWithDeviceLayout(wrapped_program_shape.result(), + static_cast(coll.device.type())); cerr << "hlo built = \n" << ConsumeValue(xla::util::GetComputationHloText(computation)) << "\n"; cerr << "wrapped hlo built = \n" - << ConsumeValue(xla::util::GetComputationHloText(wrapped_computation)) << "\n"; + << ConsumeValue(xla::util::GetComputationHloText(wrapped_computation)) + << "\n"; std::vector instances; instances.push_back({std::move(wrapped_computation), coll.device.toString(), From 6890aa1f410eba1a486dd5c96d84ce71cb3e570e Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Thu, 25 Aug 2022 02:12:42 +0000 Subject: [PATCH 04/12] clean up --- torch_xla/csrc/lowering_context.cpp | 2 -- torch_xla/csrc/tensor.cpp | 8 -------- 2 files changed, 10 deletions(-) diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 967d00c6ea8a..c1c9c7ba3529 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/lowering_context.h" -#include #include #include @@ -14,7 +13,6 @@ #include "torch_xla/csrc/computation.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/tensor_util.h" -using std::cerr; namespace torch_xla { namespace { diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index d01b45bc7e4b..57584e94711a 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -6,12 +6,10 @@ #include #include #include -#include #include #include #include #include -using std::cerr; #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" @@ -1763,12 +1761,6 @@ XLATensor::CompilationResult XLATensor::Compile( MakeShapeWithDeviceLayout(wrapped_program_shape.result(), static_cast(coll.device.type())); - cerr << "hlo built = \n" - << ConsumeValue(xla::util::GetComputationHloText(computation)) << "\n"; - cerr << "wrapped hlo built = \n" - << ConsumeValue(xla::util::GetComputationHloText(wrapped_computation)) - << "\n"; - std::vector instances; instances.push_back({std::move(wrapped_computation), coll.device.toString(), xla::ComputationClient::Get()->GetCompilationDevices( From a878fc6f76c5e4fe3c03cfec2cd9a68184932a1b Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Thu, 25 Aug 2022 02:59:44 +0000 Subject: [PATCH 05/12] preserve aliasing --- torch_xla/csrc/tensor.cpp | 43 ++++++++++++++++++++------------------- torch_xla/csrc/tensor.h | 6 +++--- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 57584e94711a..c8751b76e24e 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -1602,10 +1602,11 @@ XLATensor::OpByOpAsync XLATensor::SyncTensorsGraphOpByOp( return async_op.Schedule(); } -void XLATensor::BuildInputOutputAliases( +std::vector> XLATensor::BuildInputOutputAliases( const std::vector& tensors, absl::Span indices, LoweringContext* lowering_ctx) { std::unordered_map output_tensor_id_map; + std::vector> input_output_alias_pair; for (size_t i = 0; i < indices.size(); ++i) { size_t tensor_index = indices[i]; int64_t tensor_id = tensors[tensor_index]->GetUniqueId(); @@ -1626,9 +1627,12 @@ void XLATensor::BuildInputOutputAliases( const xla::Shape& root_shape = XlaHelpers::ShapeOfXlaOp(root); if (parameters_data[i]->shape() == root_shape && alias_map[output_index] < 0) { + // parameter is not a tuple so param_index will always be {} lowering_ctx->builder()->SetUpAlias( - {static_cast(output_index)}, i, {}); + {/*output_index=*/static_cast(output_index)}, + /*param_number=*/i, /*param_index=*/{}); alias_map[output_index] = i; + input_output_alias_pair.push_back(std::make_pair(i, output_index)); TF_VLOG(6) << "Aliased paramter " << i << " with output " << output_index << ": " << parameters_data[i]->shape(); @@ -1637,11 +1641,13 @@ void XLATensor::BuildInputOutputAliases( } } XLA_VALUE_METRIC("InputOutputAliasCount", alias_map.size()); + return input_output_alias_pair; } xla::StatusOr WrapComputation( const xla::XlaComputation& computation, - const std::vector& parameter_shapes) { + const std::vector& parameter_shapes, + std::vector> input_output_alias_pair) { xla::XlaBuilder builder(computation.proto().name()); // Construct a single tuple parameter. @@ -1680,21 +1686,14 @@ xla::StatusOr WrapComputation( xla::XlaOp result_tuple; { result_tuple = xla::Tuple(&builder, results); } - // // Preserve aliases. - // if (io_info.use_dummy_input()) { - // for (const auto& [input_index, output_index] : io_info.io_aliases) { - // Skip the dummy input at index 0. - // builder.SetUpAlias(xla::ShapeIndex({output_index}), 0, - // xla::ShapeIndex({input_index + 1})); - // } - // } else { - // for (const auto& [input_index, output_index] : io_info.io_aliases) { - // builder.SetUpAlias(xla::ShapeIndex({output_index}), 0, - // xla::ShapeIndex({input_index})); - // } - // } - - // return builder.Build(result_tuple); + for (const auto& [input_index, output_index] : input_output_alias_pair) { + // Both input and output will be a tuple so parameter_number will always be + // 0 + builder.SetUpAlias(/*output_index=*/xla::ShapeIndex({output_index}), + /*param_number=*/0, + /*param_index=*/xla::ShapeIndex({input_index})); + } + return builder.Build(orig_result); } @@ -1723,6 +1722,7 @@ XLATensor::CompilationResult XLATensor::Compile( // Annotate HLO sharding selectively in the compuation. ShardingUtil::SetHloSharding(&lowering_ctx); + std::vector> input_output_alias_pair; if (enable_aliasing && coll.config.sync_xla_data) { // We can only alias at the step barrier, when force_xla_data is true. // Consider the case: @@ -1748,13 +1748,14 @@ XLATensor::CompilationResult XLATensor::Compile( // will later fetch the new value of A, which is incorrect. // But, when we issue a step barrier (force_xla_data == true) we have to // turn everything into DEVICE_DATA, so we can activate aliasing. - BuildInputOutputAliases(tensors, coll.indices, &lowering_ctx); + input_output_alias_pair = + BuildInputOutputAliases(tensors, coll.indices, &lowering_ctx); } xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla()); xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); - xla::XlaComputation wrapped_computation = - ConsumeValue(WrapComputation(computation, program_shape.parameters())); + xla::XlaComputation wrapped_computation = ConsumeValue(WrapComputation( + computation, program_shape.parameters(), input_output_alias_pair)); xla::ProgramShape wrapped_program_shape = ConsumeValue(wrapped_computation.GetProgramShape()); xla::Shape shape = diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 868e338f6196..8721c909dfef 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -1445,9 +1445,9 @@ class XLATensor : public c10::intrusive_ptr_target { std::vector* tensors, SyncTensorCollection* coll, PostOrderData* po_data); - static void BuildInputOutputAliases(const std::vector& tensors, - absl::Span indices, - LoweringContext* lowering_ctx); + static std::vector> BuildInputOutputAliases( + const std::vector& tensors, + absl::Span indices, LoweringContext* lowering_ctx); static CompilationResult Compile(const std::vector& tensors, absl::Span devices, From 53f9dc0bf12b6fe170c14d982e49d04a009ff07b Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 26 Aug 2022 01:37:22 +0000 Subject: [PATCH 06/12] Add XLA_PARAMETER_WRAPPING_THREADSHOLD --- third_party/xla_client/computation_client.h | 7 ++-- .../xla_client/pjrt_computation_client.cc | 20 +++++++---- torch_xla/csrc/tensor.cpp | 36 +++++++++++++------ 3 files changed, 43 insertions(+), 20 deletions(-) diff --git a/third_party/xla_client/computation_client.h b/third_party/xla_client/computation_client.h index 0a112c3cdeaf..abc8f883fe58 100644 --- a/third_party/xla_client/computation_client.h +++ b/third_party/xla_client/computation_client.h @@ -134,16 +134,19 @@ class ComputationClient { struct CompileInstance { CompileInstance() = default; CompileInstance(XlaComputation computation, std::string compilation_device, - std::vector devices, const Shape* output_shape) + std::vector devices, const Shape* output_shape, + bool parameter_is_tupled_arguments = false) : computation(std::move(computation)), compilation_device(std::move(compilation_device)), devices(std::move(devices)), - output_shape(output_shape) {} + output_shape(output_shape), + parameter_is_tupled_arguments(parameter_is_tupled_arguments) {} XlaComputation computation; std::string compilation_device; std::vector devices; const Shape* output_shape = nullptr; + bool parameter_is_tupled_arguments; }; struct ExecuteOptions { diff --git a/third_party/xla_client/pjrt_computation_client.cc b/third_party/xla_client/pjrt_computation_client.cc index a556912e52e2..ef8647ef654c 100644 --- a/third_party/xla_client/pjrt_computation_client.cc +++ b/third_party/xla_client/pjrt_computation_client.cc @@ -9,8 +9,8 @@ #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/pjrt/cpu_device.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" -#include "tensorflow/compiler/xla/pjrt/tpu_client.h" #include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/tpu_client.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_client/computation_client.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" @@ -49,7 +49,8 @@ PjRtComputationClient::PjRtComputationClient() { TF_VLOG(1) << "Initializing PjRt CPU client..."; bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true); int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1); - client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).ValueOrDie()); + client_ = + std::move(xla::GetTfrtCpuClient(async, cpu_device_count).ValueOrDie()); } else if (device_type == "TPU") { TF_VLOG(1) << "Initializing PjRt TPU client..."; int64_t max_inflight_computations = sys_util::GetEnvInt( @@ -83,7 +84,8 @@ ComputationClient::DataPtr PjRtComputationClient::CreateDataPlaceholder( std::vector PjRtComputationClient::TransferToServer( absl::Span tensors) { tensorflow::profiler::TraceMe activity( - "PjRtComputationClient::TransferToServer", tensorflow::profiler::TraceMeLevel::kInfo); + "PjRtComputationClient::TransferToServer", + tensorflow::profiler::TraceMeLevel::kInfo); std::vector datas; datas.reserve(tensors.size()); for (auto& tensor : tensors) { @@ -119,7 +121,8 @@ std::vector PjRtComputationClient::TransferToServer( std::vector PjRtComputationClient::TransferFromServer( absl::Span handles) { tensorflow::profiler::TraceMe activity( - "PjRtComputationClient::TransferFromServer", tensorflow::profiler::TraceMeLevel::kInfo); + "PjRtComputationClient::TransferFromServer", + tensorflow::profiler::TraceMeLevel::kInfo); std::vector literals; literals.reserve(handles.size()); @@ -137,7 +140,8 @@ std::vector PjRtComputationClient::TransferFromServer( std::vector PjRtComputationClient::Compile( std::vector instances) { tensorflow::profiler::TraceMe activity( - "PjRtComputationClient::Compile", tensorflow::profiler::TraceMeLevel::kInfo); + "PjRtComputationClient::Compile", + tensorflow::profiler::TraceMeLevel::kInfo); std::vector computations; for (auto& instance : instances) { @@ -153,7 +157,8 @@ std::vector PjRtComputationClient::Compile( compile_options.executable_build_options.set_num_partitions(1); compile_options.executable_build_options.set_num_replicas( client_->device_count()); - compile_options.parameter_is_tupled_arguments = true; + compile_options.parameter_is_tupled_arguments = + instance.parameter_is_tupled_arguments; std::unique_ptr executable = client_->Compile(instance.computation, compile_options).ValueOrDie(); std::shared_ptr pjrt_computation = @@ -173,7 +178,8 @@ PjRtComputationClient::ExecuteComputation( absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) { tensorflow::profiler::TraceMe activity( - "PjRtComputationClient::ExecuteComputation", tensorflow::profiler::TraceMeLevel::kInfo); + "PjRtComputationClient::ExecuteComputation", + tensorflow::profiler::TraceMeLevel::kInfo); TF_VLOG(1) << "Executing PjRt computation on " << device; const PjRtComputation& pjrt_computation = dynamic_cast(computation); diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index c8751b76e24e..ffa39e33d04e 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -1710,6 +1710,10 @@ XLATensor::CompilationResult XLATensor::Compile( tensorflow::profiler::TraceMeLevel::kInfo); static const bool enable_aliasing = xla::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true); + static const size_t parameter_wrapping_threadshold = + xla::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200); + static const bool using_pjrt = + xla::sys_util::GetEnvString("PJRT_DEVICE", "").size() > 0; LoweringContext lowering_ctx("SyncTensorsGraph", coll.device, po_data->post_order, std::move(po_data->emission_map)); @@ -1754,19 +1758,23 @@ XLATensor::CompilationResult XLATensor::Compile( xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla()); xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); - xla::XlaComputation wrapped_computation = ConsumeValue(WrapComputation( - computation, program_shape.parameters(), input_output_alias_pair)); - xla::ProgramShape wrapped_program_shape = - ConsumeValue(wrapped_computation.GetProgramShape()); - xla::Shape shape = - MakeShapeWithDeviceLayout(wrapped_program_shape.result(), - static_cast(coll.device.type())); + + bool should_wrap_parameter = + (program_shape.parameters_size() >= parameter_wrapping_threadshold) && + using_pjrt; + if (should_wrap_parameter) { + computation = ConsumeValue(WrapComputation( + computation, program_shape.parameters(), input_output_alias_pair)); + program_shape = ConsumeValue(computation.GetProgramShape()); + } + xla::Shape shape = MakeShapeWithDeviceLayout( + program_shape.result(), static_cast(coll.device.type())); std::vector instances; - instances.push_back({std::move(wrapped_computation), coll.device.toString(), + instances.push_back({std::move(computation), coll.device.toString(), xla::ComputationClient::Get()->GetCompilationDevices( coll.device.toString(), devices), - &shape}); + &shape, should_wrap_parameter}); TF_VLOG(3) << "Compiling IR graph hash " << torch::lazy::HashToString(coll.hash) << " on device " @@ -1782,8 +1790,14 @@ XLATensor::CompilationResult XLATensor::Compile( << " is computation hash " << torch::lazy::HashToString(torch::lazy::Hash( computations.front()->computation().proto().SerializeAsString())); - XLA_CHECK_EQ(program_shape.parameters_size(), - po_data->parameters_data.size()); + if (should_wrap_parameter) { + XLA_CHECK_EQ(program_shape.parameters_size(), 1); + XLA_CHECK_EQ(program_shape.parameters()[0].tuple_shapes_size(), + po_data->parameters_data.size()); + } else { + XLA_CHECK_EQ(program_shape.parameters_size(), + po_data->parameters_data.size()); + } return {/*device=*/coll.device, /*emitted_nodes=*/lowering_ctx.GetEmittedNodeCount(), From a80d4ffdacad25fc298e96325af495914c932cd4 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 26 Aug 2022 01:50:01 +0000 Subject: [PATCH 07/12] code refactor --- torch_xla/csrc/helpers.cpp | 53 ++++++++++++++++++++++++++++++++++++ torch_xla/csrc/helpers.h | 5 ++++ torch_xla/csrc/tensor.cpp | 55 +------------------------------------- 3 files changed, 59 insertions(+), 54 deletions(-) diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 35797452e398..1dd96af9c95e 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -577,4 +577,57 @@ xla::XlaOp XlaHelpers::PromotedLogicalUnaryOp( return unary_op(op); } +xla::StatusOr XlaHelpers::WrapXlaComputation( + const xla::XlaComputation& computation, + const std::vector& parameter_shapes, + std::vector> input_output_alias_pair) { + xla::XlaBuilder builder(computation.proto().name()); + + // Construct a single tuple parameter. + const xla::XlaOp input_tuple = [&builder, ¶meter_shapes]() { + xla::Shape input_tuple; + input_tuple.set_element_type(xla::PrimitiveType::TUPLE); + input_tuple.mutable_tuple_shapes()->reserve(parameter_shapes.size()); + for (int i = 0; i < parameter_shapes.size(); ++i) { + *input_tuple.add_tuple_shapes() = parameter_shapes[i]; + } + return xla::Parameter(&builder, 0, input_tuple, "in"); + }(); + + // Handle the results of the original computation. + const std::vector inner_params = [&input_tuple, + ¶meter_shapes]() { + std::vector parameters; + parameters.reserve(parameter_shapes.size()); + for (int i = 0; i < parameter_shapes.size(); ++i) { + parameters.push_back(xla::GetTupleElement(input_tuple, i)); + } + return parameters; + }(); + + // Call the original computation. + xla::XlaOp orig_result; + orig_result = xla::Call(&builder, computation, inner_params); + + // Construct a single tuple result. + const std::vector results = [&orig_result]() { + std::vector results; + results.push_back(orig_result); + return results; + }(); + + xla::XlaOp result_tuple; + { result_tuple = xla::Tuple(&builder, results); } + + for (const auto& [input_index, output_index] : input_output_alias_pair) { + // Both input and output will be a tuple so parameter_number will always be + // 0 + builder.SetUpAlias(/*output_index=*/xla::ShapeIndex({output_index}), + /*param_number=*/0, + /*param_index=*/xla::ShapeIndex({input_index})); + } + + return builder.Build(orig_result); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 83625e410beb..bf796164c59c 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -330,6 +330,11 @@ class XlaHelpers { s_mat_mul_precision = precision; } + static xla::StatusOr WrapXlaComputation( + const xla::XlaComputation& computation, + const std::vector& parameter_shapes, + std::vector> input_output_alias_pair); + private: static xla::PrecisionConfig::Precision s_mat_mul_precision; }; diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index ffa39e33d04e..7c74902caa4b 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -1644,59 +1644,6 @@ std::vector> XLATensor::BuildInputOutputAliases( return input_output_alias_pair; } -xla::StatusOr WrapComputation( - const xla::XlaComputation& computation, - const std::vector& parameter_shapes, - std::vector> input_output_alias_pair) { - xla::XlaBuilder builder(computation.proto().name()); - - // Construct a single tuple parameter. - const xla::XlaOp input_tuple = [&builder, ¶meter_shapes]() { - xla::Shape input_tuple; - input_tuple.set_element_type(xla::PrimitiveType::TUPLE); - input_tuple.mutable_tuple_shapes()->reserve(parameter_shapes.size()); - for (int i = 0; i < parameter_shapes.size(); ++i) { - *input_tuple.add_tuple_shapes() = parameter_shapes[i]; - } - return xla::Parameter(&builder, 0, input_tuple, "in"); - }(); - - // Handle the results of the original computation. - const std::vector inner_params = [&input_tuple, - ¶meter_shapes]() { - std::vector parameters; - parameters.reserve(parameter_shapes.size()); - for (int i = 0; i < parameter_shapes.size(); ++i) { - parameters.push_back(xla::GetTupleElement(input_tuple, i)); - } - return parameters; - }(); - - // Call the original computation. - xla::XlaOp orig_result; - orig_result = xla::Call(&builder, computation, inner_params); - - // Construct a single tuple result. - const std::vector results = [&orig_result]() { - std::vector results; - results.push_back(orig_result); - return results; - }(); - - xla::XlaOp result_tuple; - { result_tuple = xla::Tuple(&builder, results); } - - for (const auto& [input_index, output_index] : input_output_alias_pair) { - // Both input and output will be a tuple so parameter_number will always be - // 0 - builder.SetUpAlias(/*output_index=*/xla::ShapeIndex({output_index}), - /*param_number=*/0, - /*param_index=*/xla::ShapeIndex({input_index})); - } - - return builder.Build(orig_result); -} - XLATensor::CompilationResult XLATensor::Compile( const std::vector& tensors, absl::Span devices, const SyncTensorCollection& coll, @@ -1763,7 +1710,7 @@ XLATensor::CompilationResult XLATensor::Compile( (program_shape.parameters_size() >= parameter_wrapping_threadshold) && using_pjrt; if (should_wrap_parameter) { - computation = ConsumeValue(WrapComputation( + computation = ConsumeValue(XlaHelpers::WrapXlaComputation( computation, program_shape.parameters(), input_output_alias_pair)); program_shape = ConsumeValue(computation.GetProgramShape()); } From a198ce5b482ec5c8b1df07b46a35b3d6224ebc69 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 26 Aug 2022 02:06:57 +0000 Subject: [PATCH 08/12] Add debug VLOG --- torch_xla/csrc/tensor.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 7c74902caa4b..bd074b724f0c 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -1710,6 +1710,9 @@ XLATensor::CompilationResult XLATensor::Compile( (program_shape.parameters_size() >= parameter_wrapping_threadshold) && using_pjrt; if (should_wrap_parameter) { + TF_VLOG(3) << "Wrapping graph with " << program_shape.parameters_size() + << " parameters. Threadshold = " + << parameter_wrapping_threadshold; computation = ConsumeValue(XlaHelpers::WrapXlaComputation( computation, program_shape.parameters(), input_output_alias_pair)); program_shape = ConsumeValue(computation.GetProgramShape()); From 55797a727409a67ba17d5c6588d5208a54525bc0 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 26 Aug 2022 02:23:06 +0000 Subject: [PATCH 09/12] Add test --- test/test_operations.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/test_operations.py b/test/test_operations.py index 8ce4f97e0ec3..d960ce3d0829 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1800,6 +1800,16 @@ def test_fn(m): self.runAtenTest([torch.randint(1, 4, (7, 7), dtype=torch.uint8)], test_fn) + def test_too_many_parameter(self): + + def test_fn(t): + # TPU can handle ~3500 parameters on v3 without parameter tupling. + for i in range(4000): + t += torch.tensor(i, dtype=torch.float, device=t.device) + return t + + self.runAtenTest([torch.tensor(20.0)], test_fn) + def test_view_and_copy_(self): xla_device = xm.xla_device() x = torch.tensor([1.5, 2.5, 3.5, 4.5, 5.5, 6.5], device='cpu') From 2bd6bb224d48e20c7384f53c06a90deb47a8aa05 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 26 Aug 2022 02:38:03 +0000 Subject: [PATCH 10/12] clean up --- torch_xla/csrc/helpers.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 1dd96af9c95e..5f24fe3d4c30 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -616,9 +616,6 @@ xla::StatusOr XlaHelpers::WrapXlaComputation( return results; }(); - xla::XlaOp result_tuple; - { result_tuple = xla::Tuple(&builder, results); } - for (const auto& [input_index, output_index] : input_output_alias_pair) { // Both input and output will be a tuple so parameter_number will always be // 0 From bf9b61ec0a91b0d6cd51e3c1ce23f9b4f45f9fb2 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 26 Aug 2022 02:55:17 +0000 Subject: [PATCH 11/12] clean up --- torch_xla/csrc/helpers.cpp | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 5f24fe3d4c30..7442745f9c1b 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -609,13 +609,7 @@ xla::StatusOr XlaHelpers::WrapXlaComputation( xla::XlaOp orig_result; orig_result = xla::Call(&builder, computation, inner_params); - // Construct a single tuple result. - const std::vector results = [&orig_result]() { - std::vector results; - results.push_back(orig_result); - return results; - }(); - + // Rebuild aliasing. for (const auto& [input_index, output_index] : input_output_alias_pair) { // Both input and output will be a tuple so parameter_number will always be // 0 From 6c8763af5be37f792eeca85fa08e1ed9467dbccd Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 26 Aug 2022 23:33:04 +0000 Subject: [PATCH 12/12] address review comments --- torch_xla/csrc/helpers.cpp | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 7442745f9c1b..e4b5a98a7e55 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -584,30 +584,23 @@ xla::StatusOr XlaHelpers::WrapXlaComputation( xla::XlaBuilder builder(computation.proto().name()); // Construct a single tuple parameter. - const xla::XlaOp input_tuple = [&builder, ¶meter_shapes]() { - xla::Shape input_tuple; - input_tuple.set_element_type(xla::PrimitiveType::TUPLE); - input_tuple.mutable_tuple_shapes()->reserve(parameter_shapes.size()); - for (int i = 0; i < parameter_shapes.size(); ++i) { - *input_tuple.add_tuple_shapes() = parameter_shapes[i]; - } - return xla::Parameter(&builder, 0, input_tuple, "in"); - }(); + xla::Shape input_tuple_shape; + input_tuple_shape.set_element_type(xla::PrimitiveType::TUPLE); + input_tuple_shape.mutable_tuple_shapes()->reserve(parameter_shapes.size()); + for (int i = 0; i < parameter_shapes.size(); ++i) { + *input_tuple_shape.add_tuple_shapes() = parameter_shapes[i]; + } + xla::XlaOp input_tuple = xla::Parameter(&builder, 0, input_tuple_shape, "in"); // Handle the results of the original computation. - const std::vector inner_params = [&input_tuple, - ¶meter_shapes]() { - std::vector parameters; - parameters.reserve(parameter_shapes.size()); - for (int i = 0; i < parameter_shapes.size(); ++i) { - parameters.push_back(xla::GetTupleElement(input_tuple, i)); - } - return parameters; - }(); + std::vector inner_params; + inner_params.reserve(parameter_shapes.size()); + for (int i = 0; i < parameter_shapes.size(); ++i) { + inner_params.push_back(xla::GetTupleElement(input_tuple, i)); + } // Call the original computation. - xla::XlaOp orig_result; - orig_result = xla::Call(&builder, computation, inner_params); + xla::XlaOp orig_result = xla::Call(&builder, computation, inner_params); // Rebuild aliasing. for (const auto& [input_index, output_index] : input_output_alias_pair) {