-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No.91】 #52948
【Hackathon No.91】 #52948
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
通过将函数转为 AST 改写代码使得 动转静 下的 register_hook 能成功运行。 之后还可以提升的地方
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
非常感谢您的贡献。这里方案有几个问题待商榷:
- PR支持了被to_static装饰下使用register_hook的场景,对于继承了nn.Layer的class,如果register_hook函数是self.__init__函数中调用的,则被@to_static装饰的forward函数里registerc_hook是否会正确触发呢?
python/paddle/jit/dy2static/utils.py
Outdated
@@ -650,6 +650,10 @@ def func_to_source_code(function, dedent=True): | |||
for line in source_code_list | |||
] | |||
source_code = ''.join(source_code_list) | |||
# check the 'register hook' in the source code | |||
if 'register_hook' in source_code: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是一个AST层面的变换,推荐写成一个单独的AstTransformer。因为source_code层面的str匹配是有风险的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用 NodeVisitor 来进行遍历修改节点
@@ -38,3 +41,69 @@ def pretty_source(source): | |||
|
|||
source_code = astor.to_source(ast_node, pretty_source=pretty_source) | |||
return source_code | |||
|
|||
|
|||
def modify_function_code(func, code_str='register_hook'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
几点建议哈:
- 建议将其以AstTransformer形式接入进来,可以参考dy2stat目录下其他Transformer
- 注释建议使用英文
- 移除不必要的 print等注释代码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释改为英文、移除 print 等不必要的注释代码
python/paddle/fluid/framework.py
Outdated
"""do nothing but return a new variable.""" | ||
return x | ||
|
||
# class HookRemoveHelper: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议移除注释代码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
是类似之前的这个 RFC 提到的 LinearNet 的例子吗,是可以被正确的触发的。 |
是的,RFC里是对params注册了hook,是否可以在test/unittest/dygraph_to_static/ 目录下添加一个test_tensor_hook.py,丰富下不同使用场景的单测单元case? |
增加了 test_tensor_hook.py 的测试代码,设定了几种形式,包括了在 Layer 的 forward 中对参数 register_hook,这里我把 varbase_patch_method 中 monkey_patch_varbase 中 register_hook 的 dygraph_only 装饰器注释了,发现代码都可以直接运行。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for overall
|
||
def backward_hook_wrapper(dy): | ||
"""call the backward hook in .""" | ||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里在函数内添加import numpy ,是因为此函数会在py_func里执行,且用户可能没有在模型代码里添加import ?
@@ -38,3 +41,84 @@ def pretty_source(source): | |||
|
|||
source_code = astor.to_source(ast_node, pretty_source=pretty_source) | |||
return source_code | |||
|
|||
|
|||
class RegisterHookVisitor(gast.NodeVisitor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如下 comment 可以考虑下个PR做优化。
- ast_utils.py 为公共文件,所有上层的Ast Transformer变换逻辑按照规范是独立一个文件的。
- 所有的AST 变换建议都继承 BaseTransformer,并提供
def transform(self):
方法,因为报错栈回溯逻辑是在基类BaseTransformer中实现的。
func_def.body = new_body | ||
|
||
|
||
def modify_function_code(func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同样下个PR可以优化:
- 动转静的AST Transfomer 逻辑是放在
ast_transformer.py
统一生效的,里面有一个list,会逐个应用生效
if dedent: | ||
source_code = textwrap.dedent(source_code) | ||
# return modified function source code if there is 'register_hook', otherwise return None | ||
source_code = modify_function_code(function) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如前面的comment,ast_to_func函数的职责只负责借助 AST 生成python module,因此建议下个PR可否将此处逻辑放到 ast_transformer.py
loss_jit = jit_layer(image_jit) | ||
loss_jit.backward() | ||
loss.backward() | ||
self.assertTrue( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
另外从框架内部单测规范上,我们更建议使用类似 np.testing.assert_allclose()
函数,而非self.assertTrue(xxx),详见:https://github.com/PaddlePaddle/community/blob/master/rfcs/CodeStyle/20220805_code_style_improvement_for_unittest.md#background
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for framework.py
备注:黑客松91题还未完成,仍需要PR优化 |
Hi @Aurelius84, check this new PR #53572 for updates. |
PR types
Others
PR changes
Others
Description
register_hook for static mode