Skip to content

Commit

Permalink
Op lowering for Einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
steventk-g committed Sep 14, 2022
1 parent 2c77849 commit e449a12
Show file tree
Hide file tree
Showing 21 changed files with 569 additions and 11 deletions.
97 changes: 88 additions & 9 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3834,7 +3834,7 @@ TEST_F(AtenXlaTensorTest, TestEinsumOuter) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::mul", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumOuterBackward) {
Expand All @@ -3851,7 +3851,7 @@ TEST_F(AtenXlaTensorTest, TestEinsumOuterBackward) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::mul", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumBatchMatMul) {
Expand All @@ -3867,7 +3867,24 @@ TEST_F(AtenXlaTensorTest, TestEinsumBatchMatMul) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bmm", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumBatchMatMulBackward) {
torch::Tensor a = torch::rand(
{3, 2, 5}, torch::TensorOptions(torch::kFloat).requires_grad(true));
torch::Tensor b = torch::rand(
{3, 5, 4}, torch::TensorOptions(torch::kFloat).requires_grad(true));
std::string equation = "bij,bjk->bik";
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::einsum(equation, inputs);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward({a, b}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBilinear) {
Expand All @@ -3884,8 +3901,8 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBilinear) {
AllClose(c, xla_c);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bmm", cpp_test::GetIgnoredCounters());
ExpectCounterNotChanged("aten::einsum", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerDiagonal) {
Expand All @@ -3900,7 +3917,22 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerDiagonal) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::diagonal", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerDiagonalBackward) {
torch::Tensor input = torch::rand(
{3, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true));
std::string equation = "ii->i";
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::einsum(equation, inputs);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward({input}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchDiagonal) {
Expand All @@ -3915,7 +3947,22 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchDiagonal) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::diagonal", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchDiagonalBackward) {
torch::Tensor input = torch::rand(
{4, 3, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true));
std::string equation = "...ii->...i";
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::einsum(equation, inputs);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward({input}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchPermute) {
Expand All @@ -3930,7 +3977,22 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchPermute) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::permute", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchPermuteBackward) {
torch::Tensor input = torch::rand(
{2, 3, 4, 5}, torch::TensorOptions(torch::kFloat).requires_grad(true));
std::string equation = "...ij->...ji";
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::einsum(equation, inputs);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward({input}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerRepeatedAxis) {
Expand All @@ -3946,7 +4008,24 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerRepeatedAxis) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::mul", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerRepeatedAxisBackward) {
torch::Tensor x = torch::rand(
{2, 3, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true));
torch::Tensor y =
torch::rand({4}, torch::TensorOptions(torch::kFloat).requires_grad(true));
std::string equation = "ijj,k->ik";
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::einsum(equation, inputs);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward({x, y}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestBilinear) {
Expand Down
1 change: 1 addition & 0 deletions torch_patches/.torch_pin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#84583
45 changes: 45 additions & 0 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,51 @@ bool IsNonTrivialDilation(at::IntArrayRef dilation) {

namespace aten_autograd_ops {

torch::Tensor EinsumAutogradFunction::forward(
torch::autograd::AutogradContext* ctx, const c10::string_view equation,
at::TensorList tensors) {
std::string eq_str = std::string(equation);
ctx->saved_data["equation"] = eq_str;

torch::autograd::variable_list vars;
for (const torch::Tensor const& tensor : tensors) {
vars.push_back(tensor);
}
ctx->save_for_backward(vars);

std::vector<XLATensorPtr> xla_tensors =
bridge::GetXlaTensors(absl::MakeSpan(tensors));
XLATensorPtr output = XLATensor::einsum(eq_str, xla_tensors);
return bridge::AtenFromXlaTensor(output);
}

torch::autograd::variable_list EinsumAutogradFunction::backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
std::string equation = ctx->saved_data["equation"].toString()->string();
torch::autograd::variable_list tensors = ctx->get_saved_variables();
std::vector<XLATensorPtr> xla_tensors =
bridge::GetXlaTensors(absl::MakeSpan(tensors));

std::tuple<XLATensorPtr, XLATensorPtr> outputs = XLATensor::einsum_backward(
bridge::GetXlaTensor(grad_output[0]), xla_tensors, equation);

// For both einsum and max pool, we use "undef" as a placeholder for the
// non-tensor grad inputs, in this case the equation string.
torch::Tensor undef;
torch::autograd::variable_list grad_inputs = {
undef, bridge::AtenFromXlaTensor(std::get<0>(outputs))};

// einsum_backward will return a tuple with either one or two tensors defined.
// If both tensors in the tuple are defined, then we return both tensors.
// Otherwise, we only return the first tensor.
if (std::get<1>(outputs).defined()) {
grad_inputs.push_back(bridge::AtenFromXlaTensor(std::get<1>(outputs)));
}

return grad_inputs;
}

torch::Tensor MaxPool2dAutogradFunction::forward(
torch::autograd::AutogradContext* ctx, torch::Tensor self,
torch::IntArrayRef kernel_size, torch::IntArrayRef stride,
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/aten_autograd_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ bool IsNonTrivialDilation(at::IntArrayRef dilation);

namespace aten_autograd_ops {

struct EinsumAutogradFunction
: public torch::autograd::Function<EinsumAutogradFunction> {
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
c10::string_view equation,
at::TensorList tensors);
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output);
};

struct MaxPool2dAutogradFunction
: public torch::autograd::Function<MaxPool2dAutogradFunction> {
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
Expand Down
13 changes: 13 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "torch_xla/csrc/generated/XLANativeFunctions.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ops/as_strided.h"
#include "torch_xla/csrc/ops/einsum_utilities.h"
#include "torch_xla/csrc/ops/index_ops.h"
#include "torch_xla/csrc/pooling.h"
#include "torch_xla/csrc/tensor_impl.h"
Expand Down Expand Up @@ -1070,6 +1071,18 @@ at::Tensor XLANativeFunctions::dot(const at::Tensor& self,
bridge::GetXlaTensor(self), bridge::GetXlaTensor(tensor)));
}

at::Tensor XLANativeFunctions::einsum(c10::string_view equation,
at::TensorList tensors) {
XLA_FN_COUNTER("xla::");
// Einsum operations with more than 2 operands, like bilinear operations, are
// not currently supported in XLA
if (tensors.size() > 2 ||
!EinsumUtilities::EquationIsValid(std::string(equation))) {
return at::native::einsum(equation, tensors);
}
return aten_autograd_ops::EinsumAutogradFunction::apply(equation, tensors);
}

at::Tensor XLANativeFunctions::elu_backward(const at::Tensor& grad_output,
const at::Scalar& alpha,
const at::Scalar& scale,
Expand Down
9 changes: 7 additions & 2 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,12 @@ xla::PrimitiveType XlaHelpers::PromoteType(xla::PrimitiveType type1,
return type1;
}

xla::PrimitiveType XlaHelpers::PromoteType(xla::PrimitiveType type1,
xla::PrimitiveType type2,
xla::PrimitiveType type3) {
return PromoteType(PromoteType(type1, type2), type3);
}

std::pair<xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteValues(xla::XlaOp op1,
xla::XlaOp op2) {
xla::PrimitiveType type1 = TypeOfXlaOp(op1);
Expand All @@ -440,8 +446,7 @@ std::tuple<xla::XlaOp, xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteValues(
xla::PrimitiveType type1 = TypeOfXlaOp(op1);
xla::PrimitiveType type2 = TypeOfXlaOp(op2);
xla::PrimitiveType type3 = TypeOfXlaOp(op3);
xla::PrimitiveType result_type =
PromoteType(PromoteType(type1, type2), type3);
xla::PrimitiveType result_type = PromoteType(type1, type2, type3);
if (type1 != result_type) {
op1 = ConvertTo(op1, type1, result_type, /*device=*/nullptr);
}
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ class XlaHelpers {
static xla::PrimitiveType PromoteType(xla::PrimitiveType type1,
xla::PrimitiveType type2);

static xla::PrimitiveType PromoteType(xla::PrimitiveType type1,
xla::PrimitiveType type2,
xla::PrimitiveType type3);

// Performs type promotion to make sure both operations return the same type.
static std::pair<xla::XlaOp, xla::XlaOp> PromoteValues(xla::XlaOp op1,
xla::XlaOp op2);
Expand Down
54 changes: 54 additions & 0 deletions torch_xla/csrc/ops/einsum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "torch_xla/csrc/ops/einsum.h"

#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"

namespace torch_xla {
namespace {

xla::Shape NodeOutputShape(const torch::lazy::OpList& operands,
const std::string& equation) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildEinsum(operands, equation);
};

std::vector<xla::Shape> shapes;
for (auto const& op : operands) {
shapes.push_back(GetXlaShape(op));
}

return InferOutputShape(absl::MakeSpan(shapes), lower_for_shape_fn);
}

} // namespace

torch::lazy::NodePtr Einsum::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<Einsum>(operands, equation_);
}

Einsum::Einsum(const torch::lazy::OpList& operands, const std::string equation)
: XlaNode(torch::lazy::OpKind(at::aten::einsum), operands,
NodeOutputShape(operands, equation),
/*num_outputs=*/1, torch::lazy::MHash(equation)),
equation_(std::move(equation)) {}

XlaOpVector Einsum::Lower(LoweringContext* loctx) const {
std::vector<xla::XlaOp> inputs;
auto& operand_list = operands();
inputs.reserve(operand_list.size());
for (size_t i = 0; i < operand_list.size(); ++i) {
inputs.push_back(loctx->GetOutputOp(operand_list[i]));
}
return ReturnOp(BuildEinsum(absl::MakeSpan(inputs), equation_), loctx);
}

std::string Einsum::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", equation=(" << equation_ << ")";
return ss.str();
}

} // namespace torch_xla
23 changes: 23 additions & 0 deletions torch_xla/csrc/ops/einsum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include "torch_xla/csrc/ir.h"

namespace torch_xla {

class Einsum : public XlaNode {
public:
Einsum(const torch::lazy::OpList& operands, const std::string equation);

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

const std::string& equation() const { return equation_; }

private:
const std::string equation_;
};

} // namespace torch_xla
Loading

0 comments on commit e449a12

Please sign in to comment.