Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Helper function to wrap entry HLO #3920

Merged
merged 12 commits into from
Aug 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,16 @@ def test_fn(m):

self.runAtenTest([torch.randint(1, 4, (7, 7), dtype=torch.uint8)], test_fn)

def test_too_many_parameter(self):

def test_fn(t):
# TPU can handle ~3500 parameters on v3 without parameter tupling.
for i in range(4000):
t += torch.tensor(i, dtype=torch.float, device=t.device)
return t

self.runAtenTest([torch.tensor(20.0)], test_fn)

def test_view_and_copy_(self):
xla_device = xm.xla_device()
x = torch.tensor([1.5, 2.5, 3.5, 4.5, 5.5, 6.5], device='cpu')
Expand Down
7 changes: 5 additions & 2 deletions third_party/xla_client/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,19 @@ class ComputationClient {
struct CompileInstance {
CompileInstance() = default;
CompileInstance(XlaComputation computation, std::string compilation_device,
std::vector<std::string> devices, const Shape* output_shape)
std::vector<std::string> devices, const Shape* output_shape,
bool parameter_is_tupled_arguments = false)
: computation(std::move(computation)),
compilation_device(std::move(compilation_device)),
devices(std::move(devices)),
output_shape(output_shape) {}
output_shape(output_shape),
parameter_is_tupled_arguments(parameter_is_tupled_arguments) {}

XlaComputation computation;
std::string compilation_device;
std::vector<std::string> devices;
const Shape* output_shape = nullptr;
bool parameter_is_tupled_arguments;
};

struct ExecuteOptions {
Expand Down
19 changes: 13 additions & 6 deletions third_party/xla_client/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/xla_client/computation_client.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
Expand Down Expand Up @@ -49,7 +49,8 @@ PjRtComputationClient::PjRtComputationClient() {
TF_VLOG(1) << "Initializing PjRt CPU client...";
bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true);
int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1);
client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).ValueOrDie());
client_ =
std::move(xla::GetTfrtCpuClient(async, cpu_device_count).ValueOrDie());
} else if (device_type == "TPU") {
TF_VLOG(1) << "Initializing PjRt TPU client...";
int64_t max_inflight_computations = sys_util::GetEnvInt(
Expand Down Expand Up @@ -83,7 +84,8 @@ ComputationClient::DataPtr PjRtComputationClient::CreateDataPlaceholder(
std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
absl::Span<const TensorSource> tensors) {
tensorflow::profiler::TraceMe activity(
"PjRtComputationClient::TransferToServer", tensorflow::profiler::TraceMeLevel::kInfo);
"PjRtComputationClient::TransferToServer",
tensorflow::profiler::TraceMeLevel::kInfo);
std::vector<ComputationClient::DataPtr> datas;
datas.reserve(tensors.size());
for (auto& tensor : tensors) {
Expand Down Expand Up @@ -119,7 +121,8 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
absl::Span<const DataPtr> handles) {
tensorflow::profiler::TraceMe activity(
"PjRtComputationClient::TransferFromServer", tensorflow::profiler::TraceMeLevel::kInfo);
"PjRtComputationClient::TransferFromServer",
tensorflow::profiler::TraceMeLevel::kInfo);
std::vector<xla::Literal> literals;
literals.reserve(handles.size());

Expand All @@ -137,7 +140,8 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
std::vector<ComputationClient::CompileInstance> instances) {
tensorflow::profiler::TraceMe activity(
"PjRtComputationClient::Compile", tensorflow::profiler::TraceMeLevel::kInfo);
"PjRtComputationClient::Compile",
tensorflow::profiler::TraceMeLevel::kInfo);
std::vector<ComputationClient::ComputationPtr> computations;

for (auto& instance : instances) {
Expand All @@ -153,6 +157,8 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
compile_options.executable_build_options.set_num_partitions(1);
compile_options.executable_build_options.set_num_replicas(
client_->device_count());
compile_options.parameter_is_tupled_arguments =
instance.parameter_is_tupled_arguments;
std::unique_ptr<xla::PjRtExecutable> executable =
client_->Compile(instance.computation, compile_options).ValueOrDie();
std::shared_ptr<PjRtComputation> pjrt_computation =
Expand All @@ -172,7 +178,8 @@ PjRtComputationClient::ExecuteComputation(
absl::Span<const ComputationClient::DataPtr> arguments,
const std::string& device, const ExecuteComputationOptions& options) {
tensorflow::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteComputation", tensorflow::profiler::TraceMeLevel::kInfo);
"PjRtComputationClient::ExecuteComputation",
tensorflow::profiler::TraceMeLevel::kInfo);
TF_VLOG(1) << "Executing PjRt computation on " << device;
const PjRtComputation& pjrt_computation =
dynamic_cast<const PjRtComputation&>(computation);
Expand Down
37 changes: 37 additions & 0 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,4 +577,41 @@ xla::XlaOp XlaHelpers::PromotedLogicalUnaryOp(
return unary_op(op);
}

xla::StatusOr<xla::XlaComputation> XlaHelpers::WrapXlaComputation(
const xla::XlaComputation& computation,
const std::vector<xla::Shape>& parameter_shapes,
std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair) {
xla::XlaBuilder builder(computation.proto().name());

// Construct a single tuple parameter.
xla::Shape input_tuple_shape;
input_tuple_shape.set_element_type(xla::PrimitiveType::TUPLE);
input_tuple_shape.mutable_tuple_shapes()->reserve(parameter_shapes.size());
for (int i = 0; i < parameter_shapes.size(); ++i) {
*input_tuple_shape.add_tuple_shapes() = parameter_shapes[i];
}
xla::XlaOp input_tuple = xla::Parameter(&builder, 0, input_tuple_shape, "in");

// Handle the results of the original computation.
std::vector<xla::XlaOp> inner_params;
inner_params.reserve(parameter_shapes.size());
for (int i = 0; i < parameter_shapes.size(); ++i) {
inner_params.push_back(xla::GetTupleElement(input_tuple, i));
}

// Call the original computation.
xla::XlaOp orig_result = xla::Call(&builder, computation, inner_params);

// Rebuild aliasing.
for (const auto& [input_index, output_index] : input_output_alias_pair) {
// Both input and output will be a tuple so parameter_number will always be
// 0
builder.SetUpAlias(/*output_index=*/xla::ShapeIndex({output_index}),
/*param_number=*/0,
/*param_index=*/xla::ShapeIndex({input_index}));
}

return builder.Build(orig_result);
}

} // namespace torch_xla
5 changes: 5 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ class XlaHelpers {
s_mat_mul_precision = precision;
}

static xla::StatusOr<xla::XlaComputation> WrapXlaComputation(
const xla::XlaComputation& computation,
const std::vector<xla::Shape>& parameter_shapes,
std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair);

private:
static xla::PrecisionConfig::Precision s_mat_mul_precision;
};
Expand Down
41 changes: 35 additions & 6 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1602,10 +1602,11 @@ XLATensor::OpByOpAsync XLATensor::SyncTensorsGraphOpByOp(
return async_op.Schedule();
}

