-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【pir】Add pir_grad branch for paddle.static.gradient for test #57956
Changes from all commits
bce9b3b
c2341a5
4d30fdd
c94252d
3aa6686
3b3b5ea
cae57c1
d52fe87
9e5a0b1
2218be2
b190b2f
2ce9d92
02040b1
615c487
b48c163
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,8 @@ | |
# limitations under the License. | ||
|
||
|
||
from functools import wraps | ||
|
||
import paddle | ||
|
||
|
||
|
@@ -64,9 +66,16 @@ def _switch_to_pir(self): | |
{"FLAGS_enable_new_ir_in_executor": True} | ||
) | ||
paddle.pir.register_paddle_dialect() | ||
paddle.static.Program = paddle.pir.Program | ||
|
||
paddle.base.Program = paddle.pir.Program | ||
paddle.base.program_guard = paddle.pir.core.program_guard | ||
# paddle.base.default_main_program = ( | ||
# paddle.pir.core.default_main_program | ||
# ) | ||
# paddle.base.default_startup_program = ( | ||
# paddle.pir.core.default_startup_program | ||
# ) | ||
Comment on lines
+72
to
+77
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 没用的注释可以删掉 |
||
paddle.static.Program = paddle.pir.Program | ||
paddle.static.program_guard = paddle.pir.core.program_guard | ||
paddle.static.default_main_program = ( | ||
paddle.pir.core.default_main_program | ||
|
@@ -82,9 +91,14 @@ def _switch_to_old_ir(self): | |
paddle.framework.set_flags( | ||
{"FLAGS_enable_new_ir_in_executor": False} | ||
) | ||
paddle.static.Program = self.old_Program | ||
|
||
paddle.base.Program = self.old_Program | ||
paddle.base.program_guard = self.old_program_guard | ||
# paddle.base.default_main_program = self.old_default_main_program | ||
# paddle.base.default_startup_program = ( | ||
# self.old_default_startup_program | ||
# ) | ||
Comment on lines
+97
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 此处注释为提醒遗留处理项 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议删除或者打开注释,或者是标记加TODO |
||
paddle.static.Program = self.old_Program | ||
paddle.static.program_guard = self.old_program_guard | ||
paddle.static.default_main_program = self.old_default_main_program | ||
paddle.static.default_startup_program = ( | ||
|
@@ -95,3 +109,13 @@ def _switch_to_old_ir(self): | |
"IrGuard._switch_to_old_ir only work when paddle.framework.in_pir_mode() is false, \ | ||
please set FLAGS_enable_pir_api = false" | ||
) | ||
|
||
|
||
def test_with_pir_api(func): | ||
@wraps(func) | ||
def impl(*args, **kwargs): | ||
func(*args, **kwargs) | ||
with IrGuard(): | ||
func(*args, **kwargs) | ||
|
||
return impl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里很奇怪,因为上午有pr已经这里有代码合入了,这里居然没有产生代码冲突?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是拉取同一个pr的代码