Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensor Operants & Prim] Tensor arithmetic operants support right scalar type #50563

Merged
merged 21 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
- add
- subtract
- multiply
- divide
- unsqueeze
- pow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we don't need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

以加法为例,虽然用户不再需要调用 add 函数,可以直接调用 + 运算符。
但是重载后的 + 操作,实际依赖的还是 add 函数,因此需要保留这几项,自动生成 add 函数的声明。

Although users could invoke the + operant directly without calling the add function, in fact, the overloaded + operant relies on the add function. Hence, we need to keep these items to generate an add declaration automatically.

- exp
- scale
- multiply
- matmul
- expand
- divide
- sum
- add
- abs
- assign
- concat
Expand All @@ -24,4 +25,3 @@
- scatter_nd_add
- tile
- transpose
- subtract
49 changes: 49 additions & 0 deletions paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ class EagerTensorOperants : public TensorOperantsBase {
public:
EagerTensorOperants() = default;

Tensor add(const Tensor& x, const Scalar& y);

Tensor subtract(const Tensor& x, const Scalar& y);

Tensor multiply(const Tensor& x, const Scalar& y);

Tensor divide(const Tensor& x, const Scalar& y);

"""


Expand All @@ -73,6 +81,22 @@ class EagerTensorOperants : public TensorOperantsBase {

namespace prim {

Tensor EagerTensorOperants::add(const Tensor& x, const Scalar& y) {
return ::add_ad_func(x, ::full_like_ad_func(x, y));
}

Tensor EagerTensorOperants::subtract(const Tensor& x, const Scalar& y) {
return ::subtract_ad_func(x, ::full_like_ad_func(x, y));
}

Tensor EagerTensorOperants::multiply(const Tensor& x, const Scalar& y) {
return ::multiply_ad_func(x, ::full_like_ad_func(x, y));
}

Tensor EagerTensorOperants::divide(const Tensor& x, const Scalar& y) {
return ::divide_ad_func(x, ::full_like_ad_func(x, y));
}

"""


Expand Down Expand Up @@ -112,6 +136,14 @@ class StaticTensorOperants : public TensorOperantsBase {
public:
StaticTensorOperants() = default;

Tensor add(const Tensor& x, const Scalar& y);

Tensor subtract(const Tensor& x, const Scalar& y);

Tensor multiply(const Tensor& x, const Scalar& y);

Tensor divide(const Tensor& x, const Scalar& y);

"""


Expand All @@ -128,6 +160,7 @@ class StaticTensorOperants : public TensorOperantsBase {
#include "paddle/fluid/prim/utils/static/static_tensor_operants.h"

#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"

"""
Expand All @@ -139,6 +172,22 @@ class StaticTensorOperants : public TensorOperantsBase {
namespace prim {
using DescTensor = paddle::prim::DescTensor;

Tensor StaticTensorOperants::add(const Tensor& x, const Scalar& y) {
return paddle::prim::add<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}

Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) {
return paddle::prim::subtract<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}

Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) {
return paddle::prim::multiply<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}

Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) {
return paddle::prim::divide<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}

"""


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ void gather_grad(const Tensor& x,
template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
if (!grad_x) return;
auto tmp = out.pow(2.0);
tmp = scale<T>(tmp, -1.0, 1.0, true);
auto grad_x_tmp = grad_out * tmp;
auto grad_x_tmp = grad_out * (out.pow(2.0) * -1.0 + 1.0);
set_output<T>(grad_x_tmp, grad_x);
}

Expand Down Expand Up @@ -203,10 +201,7 @@ void divide_grad(const Tensor& x,
Tensor* dy) {
if (dy) {
// dy = -(x/y^2) * dout
auto tmp0 = y.pow(2.0);
auto tmp1 = x / tmp0;
auto tmp2 = scale<T>(tmp1, -1.0, 0.0, true);
auto dy_res = tmp2 * out_grad;
auto dy_res = x / y.pow(2.0) * -1.0 * out_grad;
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
Expand Down Expand Up @@ -247,8 +242,7 @@ void divide_grad(const Tensor& x,
template <typename T>
void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
auto x_grad_tmp = out_grad * div_x / out;
auto x_grad_tmp = out_grad / 2.0 / out;
set_output<T>(x_grad_tmp, x_grad);
}
}
Expand Down
58 changes: 43 additions & 15 deletions paddle/fluid/prim/api/manual_prim/static_prim_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,51 @@ Tensor full<DescTensor>(const IntArray& shape,
framework::OpDesc* op = block->AppendOp();
op->SetType("fill_constant");
op->SetAttr("shape", shape.GetData());
PADDLE_ENFORCE_EQ(
((dtype == DataType::FLOAT32) || (dtype == DataType::FLOAT64) ||
(dtype == DataType::FLOAT16)),
true,
phi::errors::InvalidArgument(
"We only support float32/float16 for full, but we got data type: %s",
switch (dtype) {
case phi::DataType::FLOAT16:
op->SetAttr("str_value", std::to_string(value.to<float>()));
break;
case phi::DataType::FLOAT32:
op->SetAttr("value", value.to<float>());
break;
case phi::DataType::FLOAT64:
op->SetAttr("str_value", std::to_string(value.to<double>()));
break;
case phi::DataType::BOOL:
op->SetAttr("str_value", std::to_string(value.to<bool>()));
break;
case phi::DataType::INT8:
op->SetAttr("str_value", std::to_string(value.to<int8_t>()));
break;
case phi::DataType::UINT8:
op->SetAttr("str_value", std::to_string(value.to<uint8_t>()));
break;
case phi::DataType::INT16:
op->SetAttr("str_value", std::to_string(value.to<int16_t>()));
break;
case phi::DataType::UINT16:
op->SetAttr("str_value", std::to_string(value.to<uint16_t>()));
break;
case phi::DataType::INT32:
op->SetAttr("str_value", std::to_string(value.to<int32_t>()));
break;
case phi::DataType::UINT32:
op->SetAttr("str_value", std::to_string(value.to<uint32_t>()));
break;
case phi::DataType::INT64:
op->SetAttr("str_value", std::to_string(value.to<int64_t>()));
break;
case phi::DataType::UINT64:
op->SetAttr("str_value", std::to_string(value.to<uint64_t>()));
break;
default:
PADDLE_THROW(phi::errors::Unimplemented(
"We support "
"bool/float16/float32/float64/int8/int16/int32/int64/uint8/uint16/"
"uint32/uint64 for full, but we got data type: %s",
phi::DataTypeToString(dtype)));
if (dtype == phi::DataType::FLOAT32) {
op->SetAttr("value", value.to<float>());
} else if (dtype == phi::DataType::FLOAT64) {
op->SetAttr("str_value", std::to_string(value.to<double>()));
} else if (dtype == phi::DataType::FLOAT16) {
op->SetAttr("str_value", std::to_string(value.to<float>()));
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"We only support float64/float32/float16 for full"));
}

op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype));
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/prim/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ cc_test_old(
prim_utils
operator
elementwise_mul_op
elementwise_add_op
fill_constant_op
activation_op
phi_api
phi_dygraph_api
Expand Down
41 changes: 31 additions & 10 deletions paddle/fluid/prim/tests/test_static_prim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Expand All @@ -43,6 +44,7 @@ PD_DECLARE_KERNEL(tanh, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT);
#endif
Expand Down Expand Up @@ -192,7 +194,7 @@ TEST(StaticPrim, TanhBackwardComposite) {
target_block,
grad_sub_block));
ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(3));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(6));
ASSERT_EQ(target_block->AllOps()[0]->Type(), "tanh");
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(),
static_cast<std::size_t>(1));
Expand All @@ -210,14 +212,9 @@ TEST(StaticPrim, TanhBackwardComposite) {
ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[1]->Type(), "scale");
ASSERT_EQ(grad_ops[1]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Inputs().at("X")[0],
grad_ops[0]->Outputs().at("Out")[0]);
ASSERT_EQ(PADDLE_GET_CONST(float, grad_ops[1]->GetAttr("scale")),
static_cast<float>(-1.0));
ASSERT_EQ(PADDLE_GET_CONST(float, grad_ops[1]->GetAttr("bias")),
static_cast<float>(1.0));
ASSERT_EQ(grad_ops[1]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[1]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[1]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

Expand All @@ -226,9 +223,31 @@ TEST(StaticPrim, TanhBackwardComposite) {
ASSERT_EQ(grad_ops[2]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("Y")[0],
grad_ops[1]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[2]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[2]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[3]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[3]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[3]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[4]->Type(), "elementwise_add");
ASSERT_EQ(grad_ops[4]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y")[0],
grad_ops[3]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[4]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[5]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[5]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[5]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[5]->Inputs().at("Y")[0],
grad_ops[4]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[5]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[5]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
}

TEST(StaticCompositeGradMaker, TestMutiInputMethod) {
Expand Down Expand Up @@ -368,8 +387,10 @@ TEST(StaticPrim, TestFlags) {

} // namespace prim
} // namespace paddle
USE_OP_ITSELF(fill_constant);
USE_OP_ITSELF(tanh);
USE_OP_ITSELF(tanh_grad);
USE_OP_ITSELF(pow);
USE_OP_ITSELF(elementwise_mul);
USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(scale);
13 changes: 13 additions & 0 deletions paddle/phi/api/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,14 @@ class PADDLE_API Tensor final {

Tensor operator/(const Tensor& other) const;

Tensor operator+(const Scalar& other) const;

Tensor operator-(const Scalar& other) const;

Tensor operator*(const Scalar& other) const;

Tensor operator/(const Scalar& other) const;

/* Part 8: Autograd methods */

/**
Expand Down Expand Up @@ -663,6 +671,11 @@ class PADDLE_API Tensor final {
Tensor divide(const Tensor& y) const;
Tensor multiply(const Tensor& y) const;
Tensor subtract(const Tensor& y) const;
Tensor add(const Scalar& y) const;
Tensor divide(const Scalar& y) const;
Tensor multiply(const Scalar& y) const;
Tensor subtract(const Scalar& y) const;

Tensor exp() const;
Tensor floor() const;
Tensor gather_nd(const Tensor& index) const;
Expand Down
Loading