diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index 1f1c69e22f0a..e89e7ddd65e9 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -268,6 +268,7 @@ supported: - pow.Tensor_Scalar - pow.Tensor_Tensor - _prelu_kernel + - _prelu_kernel_backward - prod - prod.dim_int - _propagate_xla_data diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index 5d281a626eca..7e3a65d5442c 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -3397,6 +3397,22 @@ TEST_F(AtenXlaTensorTest, TestPrelu) { ExpectCounterChanged("xla::_prelu_kernel", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestPreluBackward) { + auto testfn = [&](const std::vector& inputs) -> torch::Tensor { + return torch::prelu(inputs[0], inputs[1]); + }; + torch::Tensor input = torch::rand( + {5, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true)); + torch::Tensor weight = torch::rand({3}, torch::TensorOptions(torch::kFloat)); + ForEachDevice([&](const torch::Device& device) { + TestBackward({input, weight}, device, testfn); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_prelu_kernel_backward", + cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestHardshrink) { torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat)); torch::Tensor output = torch::hardshrink(input); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 086eb24e51c1..49780fc8811e 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2350,6 +2350,21 @@ at::Tensor XLANativeFunctions::_prelu_kernel(const at::Tensor& self, tensor_methods::prelu(self_tensor, weight_tensor)); } +std::tuple XLANativeFunctions::_prelu_kernel_backward( + const at::Tensor& grad_output, const at::Tensor& self, + const at::Tensor& weight) { + TORCH_LAZY_FN_COUNTER("xla::"); + + XLATensorPtr grad_output_tensor = bridge::GetXlaTensor(grad_output); + XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr weight_tensor = bridge::GetXlaTensor(weight); + + auto outputs = tensor_methods::prelu_backward(grad_output_tensor, self_tensor, + weight_tensor); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), + bridge::AtenFromXlaTensor(std::get<1>(outputs))); +} + at::Tensor XLANativeFunctions::prod(const at::Tensor& self, c10::optional dtype) { TORCH_LAZY_FN_COUNTER("xla::"); diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 132bba18f950..88ce96cab999 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -235,6 +235,19 @@ xla::XlaOp BuildPrelu(xla::XlaOp input, xla::XlaOp weight) { return xla::Select(xla::Gt(input, zero), input, product); } +std::vector BuildPreluBackward(xla::XlaOp grad, xla::XlaOp input, + xla::XlaOp weight) { + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); + const xla::Shape& weight_shape = ShapeHelper::ShapeOfXlaOp(weight); + + xla::XlaOp zero = xla::Zero(input.builder(), input_shape.element_type()); + xla::XlaOp grad_input = xla::Mul(weight, grad); + xla::XlaOp grad_weight = xla::Mul(input, grad); + + return {xla::Select(xla::Gt(input, zero), grad, grad_input), + xla::Select(xla::Gt(input, zero), zero, grad_weight)}; +} + xla::XlaOp BuildSigmoid(xla::XlaOp input) { return xla::Logistic(input); } xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input) { diff --git a/torch_xla/csrc/elementwise.h b/torch_xla/csrc/elementwise.h index 0753925e82f7..1b327e1bfc1c 100644 --- a/torch_xla/csrc/elementwise.h +++ b/torch_xla/csrc/elementwise.h @@ -21,6 +21,9 @@ xla::XlaOp BuildRelu(xla::XlaOp input); xla::XlaOp BuildPrelu(xla::XlaOp input, xla::XlaOp weight); +std::vector BuildPreluBackward(xla::XlaOp grad, xla::XlaOp input, + xla::XlaOp weight); + std::vector BuildRrelu(xla::XlaOp input, const at::Scalar& lower, const at::Scalar& upper, bool training, xla::XlaOp rng_seed); diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 933149040beb..ae3fb83d54b5 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -132,6 +132,25 @@ torch::lazy::NodePtr Prelu(const torch::lazy::Value& input, GetXlaShape(input), std::move(lower_fn)); } +torch::lazy::NodePtr PreluBackward(const torch::lazy::Value& grad, + const torch::lazy::Value& input, + const torch::lazy::Value& weight) { + auto lower_fn = [](const XlaNode& node, + LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_grad = loctx->GetOutputOp(node.operand(0)); + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1)); + xla::XlaOp xla_weight = loctx->GetOutputOp(node.operand(2)); + return node.ReturnOps(BuildPreluBackward(xla_grad, xla_input, xla_weight), + loctx); + }; + + return GenericOp( + torch::lazy::OpKind(at::aten::_prelu_kernel_backward), + {grad, input, weight}, + xla::ShapeUtil::MakeTupleShape({GetXlaShape(grad), GetXlaShape(input)}), + std::move(lower_fn), /*num_outputs=*/2); +} + torch::lazy::NodePtr LogSigmoid(const torch::lazy::Value& input) { auto lower_fn = [](const XlaNode& node, LoweringContext* loctx) -> XlaOpVector { diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index d8015ed5f633..c110fd3a32ca 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -101,6 +101,10 @@ torch::lazy::NodePtr Sqrt(const torch::lazy::Value& input); torch::lazy::NodePtr Prelu(const torch::lazy::Value& input, const torch::lazy::Value& weight); +torch::lazy::NodePtr PreluBackward(const torch::lazy::Value& grad, + const torch::lazy::Value& input, + const torch::lazy::Value& weight); + torch::lazy::NodePtr Pow(const torch::lazy::Value& input, const torch::lazy::Value& exponent); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index e036c0e70778..fa54741190d5 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2065,6 +2065,15 @@ XLATensorPtr prelu(const XLATensorPtr& input, const XLATensorPtr& weight) { return input->CreateFrom(Prelu(input->GetIrValue(), weight->GetIrValue())); } +std::tuple prelu_backward( + const XLATensorPtr& grad, const XLATensorPtr& input, + const XLATensorPtr& weight) { + torch::lazy::NodePtr node = PreluBackward( + grad->GetIrValue(), input->GetIrValue(), weight->GetIrValue()); + return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)), + input->CreateFrom(torch::lazy::Value(node, 1))); +} + XLATensorPtr prod(const XLATensorPtr& input, std::vector dimensions, bool keep_reduced_dimensions, c10::optional dtype) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 88d6e8b44965..5a714170300e 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -645,6 +645,10 @@ XLATensorPtr pow(const at::Scalar& input, const XLATensorPtr& exponent); XLATensorPtr prelu(const XLATensorPtr& input, const XLATensorPtr& weight); +std::tuple prelu_backward( + const XLATensorPtr& grad_out, const XLATensorPtr& input, + const XLATensorPtr& weight); + XLATensorPtr prod(const XLATensorPtr& input, std::vector dimensions, bool keep_reduced_dimensions, c10::optional dtype);