Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensor Operants & Prim] Tensor pow API uses elementwise_pow #50886

Merged
merged 2 commits into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
- multiply
- divide
- unsqueeze
- pow
- exp
- scale
- matmul
Expand All @@ -25,5 +24,4 @@
- scatter_nd_add
- tile
- transpose
- subtract
- pad
24 changes: 24 additions & 0 deletions paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class EagerTensorOperants : public TensorOperantsBase {

Tensor divide(const Scalar& x, const Tensor& y);

Tensor pow(const Tensor& x, const Tensor& y);

Tensor pow(const Tensor& x, const Scalar& y);

"""


Expand Down Expand Up @@ -121,6 +125,14 @@ class EagerTensorOperants : public TensorOperantsBase {
return ::divide_ad_func(::full_like_ad_func(y, x), y);
}

Tensor EagerTensorOperants::pow(const Tensor& x, const Tensor& y) {
return ::elementwise_pow_ad_func(x, y);
}

Tensor EagerTensorOperants::pow(const Tensor& x, const Scalar& y) {
return ::elementwise_pow_ad_func(x, ::full_like_ad_func(x, y));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个为什么不使用pow的接口?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处来自组合算子需求,要求尽可能使用基础算子,使用 elementwise + full_like 做组合,可以节省一个 pow 算子(接口)

Prim requires to use basic operators as much as possible. Using elementwise + full_like can save the operator of pow.

}

"""


Expand Down Expand Up @@ -176,6 +188,10 @@ class StaticTensorOperants : public TensorOperantsBase {

Tensor divide(const Scalar& x, const Tensor& y);

Tensor pow(const Tensor& x, const Tensor& y);

Tensor pow(const Tensor& x, const Scalar& y);

"""


Expand Down Expand Up @@ -236,6 +252,14 @@ class StaticTensorOperants : public TensorOperantsBase {
return paddle::prim::divide<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
}

Tensor StaticTensorOperants::pow(const Tensor& x, const Tensor& y) {
return paddle::prim::elementwise_pow<DescTensor>(x, y);
}

Tensor StaticTensorOperants::pow(const Tensor& x, const Scalar& y) {
return paddle::prim::elementwise_pow<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}

"""


Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/prim/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cc_test_old(
operator
elementwise_mul_op
elementwise_sub_op
elementwise_pow_op
fill_constant_op
activation_op
phi_api
Expand Down
43 changes: 24 additions & 19 deletions paddle/fluid/prim/tests/test_static_prim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ TEST(StaticPrim, TanhBackwardComposite) {
target_block,
grad_sub_block));
ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(4));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(5));
ASSERT_EQ(target_block->AllOps()[0]->Type(), "tanh");
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(),
static_cast<std::size_t>(1));
Expand All @@ -204,36 +204,41 @@ TEST(StaticPrim, TanhBackwardComposite) {
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out")[0], "b");
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out")[0], "b");

ASSERT_EQ(grad_ops[0]->Type(), "pow");
ASSERT_EQ(grad_ops[0]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[0]->Inputs().at("X")[0], "b");
ASSERT_EQ(PADDLE_GET_CONST(float, grad_ops[0]->GetAttr("factor")),
static_cast<float>(2.0));
ASSERT_EQ(grad_ops[0]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[0]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[1]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[1]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[1]->Outputs().at("Out").size(),
ASSERT_EQ(grad_ops[1]->Type(), "elementwise_pow");
ASSERT_EQ(grad_ops[1]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Inputs().at("X")[0], "b");
ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[2]->Type(), "elementwise_sub");
ASSERT_EQ(grad_ops[2]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("X")[0],
grad_ops[1]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[2]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[2]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[2]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[3]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[3]->Type(), "elementwise_sub");
ASSERT_EQ(grad_ops[3]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Inputs().at("Y")[0],
ASSERT_EQ(grad_ops[3]->Inputs().at("X")[0],
grad_ops[2]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[3]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[3]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[4]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[4]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y")[0],
grad_ops[3]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[4]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[4]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
}

TEST(StaticCompositeGradMaker, TestMutiInputMethod) {
Expand Down Expand Up @@ -376,7 +381,7 @@ TEST(StaticPrim, TestFlags) {
USE_OP_ITSELF(fill_constant);
USE_OP_ITSELF(tanh);
USE_OP_ITSELF(tanh_grad);
USE_OP_ITSELF(pow);
USE_OP_ITSELF(elementwise_mul);
USE_OP_ITSELF(elementwise_sub);
USE_OP_ITSELF(elementwise_pow);
USE_OP_ITSELF(scale);
3 changes: 2 additions & 1 deletion paddle/phi/api/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -677,12 +677,13 @@ class PADDLE_API Tensor final {
Tensor divide(const Scalar& y) const;
Tensor multiply(const Scalar& y) const;
Tensor subtract(const Scalar& y) const;
Tensor pow(const Tensor& y) const;
Tensor pow(const Scalar& y) const;

Tensor exp() const;
Tensor floor() const;
Tensor gather_nd(const Tensor& index) const;
Tensor log() const;
Tensor pow(const Scalar& y) const;
Tensor roll(const IntArray& shifts, const std::vector<int64_t>& axis) const;
Tensor scatter(const Tensor& index,
const Tensor& updates,
Expand Down
42 changes: 41 additions & 1 deletion paddle/phi/api/yaml/generator/tensor_operants_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

indent = " "

specific_ops_map = {"elementwise_pow": "pow"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以加点注释,说明一下作用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next PR fix this.



operants_base_include = """// Generated by paddle/phi/api/yaml/generator/tensor_operants_gen.py

Expand Down Expand Up @@ -68,6 +70,10 @@ class TensorOperantsBase {
virtual Tensor multiply(const Scalar& x, const Tensor& y) = 0;

virtual Tensor subtract(const Scalar& x, const Tensor& y) = 0;

virtual Tensor pow(const Tensor& x, const Tensor& y) = 0;

virtual Tensor pow(const Tensor& x, const Scalar& y) = 0;
"""


Expand Down Expand Up @@ -143,6 +149,14 @@ class TensorOperantsBase {
return scale(-1.0, 0.0, true);
}

Tensor Tensor::pow(const Tensor& y) const {
return paddle::OperantsManager::Instance().pow(static_cast<const Tensor &>(*this), y);
}

Tensor Tensor::pow(const Scalar& y) const {
return paddle::OperantsManager::Instance().pow(static_cast<const Tensor &>(*this), y);
}

PADDLE_API Tensor operator+(const Scalar& x, const Tensor& y) {
return paddle::OperantsManager::Instance().add(x, y);
}
Expand Down Expand Up @@ -211,6 +225,10 @@ class PhiTensorOperants : public TensorOperantsBase {

Tensor divide(const Scalar& x, const Tensor& y);

Tensor pow(const Tensor& x, const Tensor& y);

Tensor pow(const Tensor& x, const Scalar& y);

"""


Expand Down Expand Up @@ -267,6 +285,14 @@ class PhiTensorOperants : public TensorOperantsBase {
Tensor PhiTensorOperants::divide(const Scalar& x, const Tensor& y) {
return paddle::experimental::divide(paddle::experimental::full_like(y, x), y);
}

Tensor PhiTensorOperants::pow(const Tensor& x, const Tensor& y) {
return paddle::experimental::elementwise_pow(x, y);
}

Tensor PhiTensorOperants::pow(const Tensor& x, const Scalar& y) {
return paddle::experimental::elementwise_pow(x, paddle::experimental::full_like(x, y));
}
"""


Expand Down Expand Up @@ -359,6 +385,10 @@ class OperantsManager {

Tensor divide(const Scalar& x, const Tensor& y);

Tensor pow(const Tensor& x, const Tensor& y);

Tensor pow(const Tensor& x, const Scalar& y);

"""


Expand Down Expand Up @@ -512,8 +542,10 @@ def gene_operants_implementation(self):

"""

def gene_operants_manager_code(self):
def gene_operants_manager_code(self, is_specific_op=False):
func_name = self.get_api_func_name()
if is_specific_op:
func_name = specific_ops_map[func_name]
func_args = self.inputs['names'] + self.attrs['names']
func_args_code = ", ".join(func_args)
return f"""
Expand Down Expand Up @@ -552,11 +584,19 @@ def gene_operants_manager_code(self):
def gene_operants_manager_implementation(self):
func_name = self.get_api_func_name()
final_code = ""
# Codes for arthemetic operants
if func_name in ["add", "subtract", "multiply", "divide"]:
final_code += f"""
{self.get_return_type()} OperantsManager::{func_name}(const Tensor& x, const Scalar& y) {{{self.gene_operants_manager_code()}}}

{self.get_return_type()} OperantsManager::{func_name}(const Scalar& x, const Tensor& y) {{{self.gene_operants_manager_code()}}}
"""
# Codes for specific operants
if func_name in specific_ops_map.keys():
final_code += f"""
{self.get_return_type()} OperantsManager::{specific_ops_map[func_name]}(const Tensor& x, const Tensor& y) {{{self.gene_operants_manager_code(is_specific_op=True)}}}

{self.get_return_type()} OperantsManager::{specific_ops_map[func_name]}(const Tensor& x, const Scalar& y) {{{self.gene_operants_manager_code(is_specific_op=True)}}}
"""
# func decalaration
if func_name[-1] != '_':
Expand Down
9 changes: 4 additions & 5 deletions paddle/phi/api/yaml/tensor_operants.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Attach operants to Tensor, this file should be consistent with the declaration in `tensor.h`
- add
- subtract
- multiply
- divide
- unsqueeze
- pow
- exp
- scale
- multiply
- matmul
- expand
- divide
- sum
- add
- abs
- assign
- elementwise_pow
Expand All @@ -22,4 +22,3 @@
- scatter
- scatter_nd_add
- tile
- subtract
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
set(),
tuple(),
(
'pow',
'fill_constant',
'elementwise_pow',
'fill_constant',
'elementwise_sub',
'elementwise_mul',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def setUp(self):
self.grad_sub_block = tuple()
self.desired_ops = 'tanh_grad'
self.desired_ops_no_skip = (
'pow',
'fill_constant',
'elementwise_pow',
'fill_constant',
'elementwise_sub',
'elementwise_mul',
Expand Down