Skip to content

Commit

Permalink
remove some commented out code
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 29, 2023
1 parent cb64bee commit e4856a3
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 139 deletions.
85 changes: 0 additions & 85 deletions torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,6 @@ tsl::RCReference<xla::ifrt::Array> IfrtComputationClient::ReplicateShardedData(
TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle()
<< ", shape=" << handle->shape() << ")";
// TODO: handle replicated data
// if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) {
// // Data is replicated, return the first shard
// return sharded_data->shards[0];
// }
xla::XlaBuilder builder("ReplicateShardedData");
xla::Shape shape = handle->shape();
builder.SetSharding(handle->GetSharding());
Expand Down Expand Up @@ -489,87 +485,6 @@ IfrtComputationClient::ExecuteComputation(
const std::string& device, const ExecuteComputationOptions& options) {
// TODO: Implement sharded exec in IFRT
XLA_ERROR() << __FUNCTION__ << " not implemented";
// // Shared ownership of the timed section ensures that it will only get logged
// // once both `ExecuteComputation` and the async work in `ExecuteSharded` are
// // complete; a copy is held from the lambda that releases it when done.
// auto timed = std::make_shared<metrics::TimedSection>(ExecuteMetric());
// tsl::profiler::TraceMe activity("IfrtComputationClient::ExecuteComputation",
// tsl::profiler::TraceMeLevel::kInfo);
// TF_VLOG(1) << "Executing Ifrt computation on " << device;
// const IfrtComputation& pjrt_computation =
// dynamic_cast<const IfrtComputation&>(computation);

// xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device);
// XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();

// std::vector<tsl::RCReference<xla::ifrt::Array>> buffers;
// buffers.reserve(arguments.size());
// for (auto& argument : arguments) {
// const IfrtData* pjrt_data = dynamic_cast<IfrtData*>(argument.get());

// // XLA_CHECK(pjrt_device == pjrt_data->buffer->device())
// // << pjrt_device->DebugString() << " vs "
// // << pjrt_data->buffer->device()->DebugString();
// buffers.push_back(pjrt_data->buffer);
// }

// xla::ExecuteOptions execute_options;
// execute_options.untuple_result = options.explode_tuple;
// execute_options.strict_shape_checking = false;

// // Required as of cl/518733871
// execute_options.use_major_to_minor_data_layout_for_callbacks = true;

// xla::ifrt::DeviceList device_list({pjrt_device});
// xla::ifrt::LoadedExecutable::ExecuteResult result =
// pjrt_computation.executable
// ->Execute(absl::MakeSpan(buffers), execute_options, device_list)
// .value();

// xla::ifrt::Future<xla::Status> returned_future = result.status;

// auto results = result.outputs;
// std::vector<DataPtr> datas;
// datas.reserve(results.size());
// for (auto& result : results) {
// tsl::RCReference<xla::ifrt::Array> buffer = std::move(result);

// std::shared_ptr<IfrtData> data =
// std::make_shared<IfrtData>(device, std::move(buffer));

// datas.push_back(data);
// }
// CreateDataHandlesCounter()->AddValue(datas.size());

// auto mwait = std::make_shared<util::MultiWait>(1);
// auto lockfn = [&, this, device, returned_future = std::move(returned_future),
// timed]() mutable {
// TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for "
// << device;
// // Grab the shared lock and block the `WaitDeviceOps` until buffer is
// // ready.
// // TODO(JackCaoG): This lock should acquired outside of the lockfn and
// // passed in. It is possible that lockfn started after ExecuteComputation
// // released the xla_graph_executor lock, which will create a short windows
// // where device is unlcoked while execution is still running.
// auto lock = lock_device_shared(device);
// TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device
// << " Done";
// // Signal that `ExecuteSharded` has completed for the ExecuteTime
// // metric. Copies the `timed` shared pointer into the lambda.
// XLA_CHECK(returned_future.IsValid())
// << "returned_future in ExecuteComputation is empty";
// returned_future.OnReady(
// [timed, lock = std::move(lock)](xla::Status unused) mutable {
// timed.reset();
// TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished";
// });
// };

// env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn)));

// TF_VLOG(1) << "Returning " << datas.size() << " results";
// return datas;
}

std::vector<ComputationClient::DataPtr>
Expand Down
54 changes: 0 additions & 54 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,60 +187,6 @@ class IfrtComputationClient : public ComputationClient {

tsl::RCReference<xla::ifrt::Array> ReplicateShardedData(
const std::shared_ptr<IfrtData> handle);
// struct PjRtShardedData : public Data {
// PjRtShardedData(std::string device, xla::Shape shape) = delete;

// PjRtShardedData(std::string device, xla::Shape shape,
// std::vector<std::shared_ptr<PjRtData>> shards,
// xla::OpSharding sharding)
// : Data(std::move(device), std::move(shape)),
// shards(shards),
// sharding(sharding) {}

// Handle GetHandle() override {
// // Always returns `Handle` of the first shard.
// return shards[0]->GetHandle();
// }

// void Assign(const torch::lazy::BackendData& data) override {
// const PjRtShardedData& pjrt_sharded_data =
// dynamic_cast<const PjRtShardedData&>(data);
// if (&pjrt_sharded_data != this) {
// shards = std::move(pjrt_sharded_data.shards);
// }
// }

// bool HasValue() const override {
// if (shards.empty()) {
// return false;
// }

// for (auto& shard : shards) {
// if (!shard->HasValue()) {
// return false;
// }
// }
// return true;
// }

// std::string ToString() const override {
// std::stringstream ss;
// ss << "XLAShardedData: \n";
// ss << " Data Device: " << device() << "\n";
// ss << " Data Shape: " << shape().ToString() << "\n";
// ss << " OpSharding: "
// << xla::HloSharding::FromProto(sharding)->ToString() << "\n";
// ss << " NumShards: " << shards.size() << "\n";
// return ss.str();
// }

// bool HasSharding() const override { return true; }

// xla::OpSharding GetSharding() const override { return sharding; }

// std::vector<std::shared_ptr<PjRtData>> shards;
// xla::OpSharding sharding;
// };

struct IfrtComputation : public Computation {
IfrtComputation(xla::XlaComputation computation,
Expand Down

0 comments on commit e4856a3

Please sign in to comment.