From 46b0ef369040f694f8479a63c11b9f7bceecec90 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 3 Nov 2023 08:50:55 +0000 Subject: [PATCH 1/3] allow pir::Program dynamically add attribute --- paddle/fluid/pybind/pir.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 3b2976deb1f88..b5adfe76f9162 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -119,7 +119,8 @@ std::string GetValueInfo(Value v) { } void BindProgram(py::module *m) { - py::class_> program(*m, "Program", R"DOC( + py::class_> program( + *m, "Program", py::dynamic_attr(), R"DOC( Create Python Program. Program is an abstraction of model structure, divided into computational graphs and weights. The Program has a main block that stores the computational graphs. From d71d13fab4c9ec1f60ca7f35124981072190a2c3 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Fri, 3 Nov 2023 09:31:03 +0000 Subject: [PATCH 2/3] add seed for pir::Program --- python/paddle/__init__.py | 3 ++- python/paddle/pir/__init__.py | 1 + python/paddle/pir/program_patch.py | 34 ++++++++++++++++++++++++++++++ test/ir/new_ir/test_ir_pybind.py | 7 ++++++ 4 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 python/paddle/pir/program_patch.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 9c6484a1d4611..92ac2fcbb5c34 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -33,11 +33,12 @@ # the illogical implement in the monkey-patch methods later. from .framework import monkey_patch_variable from .framework import monkey_patch_math_tensor -from .pir import monkey_patch_opresult +from .pir import monkey_patch_opresult, monkey_patch_program monkey_patch_variable() monkey_patch_math_tensor() monkey_patch_opresult() +monkey_patch_program() from .framework import ( disable_signal_handler, diff --git a/python/paddle/pir/__init__.py b/python/paddle/pir/__init__.py index 145eb103918bf..81a398339dfb5 100644 --- a/python/paddle/pir/__init__.py +++ b/python/paddle/pir/__init__.py @@ -38,5 +38,6 @@ from . import core from .math_op_patch import monkey_patch_opresult +from .program_patch import monkey_patch_program __all__ = [] diff --git a/python/paddle/pir/program_patch.py b/python/paddle/pir/program_patch.py new file mode 100644 index 0000000000000..4de46a647259a --- /dev/null +++ b/python/paddle/pir/program_patch.py @@ -0,0 +1,34 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import Program + +_already_patch_program = False + +global_prog_seed = 0 + + +def monkey_patch_program(): + def global_seed(self, seed=0): + global global_prog_seed + global_prog_seed = seed + self._seed = global_prog_seed + + Program.global_seed = global_seed + global global_prog_seed + Program._seed = global_prog_seed + + global _already_patch_program + if not _already_patch_program: + _already_patch_program = True diff --git a/test/ir/new_ir/test_ir_pybind.py b/test/ir/new_ir/test_ir_pybind.py index d1a0e1de1f878..202e5f2dbe903 100644 --- a/test/ir/new_ir/test_ir_pybind.py +++ b/test/ir/new_ir/test_ir_pybind.py @@ -198,6 +198,13 @@ def test_get_output_intermediate_value(self): results = unsqueeze_op.get_output_intermediate_value() self.assertEqual(results, [False, True]) + def test_prog_seed(self): + p = pir.Program() + self.assertEqual(p._seed, 0) + + p.global_seed(10) + self.assertEqual(p._seed, 10) + if __name__ == "__main__": unittest.main() From 09dccac345867c1efef31e66676e577d20d7f262 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 6 Nov 2023 02:53:33 +0000 Subject: [PATCH 3/3] polish code --- paddle/fluid/ir_adaptor/translator/op_translator.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 763e8db830bc1..813cb34826eb2 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -2395,6 +2395,7 @@ struct RepeatInterLeaveGradOpTranscriber : public OpTranscriber { return op_inputs; } }; + OpTranslator::OpTranslator() { pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect();