Skip to content

Commit

Permalink
add op lowering for _prelu_kernel_backward (#5724)
Browse files Browse the repository at this point in the history
* add op lowering for _prelu_kernel_backward

* remove comment line
  • Loading branch information
zpcore authored and golechwierowicz committed Jan 12, 2024
1 parent 96967d8 commit 32fd41e
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 0 deletions.
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ supported:
- pow.Tensor_Scalar
- pow.Tensor_Tensor
- _prelu_kernel
- _prelu_kernel_backward
- prod
- prod.dim_int
- _propagate_xla_data
Expand Down
16 changes: 16 additions & 0 deletions test/cpp/test_aten_xla_tensor_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3397,6 +3397,22 @@ TEST_F(AtenXlaTensorTest, TestPrelu) {
ExpectCounterChanged("xla::_prelu_kernel", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestPreluBackward) {
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::prelu(inputs[0], inputs[1]);
};
torch::Tensor input = torch::rand(
{5, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true));
torch::Tensor weight = torch::rand({3}, torch::TensorOptions(torch::kFloat));
ForEachDevice([&](const torch::Device& device) {
TestBackward({input, weight}, device, testfn);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::_prelu_kernel_backward",
cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestHardshrink) {
torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat));
torch::Tensor output = torch::hardshrink(input);
Expand Down
15 changes: 15 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2350,6 +2350,21 @@ at::Tensor XLANativeFunctions::_prelu_kernel(const at::Tensor& self,
tensor_methods::prelu(self_tensor, weight_tensor));
}

std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::_prelu_kernel_backward(
const at::Tensor& grad_output, const at::Tensor& self,
const at::Tensor& weight) {
TORCH_LAZY_FN_COUNTER("xla::");

XLATensorPtr grad_output_tensor = bridge::GetXlaTensor(grad_output);
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
XLATensorPtr weight_tensor = bridge::GetXlaTensor(weight);

auto outputs = tensor_methods::prelu_backward(grad_output_tensor, self_tensor,
weight_tensor);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)),
bridge::AtenFromXlaTensor(std::get<1>(outputs)));
}

at::Tensor XLANativeFunctions::prod(const at::Tensor& self,
c10::optional<at::ScalarType> dtype) {
TORCH_LAZY_FN_COUNTER("xla::");
Expand Down
13 changes: 13 additions & 0 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,19 @@ xla::XlaOp BuildPrelu(xla::XlaOp input, xla::XlaOp weight) {
return xla::Select(xla::Gt(input, zero), input, product);
}

std::vector<xla::XlaOp> BuildPreluBackward(xla::XlaOp grad, xla::XlaOp input,
xla::XlaOp weight) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
const xla::Shape& weight_shape = ShapeHelper::ShapeOfXlaOp(weight);

xla::XlaOp zero = xla::Zero(input.builder(), input_shape.element_type());
xla::XlaOp grad_input = xla::Mul(weight, grad);
xla::XlaOp grad_weight = xla::Mul(input, grad);

return {xla::Select(xla::Gt(input, zero), grad, grad_input),
xla::Select(xla::Gt(input, zero), zero, grad_weight)};
}

xla::XlaOp BuildSigmoid(xla::XlaOp input) { return xla::Logistic(input); }

xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input) {
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ xla::XlaOp BuildRelu(xla::XlaOp input);

xla::XlaOp BuildPrelu(xla::XlaOp input, xla::XlaOp weight);

std::vector<xla::XlaOp> BuildPreluBackward(xla::XlaOp grad, xla::XlaOp input,
xla::XlaOp weight);

std::vector<xla::XlaOp> BuildRrelu(xla::XlaOp input, const at::Scalar& lower,
const at::Scalar& upper, bool training,
xla::XlaOp rng_seed);
Expand Down
19 changes: 19 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,25 @@ torch::lazy::NodePtr Prelu(const torch::lazy::Value& input,
GetXlaShape(input), std::move(lower_fn));
}

torch::lazy::NodePtr PreluBackward(const torch::lazy::Value& grad,
const torch::lazy::Value& input,
const torch::lazy::Value& weight) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_grad = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1));
xla::XlaOp xla_weight = loctx->GetOutputOp(node.operand(2));
return node.ReturnOps(BuildPreluBackward(xla_grad, xla_input, xla_weight),
loctx);
};

return GenericOp(
torch::lazy::OpKind(at::aten::_prelu_kernel_backward),
{grad, input, weight},
xla::ShapeUtil::MakeTupleShape({GetXlaShape(grad), GetXlaShape(input)}),
std::move(lower_fn), /*num_outputs=*/2);
}

torch::lazy::NodePtr LogSigmoid(const torch::lazy::Value& input) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ torch::lazy::NodePtr Sqrt(const torch::lazy::Value& input);
torch::lazy::NodePtr Prelu(const torch::lazy::Value& input,
const torch::lazy::Value& weight);

torch::lazy::NodePtr PreluBackward(const torch::lazy::Value& grad,
const torch::lazy::Value& input,
const torch::lazy::Value& weight);

torch::lazy::NodePtr Pow(const torch::lazy::Value& input,
const torch::lazy::Value& exponent);

Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2065,6 +2065,15 @@ XLATensorPtr prelu(const XLATensorPtr& input, const XLATensorPtr& weight) {
return input->CreateFrom(Prelu(input->GetIrValue(), weight->GetIrValue()));
}

std::tuple<XLATensorPtr, XLATensorPtr> prelu_backward(
const XLATensorPtr& grad, const XLATensorPtr& input,
const XLATensorPtr& weight) {
torch::lazy::NodePtr node = PreluBackward(
grad->GetIrValue(), input->GetIrValue(), weight->GetIrValue());
return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)),
input->CreateFrom(torch::lazy::Value(node, 1)));
}

XLATensorPtr prod(const XLATensorPtr& input, std::vector<int64_t> dimensions,
bool keep_reduced_dimensions,
c10::optional<at::ScalarType> dtype) {
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,10 @@ XLATensorPtr pow(const at::Scalar& input, const XLATensorPtr& exponent);

XLATensorPtr prelu(const XLATensorPtr& input, const XLATensorPtr& weight);

std::tuple<XLATensorPtr, XLATensorPtr> prelu_backward(
const XLATensorPtr& grad_out, const XLATensorPtr& input,
const XLATensorPtr& weight);

XLATensorPtr prod(const XLATensorPtr& input, std::vector<int64_t> dimensions,
bool keep_reduced_dimensions,
c10::optional<at::ScalarType> dtype);
Expand Down

0 comments on commit 32fd41e

Please sign in to comment.