Skip to content

Commit

Permalink
Op lowering for Einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
steventk-g committed Sep 6, 2022
1 parent 5059c3b commit 674f2cd
Show file tree
Hide file tree
Showing 13 changed files with 456 additions and 14 deletions.
97 changes: 88 additions & 9 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3834,7 +3834,7 @@ TEST_F(AtenXlaTensorTest, TestEinsumOuter) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::mul", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumOuterBackward) {
Expand All @@ -3851,7 +3851,7 @@ TEST_F(AtenXlaTensorTest, TestEinsumOuterBackward) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::mul", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumBatchMatMul) {
Expand All @@ -3867,7 +3867,24 @@ TEST_F(AtenXlaTensorTest, TestEinsumBatchMatMul) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bmm", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumBatchMatMulBackward) {
torch::Tensor a = torch::rand(
{3, 2, 5}, torch::TensorOptions(torch::kFloat).requires_grad(true));
torch::Tensor b = torch::rand(
{3, 5, 4}, torch::TensorOptions(torch::kFloat).requires_grad(true));
std::string equation = "bij,bjk->bik";
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::einsum(equation, inputs);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward({a, b}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBilinear) {
Expand All @@ -3884,8 +3901,8 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBilinear) {
AllClose(c, xla_c);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bmm", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("aten::einsum", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerDiagonal) {
Expand All @@ -3900,7 +3917,22 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerDiagonal) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::diagonal", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerDiagonalBackward) {
torch::Tensor input = torch::rand(
{3, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true));
std::string equation = "ii->i";
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::einsum(equation, inputs);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward({input}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchDiagonal) {
Expand All @@ -3915,7 +3947,22 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchDiagonal) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::diagonal", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchDiagonalBackward) {
torch::Tensor input = torch::rand(
{4, 3, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true));
std::string equation = "...ii->...i";
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::einsum(equation, inputs);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward({input}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchPermute) {
Expand All @@ -3930,7 +3977,22 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchPermute) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::permute", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchPermuteBackward) {
torch::Tensor input = torch::rand(
{2, 3, 4, 5}, torch::TensorOptions(torch::kFloat).requires_grad(true));
std::string equation = "...ij->...ji";
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::einsum(equation, inputs);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward({input}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerRepeatedAxis) {
Expand All @@ -3946,7 +4008,24 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerRepeatedAxis) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::mul", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerRepeatedAxisBackward) {
torch::Tensor x = torch::rand(
{2, 3, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true));
torch::Tensor y =
torch::rand({4}, torch::TensorOptions(torch::kFloat).requires_grad(true));
std::string equation = "ijj,k->ik";
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::einsum(equation, inputs);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward({x, y}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::einsum", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestBilinear) {
Expand Down
37 changes: 37 additions & 0 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,43 @@ bool IsNonTrivialDilation(at::IntArrayRef dilation) {

namespace aten_autograd_ops {

torch::Tensor EinsumAutogradFunction::forward(
torch::autograd::AutogradContext* ctx, const c10::string_view equation,
at::TensorList tensors) {
ctx->saved_data["equation"] = equation;

torch::autograd::variable_list vars;
for (auto const& tensor : tensors) {
vars.push_back(tensor);
}
ctx->save_for_backward(vars);

auto xla_tensors = bridge::GetXlaTensors(absl::MakeSpan(tensors));
auto output = XLATensor::einsum(equation, xla_tensors);
return bridge::AtenFromXlaTensor(output);
}

torch::autograd::variable_list EinsumAutogradFunction::backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
auto equation = ctx->saved_data["equation"].toString()->string_view();
auto tensors = ctx->get_saved_variables();

auto xla_tensors = bridge::GetXlaTensors(absl::MakeSpan(tensors));

auto outputs = XLATensor::einsum_backward(
bridge::GetXlaTensor(grad_output[0]), xla_tensors, equation);

torch::autograd::variable_list grad_inputs = {
bridge::AtenFromXlaTensor(std::get<0>(outputs))};

if (std::get<1>(outputs).defined()) {
grad_inputs.push_back(bridge::AtenFromXlaTensor(std::get<1>(outputs)));
}

return grad_inputs;
}

torch::Tensor MaxPool2dAutogradFunction::forward(
torch::autograd::AutogradContext* ctx, torch::Tensor self,
torch::IntArrayRef kernel_size, torch::IntArrayRef stride,
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/aten_autograd_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ bool IsNonTrivialDilation(at::IntArrayRef dilation);

namespace aten_autograd_ops {

struct EinsumAutogradFunction
: public torch::autograd::Function<EinsumAutogradFunction> {
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
c10::string_view equation,
at::TensorList tensors);
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output);
};

struct MaxPool2dAutogradFunction
: public torch::autograd::Function<MaxPool2dAutogradFunction> {
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
Expand Down
18 changes: 14 additions & 4 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,9 @@ void CheckSubOperandTypes(at::ScalarType type1, at::ScalarType type2) {

c10::optional<at::ScalarType> PromoteIntegralType(
at::ScalarType src_dtype, const c10::optional<at::ScalarType>& opt_dtype) {
return opt_dtype.has_value()
? opt_dtype.value()
: at::isIntegralType(src_dtype, /*includeBool=*/true) ? at::kLong
: opt_dtype;
return opt_dtype.has_value() ? opt_dtype.value()
: at::isIntegralType(src_dtype, /*includeBool=*/true) ? at::kLong
: opt_dtype;
}

bool IsTypeWithLargerRangeThanLong(torch::ScalarType dtype) {
Expand Down Expand Up @@ -1069,6 +1068,17 @@ at::Tensor XLANativeFunctions::dot(const at::Tensor& self,
bridge::GetXlaTensor(self), bridge::GetXlaTensor(tensor)));
}

at::Tensor XLANativeFunctions::einsum(c10::string_view equation,
at::TensorList tensors) {
XLA_FN_COUNTER("xla::");
// Einsum operations with more than 2 operands, like bilinear operations, are
// not currently supported in XLA
if (tensors.size() > 2) {
return at::native::einsum(equation, tensors);
}
return aten_autograd_ops::EinsumAutogradFunction::apply(equation, tensors);
}

at::Tensor XLANativeFunctions::elu_backward(const at::Tensor& grad_output,
const at::Scalar& alpha,
const at::Scalar& scale,
Expand Down
69 changes: 69 additions & 0 deletions torch_xla/csrc/ops/einsum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#include "torch_xla/csrc/ops/einsum.h"

#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"

namespace torch_xla {
namespace {

xla::XlaOp BuildEinsum(absl::Span<const xla::XlaOp> operands,
const std::string& equation) {
if (operands.size() == 1) {
return xla::Einsum(
operands[0], equation,
xla::PrecisionConfig::Precision::PrecisionConfig_Precision_DEFAULT);
} else if (operands.size() == 2) {
return xla::Einsum(
operands[0], operands[1], equation,
xla::PrecisionConfig::Precision::PrecisionConfig_Precision_DEFAULT,
XlaHelpers::PromoteType(XlaHelpers::TypeOfXlaOp(operands[0]),
XlaHelpers::TypeOfXlaOp(operands[1])));
}
}

xla::Shape NodeOutputShape(const torch::lazy::OpList& operands,
const std::string& equation) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildEinsum(operands, equation);
};

std::vector<xla::Shape> shapes;
for (auto const& op : operands) {
shapes.push_back(GetXlaShape(op));
}

return InferOutputShape(absl::MakeSpan(shapes), lower_for_shape_fn);
}

} // namespace

torch::lazy::NodePtr Einsum::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<Einsum>(operands, equation_);
}

Einsum::Einsum(const torch::lazy::OpList& operands, const std::string equation)
: XlaNode(torch::lazy::OpKind(at::aten::einsum), operands,
NodeOutputShape(operands, equation),
/*num_outputs=*/1, torch::lazy::MHash(equation)),
equation_(std::move(equation)) {}

XlaOpVector Einsum::Lower(LoweringContext* loctx) const {
std::vector<xla::XlaOp> inputs;
auto& operand_list = operands();
inputs.reserve(operand_list.size());
for (size_t i = 0; i < operand_list.size(); ++i) {
inputs.push_back(loctx->GetOutputOp(operand_list[i]));
}
return ReturnOp(BuildEinsum(absl::MakeSpan(inputs), equation_), loctx);
}

std::string Einsum::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", equation=(" << equation_ << ")";
return ss.str();
}

} // namespace torch_xla
23 changes: 23 additions & 0 deletions torch_xla/csrc/ops/einsum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include "torch_xla/csrc/ir.h"

namespace torch_xla {

class Einsum : public XlaNode {
public:
Einsum(const torch::lazy::OpList& operands, const std::string equation);

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

const std::string& equation() const { return equation_; }

private:
const std::string equation_;
};

} // namespace torch_xla
Loading

0 comments on commit 674f2cd

Please sign in to comment.