diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index daa9ea8bb32d..cea7d3c26dca 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -320,10 +320,6 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle() << ", shape=" << handle->shape() << ")"; // TODO: handle replicated data - // if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) { - // // Data is replicated, return the first shard - // return sharded_data->shards[0]; - // } xla::XlaBuilder builder("ReplicateShardedData"); xla::Shape shape = handle->shape(); builder.SetSharding(handle->GetSharding()); @@ -489,87 +485,6 @@ IfrtComputationClient::ExecuteComputation( const std::string& device, const ExecuteComputationOptions& options) { // TODO: Implement sharded exec in IFRT XLA_ERROR() << __FUNCTION__ << " not implemented"; - // // Shared ownership of the timed section ensures that it will only get logged - // // once both `ExecuteComputation` and the async work in `ExecuteSharded` are - // // complete; a copy is held from the lambda that releases it when done. - // auto timed = std::make_shared(ExecuteMetric()); - // tsl::profiler::TraceMe activity("IfrtComputationClient::ExecuteComputation", - // tsl::profiler::TraceMeLevel::kInfo); - // TF_VLOG(1) << "Executing Ifrt computation on " << device; - // const IfrtComputation& pjrt_computation = - // dynamic_cast(computation); - - // xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device); - // XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); - - // std::vector> buffers; - // buffers.reserve(arguments.size()); - // for (auto& argument : arguments) { - // const IfrtData* pjrt_data = dynamic_cast(argument.get()); - - // // XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) - // // << pjrt_device->DebugString() << " vs " - // // << pjrt_data->buffer->device()->DebugString(); - // buffers.push_back(pjrt_data->buffer); - // } - - // xla::ExecuteOptions execute_options; - // execute_options.untuple_result = options.explode_tuple; - // execute_options.strict_shape_checking = false; - - // // Required as of cl/518733871 - // execute_options.use_major_to_minor_data_layout_for_callbacks = true; - - // xla::ifrt::DeviceList device_list({pjrt_device}); - // xla::ifrt::LoadedExecutable::ExecuteResult result = - // pjrt_computation.executable - // ->Execute(absl::MakeSpan(buffers), execute_options, device_list) - // .value(); - - // xla::ifrt::Future returned_future = result.status; - - // auto results = result.outputs; - // std::vector datas; - // datas.reserve(results.size()); - // for (auto& result : results) { - // tsl::RCReference buffer = std::move(result); - - // std::shared_ptr data = - // std::make_shared(device, std::move(buffer)); - - // datas.push_back(data); - // } - // CreateDataHandlesCounter()->AddValue(datas.size()); - - // auto mwait = std::make_shared(1); - // auto lockfn = [&, this, device, returned_future = std::move(returned_future), - // timed]() mutable { - // TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " - // << device; - // // Grab the shared lock and block the `WaitDeviceOps` until buffer is - // // ready. - // // TODO(JackCaoG): This lock should acquired outside of the lockfn and - // // passed in. It is possible that lockfn started after ExecuteComputation - // // released the xla_graph_executor lock, which will create a short windows - // // where device is unlcoked while execution is still running. - // auto lock = lock_device_shared(device); - // TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device - // << " Done"; - // // Signal that `ExecuteSharded` has completed for the ExecuteTime - // // metric. Copies the `timed` shared pointer into the lambda. - // XLA_CHECK(returned_future.IsValid()) - // << "returned_future in ExecuteComputation is empty"; - // returned_future.OnReady( - // [timed, lock = std::move(lock)](xla::Status unused) mutable { - // timed.reset(); - // TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished"; - // }); - // }; - - // env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn))); - - // TF_VLOG(1) << "Returning " << datas.size() << " results"; - // return datas; } std::vector diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 091a5ee49201..468c1050783d 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -187,60 +187,6 @@ class IfrtComputationClient : public ComputationClient { tsl::RCReference ReplicateShardedData( const std::shared_ptr handle); - // struct PjRtShardedData : public Data { - // PjRtShardedData(std::string device, xla::Shape shape) = delete; - - // PjRtShardedData(std::string device, xla::Shape shape, - // std::vector> shards, - // xla::OpSharding sharding) - // : Data(std::move(device), std::move(shape)), - // shards(shards), - // sharding(sharding) {} - - // Handle GetHandle() override { - // // Always returns `Handle` of the first shard. - // return shards[0]->GetHandle(); - // } - - // void Assign(const torch::lazy::BackendData& data) override { - // const PjRtShardedData& pjrt_sharded_data = - // dynamic_cast(data); - // if (&pjrt_sharded_data != this) { - // shards = std::move(pjrt_sharded_data.shards); - // } - // } - - // bool HasValue() const override { - // if (shards.empty()) { - // return false; - // } - - // for (auto& shard : shards) { - // if (!shard->HasValue()) { - // return false; - // } - // } - // return true; - // } - - // std::string ToString() const override { - // std::stringstream ss; - // ss << "XLAShardedData: \n"; - // ss << " Data Device: " << device() << "\n"; - // ss << " Data Shape: " << shape().ToString() << "\n"; - // ss << " OpSharding: " - // << xla::HloSharding::FromProto(sharding)->ToString() << "\n"; - // ss << " NumShards: " << shards.size() << "\n"; - // return ss.str(); - // } - - // bool HasSharding() const override { return true; } - - // xla::OpSharding GetSharding() const override { return sharding; } - - // std::vector> shards; - // xla::OpSharding sharding; - // }; struct IfrtComputation : public Computation { IfrtComputation(xla::XlaComputation computation,