Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Autogen Prim Operants] Autogen prim eager and static tensor operants #50558

Merged
merged 8 commits into from
Feb 17, 2023

Conversation

jiahy0825
Copy link
Contributor

@jiahy0825 jiahy0825 commented Feb 16, 2023

PR types

New features

PR changes

Others

Describe

Auto-generate tensor operants codes, code example:

eager_tensor_operants.h

// Generated by paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py

#pragma once

#include "paddle/phi/api/include/operants_base.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/macros.h"


namespace paddle {

namespace prim {

using Tensor = paddle::experimental::Tensor;
using TensorOperantsBase = paddle::operants::TensorOperantsBase;

class EagerTensorOperants : public TensorOperantsBase {
 private:
  DISABLE_COPY_AND_ASSIGN(EagerTensorOperants);

 public:
  EagerTensorOperants() = default;

  Tensor multiply(const Tensor& x, const Tensor& y);

};

}  // namespace prim
}  // namespace paddle

eager_tensor_operants.cc

// Generated by paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py

#include "paddle/fluid/prim/utils/eager/eager_tensor_operants.h"

#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"


namespace paddle {

namespace prim {

Tensor EagerTensorOperants::multiply(const Tensor& x, const Tensor& y) {
  return ::multiply_ad_func(x, y);
}


}  // namespace prim
}  // namespace paddle

static_tensor_operants.h

// Generated by paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py

#pragma once

#include "paddle/phi/api/include/operants_base.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/macros.h"


namespace paddle {

namespace prim {

using Tensor = paddle::experimental::Tensor;
using TensorOperantsBase = paddle::operants::TensorOperantsBase;

class StaticTensorOperants : public TensorOperantsBase {
 private:
  DISABLE_COPY_AND_ASSIGN(StaticTensorOperants);

 public:
  StaticTensorOperants() = default;

  Tensor multiply(const Tensor& x, const Tensor& y);

};

}  // namespace prim
}  // namespace paddle

static_tensor_operants.cc

// Generated by paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py

#include "paddle/fluid/prim/utils/static/static_tensor_operants.h"

#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"


namespace paddle {

namespace prim {
using DescTensor = paddle::prim::DescTensor;

Tensor StaticTensorOperants::multiply(const Tensor& x, const Tensor& y) {
  return paddle::prim::multiply<DescTensor>(x, y);
}


}  // namespace prim
}  // namespace paddle

@paddle-bot
Copy link

paddle-bot bot commented Feb 16, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Comment on lines +173 to +175
api_prims = yaml.safe_load(f)
# white list temporarily
api_prims = ('add', 'subtract', 'multiply', 'divide')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

第一句的api_prims赋值有什么作用,另外api_prims = ('add', 'subtract', 'multiply', 'divide')这个应该会多次做出修改,是不是可以作为一个全局变量放在外边,方便修改维护

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在只重载了四则运算,api_prims = ('add', 'subtract', 'multiply', 'divide') 只是一个短期的白名单。后续会直接将这行代码删掉,根据 api_prims = yaml.safe_load(f) 生成 prim 的 api。

api_prims = ('add', 'subtract', 'multiply', 'divide') is a temporary whitelist solution because tensor operants only overload arithmetic operators. In the future, this line of code will be deleted and generate prim api according to api_prims = yaml.safe_load(f)

YuanRisheng
YuanRisheng previously approved these changes Feb 17, 2023
@@ -0,0 +1,273 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022 -> 2023

Copy link
Contributor Author

@jiahy0825 jiahy0825 Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx

api_prims = ('add', 'subtract', 'multiply', 'divide')

for api in apis:
eager_api = EagerPrimAPI(api, api_prims)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从长期来看 tensor operants 和 EagerPrimAPI 是不解耦开比较好些?Operants本身和PrimAPI不是一个强绑定的关系,如果以后要新增一套支持C++组网训练的API,可能operants还需要从prim里面分出来

Copy link
Contributor Author

@jiahy0825 jiahy0825 Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@jiahy0825 jiahy0825 merged commit e89baf9 into PaddlePaddle:develop Feb 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants