Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support custom vjp trait #57106

Merged
merged 40 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7de4b36
test prim custom vjp in New IR
cyber-pioneer Sep 6, 2023
94cea80
add a new CustomVjpTrait to represent whether an op has custom vjp
lxd-cumt Sep 6, 2023
892b0cd
add has_custom_vjp_op_list to represent ops that have custom vjp
lxd-cumt Sep 6, 2023
37d77c4
parse has_custom_vjp_op_list and autogen CustomVjpTrait for those ops
lxd-cumt Sep 6, 2023
3945ade
add pybind to support checking whether an op has custom vjp in python…
lxd-cumt Sep 6, 2023
c5bd381
fix conflicts
lxd-cumt Sep 6, 2023
b8e9706
add test
lxd-cumt Sep 7, 2023
31230dd
add test for add op custom vjp
lxd-cumt Sep 7, 2023
aae5244
add pybind to support checking whether an op has custom vjp in python…
lxd-cumt Sep 7, 2023
60033d9
fix bugs
lxd-cumt Sep 7, 2023
8338fca
polish code
lxd-cumt Sep 7, 2023
09ca597
fix bugs
lxd-cumt Sep 7, 2023
5f8935a
generate custom_vjp trait based on op list from gen.py
lxd-cumt Sep 7, 2023
401d26b
delete has_custom_vjp_op_list
lxd-cumt Sep 7, 2023
17d9529
fix bugs
lxd-cumt Sep 7, 2023
4fd9fa3
use currently defined list CUSTOM_VJP and VJP_COMPS rather than defin…
lxd-cumt Sep 7, 2023
93b2322
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 8, 2023
a0417b6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 8, 2023
d6cb7b5
fix ctest
lxd-cumt Sep 8, 2023
20dfa96
fix bugs
lxd-cumt Sep 8, 2023
ccbd60b
divide vjp into prim_vjp and custom vjp
lxd-cumt Sep 8, 2023
87fa37b
fix conflicts
lxd-cumt Sep 8, 2023
c614b85
add code comments
lxd-cumt Sep 8, 2023
7cfaa72
add code comments
lxd-cumt Sep 8, 2023
eee9c41
polish codes
lxd-cumt Sep 8, 2023
3b0ae3e
polish code comments
lxd-cumt Sep 8, 2023
bfd1f5a
polish codes
lxd-cumt Sep 8, 2023
ff3861b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 8, 2023
d447056
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 8, 2023
ebc8754
fix conflicts with <rename ir into pir>
lxd-cumt Sep 11, 2023
877cc7c
fix bugs
lxd-cumt Sep 11, 2023
877b9ac
add code comments
lxd-cumt Sep 11, 2023
d96340d
fix bugs
lxd-cumt Sep 11, 2023
8403398
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 11, 2023
54952c1
add custom vjp trait support in new folder
lxd-cumt Sep 11, 2023
3d75a25
fix bugs
lxd-cumt Sep 11, 2023
e3ffbc0
add another example for unit testing
lxd-cumt Sep 11, 2023
c26425e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 12, 2023
219ca2a
fix bugs
lxd-cumt Sep 12, 2023
1925e3f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
lxd-cumt Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import argparse
import logging
import os
import pathlib
import sys

import yaml
from op_build_gen import gen_build_func_str
Expand All @@ -30,6 +32,12 @@
vjp_interface_implementation_gen_op_list,
)

# import from paddle/fluid/primitive/code_gen/gen.py
sys.path.append(
str(pathlib.Path(__file__).resolve().parents[3] / 'primitive/codegen')
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不可以使用相对路径导入的方式么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前缺少__init__.py,使用相对路径的话需要新增才可以;
python代码规范里面也提到:

Absolute imports are recommended, as they are usually more readable and tend to be better behaved 
(or at least give better error messages) if the import system is incorrectly configured .

所以使用了绝对路径。

import gen as vjp_gen

# =====================================
# String Template for h file code gen
# =====================================
Expand All @@ -54,6 +62,7 @@
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
#include "paddle/fluid/pir/dialect/operator/trait/inplace.h"
#include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
Expand Down Expand Up @@ -818,6 +827,12 @@ def OpGenerator(
ops_declare_list = [] # all op class declare store in this list
ops_defined_list = [] # all op class defined store in this list
ops_vjp_defined_list = [] # all op vjp static interface defination

# (4) parse name of ops which have custom vjp rules
custom_vjp_op_name_list = []
for custom_vjp in vjp_gen.CUSTOM_VJP:
custom_vjp_op_name_list.append(custom_vjp[:-5]) # cut _grad

for key, op_info in op_info_items.items():
# get op inputs info
op_input_name_list = op_info.input_name_list
Expand Down Expand Up @@ -873,6 +888,10 @@ def OpGenerator(
op_interfaces += ["paddle::dialect::VjpInterface"]
exclusive_interface_str = gen_exclusive_interface_str(op_info)

# if op has custom vjp rule, then append a CustomVjpTrait to it
if op_info.op_phi_name[0] in custom_vjp_op_name_list:
op_traits += ["paddle::dialect::CustomVjpTrait"]

# check op inputs and mutable_attributes grad semantics
input_grad_semantics = get_input_grad_semantic(op_info, op_info_items)
mutable_attribute_grad_semantics = get_mutable_attribute_grad_semantic(
Expand Down
38 changes: 38 additions & 0 deletions paddle/fluid/pir/dialect/operator/trait/custom_vjp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*
Custom VJP stands for manually implemented backward rules for composite
operators. CustomVjpTrait will be added for those composite operators that
defines custom vjp rules. Finally, by calling has_custom_vjp(op), users can
check whether an operator has a CustomVjpTrait, and thus check whether a custom
vjp rule is defined for that operator.
*/

#pragma once

#include "paddle/pir/core/op_base.h"

namespace paddle {
namespace dialect {
class CustomVjpTrait : public pir::OpTraitBase<CustomVjpTrait> {
public:
explicit CustomVjpTrait(pir::Operation *op)
: pir::OpTraitBase<CustomVjpTrait>(op) {}
};

} // namespace dialect
} // namespace paddle

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomVjpTrait)
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/trait/trait.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h"
#include "paddle/fluid/pir/dialect/operator/trait/inplace.h"

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InplaceTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomVjpTrait)
9 changes: 7 additions & 2 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

# import from paddle/fluid/pir/dialect/op_generator/api_gen.py
sys.path.append(
str(pathlib.Path(__file__).resolve().parents[2] / 'ir/dialect/op_generator')
str(pathlib.Path(__file__).resolve().parents[2] / 'pir/dialect/op_generator')
)

# fmt: on
Expand Down Expand Up @@ -67,7 +67,12 @@
'slice_double_grad',
'layer_norm_grad',
]
VJP_COMPS = ['divide_grad', 'sum_grad', 'gelu_grad']


PRIM_VJP = ['divide_grad', 'sum_grad'] # vjp list of primitive op
CUSTOM_VJP = ['gelu_grad'] # custom vjp list of composite op
VJP_COMPS = PRIM_VJP + CUSTOM_VJP

BACKENDS = [
'add_n',
'mean',
Expand Down
16 changes: 16 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ limitations under the License. */
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
#include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h"
#include "paddle/fluid/prim/utils/eager/eager_tensor_operants.h"
#include "paddle/fluid/prim/utils/static/static_tensor_operants.h"
#include "paddle/fluid/pybind/eager_utils.h"
Expand Down Expand Up @@ -751,6 +752,21 @@ void BindVjp(pybind11::module *m) {
if (vjp_interface_impl == nullptr) return false;
return true;
});

m->def(
"has_custom_vjp",
[](pir::Operation &op) -> py::bool_ {
return op.info().HasTrait<paddle::dialect::CustomVjpTrait>();
},
R"DOC(
Return whether an op has custom vjp rules.

Args:
op (pir::Operation): op to be checked

Returns:
out (bool): True means that the op has custom vjp rules, False means it does not.
)DOC");
}
PYBIND11_MODULE(libpaddle, m) {
BindImperative(&m);
Expand Down
65 changes: 65 additions & 0 deletions test/prim/new_ir_prim/test_custom_vjp_trait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
from paddle import ir, nn
from paddle.base.core import has_custom_vjp

paddle.enable_static()


def get_gelu_program_new_ir():
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data('x', [2, 3, 3], dtype='float32')
net = nn.GELU()
out = net(x)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program


def get_multiply_program_new_ir():
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data('x', [2, 3, 3], dtype='float32')
y = paddle.static.data('y', [2, 3, 3], dtype='float32')
out = paddle.multiply(x, y)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program


class TestCustomVjpTrait(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测需要加一个反例

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okk,下个PR加入

def test_gelu_op_custom_vjp_trait(self):
newir_program = get_gelu_program_new_ir()
op = newir_program.global_block().ops[-1]
self.assertEqual(op.name(), "pd_op.gelu")
self.assertEqual(has_custom_vjp(op), True)

def test_multiply_op_custom_vjp_trait(self):
newir_program = get_multiply_program_new_ir()
op = newir_program.global_block().ops[-1]
self.assertEqual(op.name(), "pd_op.multiply")
self.assertEqual(has_custom_vjp(op), False)


if __name__ == "__main__":
unittest.main()