void XLATensor::BuildInputOutputAliases(
std::vector<std::pair<int64_t, int64_t>> XLATensor::BuildInputOutputAliases(
const std::vector<XLATensorPtr>& tensors, absl::Span<const size_t> indices,
LoweringContext* lowering_ctx) {
std::unordered_map<int64_t, size_t> output_tensor_id_map;
std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair;
for (size_t i = 0; i < indices.size(); ++i) {
size_t tensor_index = indices[i];
int64_t tensor_id = tensors[tensor_index]->GetUniqueId();
Expand All @@ -1626,9 +1627,12 @@ void XLATensor::BuildInputOutputAliases(
const xla::Shape& root_shape = XlaHelpers::ShapeOfXlaOp(root);
if (parameters_data[i]->shape() == root_shape &&
alias_map[output_index] < 0) {
// parameter is not a tuple so param_index will always be {}
lowering_ctx->builder()->SetUpAlias(
{static_cast<int64_t>(output_index)}, i, {});
{/*output_index=*/static_cast<int64_t>(output_index)},
/*param_number=*/i, /*param_index=*/{});
alias_map[output_index] = i;
input_output_alias_pair.push_back(std::make_pair(i, output_index));

TF_VLOG(6) << "Aliased paramter " << i << " with output "
<< output_index << ": " << parameters_data[i]->shape();
Expand All @@ -1637,6 +1641,7 @@ void XLATensor::BuildInputOutputAliases(
}
}
XLA_VALUE_METRIC("InputOutputAliasCount", alias_map.size());
return input_output_alias_pair;
}

XLATensor::CompilationResult XLATensor::Compile(
Expand All @@ -1652,6 +1657,10 @@ XLATensor::CompilationResult XLATensor::Compile(
tensorflow::profiler::TraceMeLevel::kInfo);
static const bool enable_aliasing =
xla::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true);
static const size_t parameter_wrapping_threadshold =
xla::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200);
static const bool using_pjrt =
xla::sys_util::GetEnvString("PJRT_DEVICE", "").size() > 0;
LoweringContext lowering_ctx("SyncTensorsGraph", coll.device,
po_data->post_order,
std::move(po_data->emission_map));
Expand All @@ -1664,6 +1673,7 @@ XLATensor::CompilationResult XLATensor::Compile(
// Annotate HLO sharding selectively in the compuation.
ShardingUtil::SetHloSharding(&lowering_ctx);

std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair;
if (enable_aliasing && coll.config.sync_xla_data) {
// We can only alias at the step barrier, when force_xla_data is true.
// Consider the case:
Expand All @@ -1689,19 +1699,32 @@ XLATensor::CompilationResult XLATensor::Compile(
// will later fetch the new value of A, which is incorrect.
// But, when we issue a step barrier (force_xla_data == true) we have to
// turn everything into DEVICE_DATA, so we can activate aliasing.
BuildInputOutputAliases(tensors, coll.indices, &lowering_ctx);
input_output_alias_pair =
BuildInputOutputAliases(tensors, coll.indices, &lowering_ctx);
}

xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla());
xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());

bool should_wrap_parameter =
(program_shape.parameters_size() >= parameter_wrapping_threadshold) &&
using_pjrt;
if (should_wrap_parameter) {
TF_VLOG(3) << "Wrapping graph with " << program_shape.parameters_size()
<< " parameters. Threadshold = "
<< parameter_wrapping_threadshold;
computation = ConsumeValue(XlaHelpers::WrapXlaComputation(
computation, program_shape.parameters(), input_output_alias_pair));
program_shape = ConsumeValue(computation.GetProgramShape());
}
xla::Shape shape = MakeShapeWithDeviceLayout(
program_shape.result(), static_cast<XlaDeviceType>(coll.device.type()));

std::vector<xla::ComputationClient::CompileInstance> instances;
instances.push_back({std::move(computation), coll.device.toString(),
xla::ComputationClient::Get()->GetCompilationDevices(
coll.device.toString(), devices),
&shape});
&shape, should_wrap_parameter});

TF_VLOG(3) << "Compiling IR graph hash "
<< torch::lazy::HashToString(coll.hash) << " on device "
Expand All @@ -1717,8 +1740,14 @@ XLATensor::CompilationResult XLATensor::Compile(
<< " is computation hash "
<< torch::lazy::HashToString(torch::lazy::Hash(
computations.front()->computation().proto().SerializeAsString()));
XLA_CHECK_EQ(program_shape.parameters_size(),
po_data->parameters_data.size());
if (should_wrap_parameter) {
XLA_CHECK_EQ(program_shape.parameters_size(), 1);
XLA_CHECK_EQ(program_shape.parameters()[0].tuple_shapes_size(),
po_data->parameters_data.size());
} else {
XLA_CHECK_EQ(program_shape.parameters_size(),
po_data->parameters_data.size());
}

return {/*device=*/coll.device,
/*emitted_nodes=*/lowering_ctx.GetEmittedNodeCount(),
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1445,9 +1445,9 @@ class XLATensor : public c10::intrusive_ptr_target {
std::vector<XLATensorPtr>* tensors, SyncTensorCollection* coll,
PostOrderData* po_data);

static void BuildInputOutputAliases(const std::vector<XLATensorPtr>& tensors,
absl::Span<const size_t> indices,
LoweringContext* lowering_ctx);
static std::vector<std::pair<int64_t, int64_t>> BuildInputOutputAliases(
const std::vector<XLATensorPtr>& tensors,
absl::Span<const size_t> indices, LoweringContext* lowering_ctx);

static CompilationResult Compile(const std::vector<XLATensorPtr>& tensors,
absl::Span<const std::string> devices,
Expand Down