From aa73509890fb10b7f3bc2b9e8598f96382f2c979 Mon Sep 17 00:00:00 2001 From: steventk-g <107513673+steventk-g@users.noreply.github.com> Date: Thu, 15 Sep 2022 18:29:32 -0700 Subject: [PATCH] Op lowering for Einsum (#3843) * Op lowering for Einsum * Delete .torch_pin Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com> --- test/cpp/test_aten_xla_tensor.cpp | 97 ++++++++++++++++++++--- torch_xla/csrc/aten_autograd_ops.cpp | 45 +++++++++++ torch_xla/csrc/aten_autograd_ops.h | 10 +++ torch_xla/csrc/aten_xla_type.cpp | 13 +++ torch_xla/csrc/helpers.cpp | 9 ++- torch_xla/csrc/helpers.h | 4 + torch_xla/csrc/ops/einsum.cpp | 54 +++++++++++++ torch_xla/csrc/ops/einsum.h | 23 ++++++ torch_xla/csrc/ops/einsum_backward.cpp | 83 +++++++++++++++++++ torch_xla/csrc/ops/einsum_backward.h | 24 ++++++ torch_xla/csrc/ops/einsum_utilities.h | 91 +++++++++++++++++++++ torch_xla/csrc/ops/infer_output_shape.cpp | 23 ++++++ torch_xla/csrc/ops/infer_output_shape.h | 5 ++ torch_xla/csrc/ops/xla_ops.cpp | 1 + torch_xla/csrc/ops/xla_ops.h | 1 + torch_xla/csrc/reduction.cpp | 47 +++++++++++ torch_xla/csrc/reduction.h | 7 ++ torch_xla/csrc/tensor.h | 5 ++ torch_xla/csrc/tensor_methods.cpp | 35 ++++++++ xla_native_functions.yaml | 1 + 20 files changed, 567 insertions(+), 11 deletions(-) create mode 100644 torch_xla/csrc/ops/einsum.cpp create mode 100644 torch_xla/csrc/ops/einsum.h create mode 100644 torch_xla/csrc/ops/einsum_backward.cpp create mode 100644 torch_xla/csrc/ops/einsum_backward.h create mode 100644 torch_xla/csrc/ops/einsum_utilities.h diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index aa41755d7bec..424d8d9da6ba 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -3834,7 +3834,7 @@ TEST_F(AtenXlaTensorTest, TestEinsumOuter) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::mul", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestEinsumOuterBackward) { @@ -3851,7 +3851,7 @@ TEST_F(AtenXlaTensorTest, TestEinsumOuterBackward) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::mul", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestEinsumBatchMatMul) { @@ -3867,7 +3867,24 @@ TEST_F(AtenXlaTensorTest, TestEinsumBatchMatMul) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::bmm", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestEinsumBatchMatMulBackward) { + torch::Tensor a = torch::rand( + {3, 2, 5}, torch::TensorOptions(torch::kFloat).requires_grad(true)); + torch::Tensor b = torch::rand( + {3, 5, 4}, torch::TensorOptions(torch::kFloat).requires_grad(true)); + std::string equation = "bij,bjk->bik"; + auto testfn = [&](const std::vector& inputs) -> torch::Tensor { + return torch::einsum(equation, inputs); + }; + ForEachDevice([&](const torch::Device& device) { + TestBackward({a, b}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBilinear) { @@ -3884,8 +3901,8 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBilinear) { AllClose(c, xla_c); }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::bmm", cpp_test::GetIgnoredCounters()); + ExpectCounterNotChanged("aten::einsum", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerDiagonal) { @@ -3900,7 +3917,22 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerDiagonal) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::diagonal", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerDiagonalBackward) { + torch::Tensor input = torch::rand( + {3, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true)); + std::string equation = "ii->i"; + auto testfn = [&](const std::vector& inputs) -> torch::Tensor { + return torch::einsum(equation, inputs); + }; + ForEachDevice([&](const torch::Device& device) { + TestBackward({input}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchDiagonal) { @@ -3915,7 +3947,22 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchDiagonal) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::diagonal", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchDiagonalBackward) { + torch::Tensor input = torch::rand( + {4, 3, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true)); + std::string equation = "...ii->...i"; + auto testfn = [&](const std::vector& inputs) -> torch::Tensor { + return torch::einsum(equation, inputs); + }; + ForEachDevice([&](const torch::Device& device) { + TestBackward({input}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchPermute) { @@ -3930,7 +3977,22 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchPermute) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::permute", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchPermuteBackward) { + torch::Tensor input = torch::rand( + {2, 3, 4, 5}, torch::TensorOptions(torch::kFloat).requires_grad(true)); + std::string equation = "...ij->...ji"; + auto testfn = [&](const std::vector& inputs) -> torch::Tensor { + return torch::einsum(equation, inputs); + }; + ForEachDevice([&](const torch::Device& device) { + TestBackward({input}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerRepeatedAxis) { @@ -3946,7 +4008,24 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerRepeatedAxis) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::mul", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerRepeatedAxisBackward) { + torch::Tensor x = torch::rand( + {2, 3, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true)); + torch::Tensor y = + torch::rand({4}, torch::TensorOptions(torch::kFloat).requires_grad(true)); + std::string equation = "ijj,k->ik"; + auto testfn = [&](const std::vector& inputs) -> torch::Tensor { + return torch::einsum(equation, inputs); + }; + ForEachDevice([&](const torch::Device& device) { + TestBackward({x, y}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestBilinear) { diff --git a/torch_xla/csrc/aten_autograd_ops.cpp b/torch_xla/csrc/aten_autograd_ops.cpp index 4088cb789d9f..89b12d171dc7 100644 --- a/torch_xla/csrc/aten_autograd_ops.cpp +++ b/torch_xla/csrc/aten_autograd_ops.cpp @@ -18,6 +18,51 @@ bool IsNonTrivialDilation(at::IntArrayRef dilation) { namespace aten_autograd_ops { +torch::Tensor EinsumAutogradFunction::forward( + torch::autograd::AutogradContext* ctx, const c10::string_view equation, + at::TensorList tensors) { + std::string eq_str = std::string(equation); + ctx->saved_data["equation"] = eq_str; + + torch::autograd::variable_list vars; + for (const torch::Tensor const& tensor : tensors) { + vars.push_back(tensor); + } + ctx->save_for_backward(vars); + + std::vector xla_tensors = + bridge::GetXlaTensors(absl::MakeSpan(tensors)); + XLATensorPtr output = XLATensor::einsum(eq_str, xla_tensors); + return bridge::AtenFromXlaTensor(output); +} + +torch::autograd::variable_list EinsumAutogradFunction::backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + std::string equation = ctx->saved_data["equation"].toString()->string(); + torch::autograd::variable_list tensors = ctx->get_saved_variables(); + std::vector xla_tensors = + bridge::GetXlaTensors(absl::MakeSpan(tensors)); + + std::tuple outputs = XLATensor::einsum_backward( + bridge::GetXlaTensor(grad_output[0]), xla_tensors, equation); + + // For both einsum and max pool, we use "undef" as a placeholder for the + // non-tensor grad inputs, in this case the equation string. + torch::Tensor undef; + torch::autograd::variable_list grad_inputs = { + undef, bridge::AtenFromXlaTensor(std::get<0>(outputs))}; + + // einsum_backward will return a tuple with either one or two tensors defined. + // If both tensors in the tuple are defined, then we return both tensors. + // Otherwise, we only return the first tensor. + if (std::get<1>(outputs).defined()) { + grad_inputs.push_back(bridge::AtenFromXlaTensor(std::get<1>(outputs))); + } + + return grad_inputs; +} + torch::Tensor MaxPool2dAutogradFunction::forward( torch::autograd::AutogradContext* ctx, torch::Tensor self, torch::IntArrayRef kernel_size, torch::IntArrayRef stride, diff --git a/torch_xla/csrc/aten_autograd_ops.h b/torch_xla/csrc/aten_autograd_ops.h index 713ddb5517fd..4ad7eb5e16d8 100644 --- a/torch_xla/csrc/aten_autograd_ops.h +++ b/torch_xla/csrc/aten_autograd_ops.h @@ -7,6 +7,16 @@ bool IsNonTrivialDilation(at::IntArrayRef dilation); namespace aten_autograd_ops { +struct EinsumAutogradFunction + : public torch::autograd::Function { + static torch::Tensor forward(torch::autograd::AutogradContext* ctx, + c10::string_view equation, + at::TensorList tensors); + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output); +}; + struct MaxPool2dAutogradFunction : public torch::autograd::Function { static torch::Tensor forward(torch::autograd::AutogradContext* ctx, diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 7ea5ed96b9cd..0511164e9c00 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -25,6 +25,7 @@ #include "torch_xla/csrc/generated/XLANativeFunctions.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/as_strided.h" +#include "torch_xla/csrc/ops/einsum_utilities.h" #include "torch_xla/csrc/ops/index_ops.h" #include "torch_xla/csrc/pooling.h" #include "torch_xla/csrc/tensor_impl.h" @@ -1070,6 +1071,18 @@ at::Tensor XLANativeFunctions::dot(const at::Tensor& self, bridge::GetXlaTensor(self), bridge::GetXlaTensor(tensor))); } +at::Tensor XLANativeFunctions::einsum(c10::string_view equation, + at::TensorList tensors) { + XLA_FN_COUNTER("xla::"); + // Einsum operations with more than 2 operands, like bilinear operations, are + // not currently supported in XLA + if (tensors.size() > 2 || + !EinsumUtilities::EquationIsValid(std::string(equation))) { + return at::native::einsum(equation, tensors); + } + return aten_autograd_ops::EinsumAutogradFunction::apply(equation, tensors); +} + at::Tensor XLANativeFunctions::elu_backward(const at::Tensor& grad_output, const at::Scalar& alpha, const at::Scalar& scale, diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index e4b5a98a7e55..40c1cfd1b9b8 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -421,6 +421,12 @@ xla::PrimitiveType XlaHelpers::PromoteType(xla::PrimitiveType type1, return type1; } +xla::PrimitiveType XlaHelpers::PromoteType(xla::PrimitiveType type1, + xla::PrimitiveType type2, + xla::PrimitiveType type3) { + return PromoteType(PromoteType(type1, type2), type3); +} + std::pair XlaHelpers::PromoteValues(xla::XlaOp op1, xla::XlaOp op2) { xla::PrimitiveType type1 = TypeOfXlaOp(op1); @@ -440,8 +446,7 @@ std::tuple XlaHelpers::PromoteValues( xla::PrimitiveType type1 = TypeOfXlaOp(op1); xla::PrimitiveType type2 = TypeOfXlaOp(op2); xla::PrimitiveType type3 = TypeOfXlaOp(op3); - xla::PrimitiveType result_type = - PromoteType(PromoteType(type1, type2), type3); + xla::PrimitiveType result_type = PromoteType(type1, type2, type3); if (type1 != result_type) { op1 = ConvertTo(op1, type1, result_type, /*device=*/nullptr); } diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index bf796164c59c..b2f1a752cea4 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -239,6 +239,10 @@ class XlaHelpers { static xla::PrimitiveType PromoteType(xla::PrimitiveType type1, xla::PrimitiveType type2); + static xla::PrimitiveType PromoteType(xla::PrimitiveType type1, + xla::PrimitiveType type2, + xla::PrimitiveType type3); + // Performs type promotion to make sure both operations return the same type. static std::pair PromoteValues(xla::XlaOp op1, xla::XlaOp op2); diff --git a/torch_xla/csrc/ops/einsum.cpp b/torch_xla/csrc/ops/einsum.cpp new file mode 100644 index 000000000000..4615b1789e58 --- /dev/null +++ b/torch_xla/csrc/ops/einsum.cpp @@ -0,0 +1,54 @@ +#include "torch_xla/csrc/ops/einsum.h" + +#include "torch_xla/csrc/helpers.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/reduction.h" + +namespace torch_xla { +namespace { + +xla::Shape NodeOutputShape(const torch::lazy::OpList& operands, + const std::string& equation) { + auto lower_for_shape_fn = + [&](absl::Span operands) -> xla::XlaOp { + return BuildEinsum(operands, equation); + }; + + std::vector shapes; + for (auto const& op : operands) { + shapes.push_back(GetXlaShape(op)); + } + + return InferOutputShape(absl::MakeSpan(shapes), lower_for_shape_fn); +} + +} // namespace + +torch::lazy::NodePtr Einsum::Clone(torch::lazy::OpList operands) const { + return torch::lazy::MakeNode(operands, equation_); +} + +Einsum::Einsum(const torch::lazy::OpList& operands, const std::string equation) + : XlaNode(torch::lazy::OpKind(at::aten::einsum), operands, + NodeOutputShape(operands, equation), + /*num_outputs=*/1, torch::lazy::MHash(equation)), + equation_(std::move(equation)) {} + +XlaOpVector Einsum::Lower(LoweringContext* loctx) const { + std::vector inputs; + auto& operand_list = operands(); + inputs.reserve(operand_list.size()); + for (size_t i = 0; i < operand_list.size(); ++i) { + inputs.push_back(loctx->GetOutputOp(operand_list[i])); + } + return ReturnOp(BuildEinsum(absl::MakeSpan(inputs), equation_), loctx); +} + +std::string Einsum::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", equation=(" << equation_ << ")"; + return ss.str(); +} + +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/ops/einsum.h b/torch_xla/csrc/ops/einsum.h new file mode 100644 index 000000000000..308b87bb405c --- /dev/null +++ b/torch_xla/csrc/ops/einsum.h @@ -0,0 +1,23 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class Einsum : public XlaNode { + public: + Einsum(const torch::lazy::OpList& operands, const std::string equation); + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + const std::string& equation() const { return equation_; } + + private: + const std::string equation_; +}; + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/einsum_backward.cpp b/torch_xla/csrc/ops/einsum_backward.cpp new file mode 100644 index 000000000000..9976474af28f --- /dev/null +++ b/torch_xla/csrc/ops/einsum_backward.cpp @@ -0,0 +1,83 @@ +#include "torch_xla/csrc/ops/einsum_backward.h" + +#include "torch_xla/csrc/helpers.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/einsum.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/reduction.h" + +namespace torch_xla { +namespace { + +std::vector GetOperandList( + c10::ArrayRef operands, + const torch::lazy::Value& grad_output) { + std::vector operand_list(operands.begin(), + operands.end()); + operand_list.insert(operand_list.begin(), grad_output); + return operand_list; +} + +xla::Shape NodeOutputShapes(const torch::lazy::Value& grad_output, + const torch::lazy::OpList& inputs, + const std::string& equation) { + auto lower_for_shapes_fn = + [&](absl::Span operands) -> std::vector { + return BuildEinsumBackward( + operands[0], + std::vector(operands.begin() + 1, operands.end()), + equation); + }; + + std::vector input_shapes; + input_shapes.push_back(GetXlaShape(grad_output)); + for (auto const& op : inputs) { + input_shapes.push_back(GetXlaShape(op)); + } + + return InferOutputShapes(absl::MakeSpan(input_shapes), lower_for_shapes_fn); +} +} // namespace + +torch::lazy::NodePtr EinsumBackward::Clone(torch::lazy::OpList operands) const { + std::vector inputs; + inputs.reserve(operands.size() - 1); + for (size_t i = 1; i < operands.size(); ++i) { + inputs.push_back(operands.at(i)); + } + + return torch::lazy::MakeNode(operands.at(0), inputs, + equation_); +} + +EinsumBackward::EinsumBackward(const torch::lazy::Value& grad_output, + const torch::lazy::OpList& inputs, + const std::string equation) + : XlaNode(xla_einsum_backward, GetOperandList(inputs, grad_output), + [&]() { return NodeOutputShapes(grad_output, inputs, equation); }, + /*num_outputs=*/inputs.size(), torch::lazy::MHash(equation)), + equation_(equation) {} + +XlaOpVector EinsumBackward::Lower(LoweringContext* loctx) const { + std::vector inputs; + auto& operand_list = operands(); + inputs.reserve(operand_list.size() - 1); + xla::XlaOp grad_output = loctx->GetOutputOp(operand_list[0]); + + for (size_t i = 1; i < operand_list.size(); ++i) { + inputs.push_back(loctx->GetOutputOp(operand_list[i])); + } + + std::vector ops = + BuildEinsumBackward(grad_output, absl::MakeSpan(inputs), equation_); + + return ReturnOps(absl::MakeSpan(ops), loctx); +} + +std::string EinsumBackward::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", equation=(" << equation_ << ")"; + return ss.str(); +} +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/ops/einsum_backward.h b/torch_xla/csrc/ops/einsum_backward.h new file mode 100644 index 000000000000..a0b1bc7b4cff --- /dev/null +++ b/torch_xla/csrc/ops/einsum_backward.h @@ -0,0 +1,24 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class EinsumBackward : public XlaNode { + public: + EinsumBackward(const torch::lazy::Value& grad_output, + const torch::lazy::OpList& inputs, const std::string equation); + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + const std::string& equation() const { return equation_; } + + private: + const std::string equation_; +}; + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/einsum_utilities.h b/torch_xla/csrc/ops/einsum_utilities.h new file mode 100644 index 000000000000..1a58e91d3892 --- /dev/null +++ b/torch_xla/csrc/ops/einsum_utilities.h @@ -0,0 +1,91 @@ +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" + +namespace torch_xla { + +class EinsumUtilities { + public: + static std::vector BuildBackwardsEquations( + const std::string& equation) { + std::vector elements = ParseEquation(equation); + std::vector equations; + equations.push_back(elements[2] + "," + elements[1] + "->" + elements[0]); + equations.push_back(elements[0] + "," + elements[2] + "->" + elements[1]); + return equations; + } + + static std::string BuildBackwardsEquation(const std::string& equation) { + int split_index = equation.find("->"); + XLA_CHECK_NE(split_index, std::string::npos); + + std::string equation_input = equation.substr(0, split_index); + std::string equation_output = + equation.substr(split_index + 2, equation.size() - split_index - 2); + + std::string backward_equation = equation_output + "->" + equation_input; + return backward_equation; + } + + // An einsum equation is invalid if there are indices in one of the inputs or + // output which are not in any other input or output. This is because such + // equations lead to a failure when attempting to execute einsum backward on + // XLA. + static bool EquationIsValid(const std::string& equation) { + // Elements represent the inputs and outputs of an equation + // For example, in "i,j->ij" the elements are {"i", "j", "ij"} + std::vector elements = ParseEquation(equation); + + for (size_t i = 0; i < elements.size(); i++) { + for (char c : elements[i]) { + // We use j to skip searching the element for its own characters + size_t j = 0; + + // For each element, we want to see if the chars in that element are + // contained in some other element. For example, "i" is contained in + // "ij", "j" is contained in "ij", and the characters in "ij" are + // contained in "i" and "j" respectively, so the equation is valid. As a + // counter-example, if we have "ik,ij->i", that is not valid, because + // not all characters of "ik" or "ij" are contained in another element. + if (std::all_of(elements.cbegin(), elements.cend(), + [&c, &i, &j](std::string element) { + return j++ == i || + std::find(element.begin(), element.end(), c) == + element.end(); + })) { + return false; + } + } + } + return true; + } + + private: + // Breaks an einsum equation string down into its "elements", e.g. "i,j->ij" + // will be decomposed into {"i", "j", "ij"} + static std::vector ParseEquation(const std::string& equation) { + int split_index_one = equation.find(","); + int split_index_two = equation.find("->"); + XLA_CHECK_NE(split_index_two, std::string::npos); + + std::vector elements; + + elements.push_back(equation.substr(0, split_index_one)); + + if (split_index_one == std::string::npos) { + elements.push_back(equation.substr( + split_index_two + 1, equation.size() - split_index_two - 1)); + } else { + elements.push_back(equation.substr( + split_index_one + 1, split_index_two - split_index_one - 1)); + elements.push_back(equation.substr( + split_index_two + 2, equation.size() - split_index_two - 2)); + } + + return elements; + } +}; +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/infer_output_shape.cpp b/torch_xla/csrc/ops/infer_output_shape.cpp index d9af8eae8ade..4411c441e268 100644 --- a/torch_xla/csrc/ops/infer_output_shape.cpp +++ b/torch_xla/csrc/ops/infer_output_shape.cpp @@ -18,4 +18,27 @@ xla::Shape InferOutputShape(absl::Span input_shapes, return XlaHelpers::ShapeOfXlaOp(result); } +xla::Shape InferOutputShapes(absl::Span input_shapes, + const LowerForShapesFn& core_lowering_fn) { + xla::XlaBuilder b("InferOutputShape"); + std::vector parameters; + for (size_t parameter_number = 0; parameter_number < input_shapes.size(); + ++parameter_number) { + parameters.push_back(xla::Parameter(&b, parameter_number, + input_shapes[parameter_number], + absl::StrCat("p", parameter_number))); + } + std::vector results = core_lowering_fn(parameters); + + xla::Shape output_shape; + if (results.size() == 2) { + output_shape = + xla::ShapeUtil::MakeTupleShape({XlaHelpers::ShapeOfXlaOp(results[0]), + XlaHelpers::ShapeOfXlaOp(results[1])}); + } else { + output_shape = XlaHelpers::ShapeOfXlaOp(results[0]); + } + return output_shape; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/infer_output_shape.h b/torch_xla/csrc/ops/infer_output_shape.h index ee77388785ad..b099d06b2e34 100644 --- a/torch_xla/csrc/ops/infer_output_shape.h +++ b/torch_xla/csrc/ops/infer_output_shape.h @@ -7,9 +7,14 @@ namespace torch_xla { using LowerForShapeFn = std::function operands)>; +using LowerForShapesFn = std::function( + absl::Span operands)>; // Compute the output shape for the given input shapes and lowering. xla::Shape InferOutputShape(absl::Span input_shapes, const LowerForShapeFn& core_lowering_fn); +xla::Shape InferOutputShapes(absl::Span input_shapes, + const LowerForShapesFn& core_lowering_fn); + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/xla_ops.cpp b/torch_xla/csrc/ops/xla_ops.cpp index 6f1c162aa372..33af10d1402b 100644 --- a/torch_xla/csrc/ops/xla_ops.cpp +++ b/torch_xla/csrc/ops/xla_ops.cpp @@ -11,6 +11,7 @@ const OpKindWrapper xla_collective_permute("xla::collective_permute"); const OpKindWrapper xla_cross_replica_sum("xla::cross_replica_sum"); const OpKindWrapper xla_device_data("xla::device_data"); const OpKindWrapper xla_diagonal_view_update("xla::diagonal_view_update"); +const OpKindWrapper xla_einsum_backward("xla::einsum_backward"); const OpKindWrapper xla_generic_slice("xla::generic_slice"); const OpKindWrapper xla_get_dimensions_size("xla::xla_get_dimensions_size"); const OpKindWrapper xla_moving_average("xla::moving_average"); diff --git a/torch_xla/csrc/ops/xla_ops.h b/torch_xla/csrc/ops/xla_ops.h index 590228363f59..13c53ab4f7c0 100644 --- a/torch_xla/csrc/ops/xla_ops.h +++ b/torch_xla/csrc/ops/xla_ops.h @@ -36,6 +36,7 @@ extern const OpKindWrapper xla_collective_permute; extern const OpKindWrapper xla_cross_replica_sum; extern const OpKindWrapper xla_device_data; extern const OpKindWrapper xla_diagonal_view_update; +extern const OpKindWrapper xla_einsum_backward; extern const OpKindWrapper xla_generic_slice; extern const OpKindWrapper xla_get_dimensions_size; extern const OpKindWrapper xla_moving_average; diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index a30dfed57b91..b774fa58c751 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -7,12 +7,14 @@ #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "torch/csrc/lazy/core/helpers.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/helpers.h" +#include "torch_xla/csrc/ops/einsum_utilities.h" #include "torch_xla/csrc/tensor_util.h" namespace torch_xla { @@ -516,4 +518,49 @@ xla::XlaOp BuildLogsumexp(xla::XlaOp input, return logs + max_in_dim; } +xla::XlaOp BuildEinsum(absl::Span operands, + const std::string& equation) { + if (operands.size() == 1) { + return xla::Einsum( + operands[0], equation, + xla::PrecisionConfig::Precision::PrecisionConfig_Precision_DEFAULT); + } else if (operands.size() == 2) { + return xla::Einsum( + operands[0], operands[1], equation, + xla::PrecisionConfig::Precision::PrecisionConfig_Precision_DEFAULT, + XlaHelpers::PromoteType(XlaHelpers::TypeOfXlaOp(operands[0]), + XlaHelpers::TypeOfXlaOp(operands[1]))); + } +} + +std::vector BuildEinsumBackward(const xla::XlaOp& grad_output, + absl::Span inputs, + const std::string& equation) { + std::vector result; + if (inputs.size() == 1) { + std::string backward_equation = + EinsumUtilities::BuildBackwardsEquation(equation); + result.push_back(xla::Einsum(grad_output, backward_equation)); + } else if (inputs.size() == 2) { + std::vector equations = + EinsumUtilities::BuildBackwardsEquations(equation); + + xla::PrimitiveType type = XlaHelpers::PromoteType( + XlaHelpers::TypeOfXlaOp(grad_output), + XlaHelpers::TypeOfXlaOp(inputs[0]), XlaHelpers::TypeOfXlaOp(inputs[1])); + + result.push_back(xla::Einsum( + grad_output, inputs[1], equations[0], + xla::PrecisionConfig::Precision::PrecisionConfig_Precision_DEFAULT, + type)); + + result.push_back(xla::Einsum( + inputs[0], grad_output, equations[1], + xla::PrecisionConfig::Precision::PrecisionConfig_Precision_DEFAULT, + type)); + } + + return result; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/reduction.h b/torch_xla/csrc/reduction.h index da0f082e2efc..8bc7d069f56c 100644 --- a/torch_xla/csrc/reduction.h +++ b/torch_xla/csrc/reduction.h @@ -100,4 +100,11 @@ xla::XlaOp BuildLogsumexp(xla::XlaOp input, absl::Span dimensions, bool keep_reduced_dimensions); +xla::XlaOp BuildEinsum(absl::Span operands, + const std::string& equation); + +std::vector BuildEinsumBackward(const xla::XlaOp& grad_output, + absl::Span inputs, + const std::string& equation); + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 69ef750ab23e..6232960c4ed3 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -553,6 +553,11 @@ class XLATensor : public c10::intrusive_ptr_target { static XLATensorPtr einsum(const std::string& equation, absl::Span tensors); + static std::tuple einsum_backward( + const XLATensorPtr& grad_output, + const absl::Span tensors, + const std::string& equation); + static XLATensorPtr elu_backward(const XLATensorPtr& grad_output, const at::Scalar& alpha, const at::Scalar& scale, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 4eccb64cca90..4571ac391360 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -46,6 +46,8 @@ #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/diagonal.h" #include "torch_xla/csrc/ops/discrete_uniform.h" +#include "torch_xla/csrc/ops/einsum.h" +#include "torch_xla/csrc/ops/einsum_backward.h" #include "torch_xla/csrc/ops/expand.h" #include "torch_xla/csrc/ops/exponential.h" #include "torch_xla/csrc/ops/flip.h" @@ -1053,6 +1055,39 @@ XLATensorPtr XLATensor::div(const XLATensorPtr& input, return input->CreateFrom(input_value / other_value, scalar_type); } +XLATensorPtr XLATensor::einsum(const std::string& equation, + const absl::Span tensors) { + std::vector irs; + irs.reserve(tensors.size()); + for (const XLATensorPtr& tensor : tensors) { + irs.push_back(tensor->GetIrValue()); + } + + return tensors[0]->CreateFrom(torch::lazy::MakeNode(irs, equation)); +} + +std::tuple XLATensor::einsum_backward( + const XLATensorPtr& grad_output, + const absl::Span tensors, const std::string& equation) { + std::vector irs; + irs.reserve(tensors.size()); + for (const XLATensorPtr& tensor : tensors) { + irs.push_back(tensor->GetIrValue()); + } + + torch::lazy::NodePtr node = torch::lazy::MakeNode( + grad_output->GetIrValue(), irs, equation); + + if (node->num_outputs() == 2) { + return std::make_tuple( + grad_output->CreateFrom(torch::lazy::Value(node, 0)), + grad_output->CreateFrom(torch::lazy::Value(node, 1))); + } else { + return std::make_tuple(grad_output->CreateFrom(torch::lazy::Value(node, 0)), + XLATensorPtr()); + } +} + XLATensorPtr XLATensor::eq(const XLATensorPtr& input, const at::Scalar& other) { return DispatchComparisonOp(at::aten::eq, input, other); } diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 4e3d48c3def9..fc14f65ab886 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -352,6 +352,7 @@ symint: - new_empty_strided - view autograd: + - einsum - max_pool2d - max_pool3d - native_layer_norm