-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel]: Support multi kernels for DistTensor. #57321
[AutoParallel]: Support multi kernels for DistTensor. #57321
Conversation
is not supported now.
@@ -1014,7 +1014,29 @@ def gene_base_api_code(self, inplace_flag=False): | |||
api_func_name += '_' | |||
|
|||
if len(self.kernel['func']) > 1: | |||
# auto parallel branch, all apis contains this branch default | |||
# 1. only works for the ops contains single kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的注释可以整体移到if条件前,只保留一份,有些已经支持了的建议根据实际情况清理一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
@@ -426,6 +426,62 @@ def test_adamax_for_dist_tensor(self): | |||
self.check_tensor_eq(local_inf_norm_out, dist_inf_norm_out) | |||
self.check_tensor_eq(local_master_param_out, dist_master_param_out) | |||
|
|||
# multi kernel functions | |||
def test_adagrad_for_dist_tensor(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后面咱们这个咱们可以把测试API的单独移到一个文件中,比如叫test_api_dist_branch.py什么的,test_dist_tensor的语义范围好像不太能覆盖了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新增了一个test_api_dist_branch.py,把多类型的API生成单测,以及多kernel的单测都迁移过去了
… support_multi_kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里也是,需要确认下新加的单测CI是不是跑到了
收到,单测没跑到的话,我在PR 57293一起改了 |
and not self.api.endswith("_double_grad") | ||
and not self.api.endswith("_triple_grad") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这不止跳过了sparse的貌似,还判断了高阶 微分和自动并行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到,我在PR 57293里改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…7321) * [AutoParallel]: Support multi kernels for DistTensor. Sparse kernel is not supported now. * Polish code with review comments. :wq * Add testcases.
…7321) * [AutoParallel]: Support multi kernels for DistTensor. Sparse kernel is not supported now. * Polish code with review comments. :wq * Add testcases.
PR types
Others
PR changes
Others
Description
Pcard-73145
Support multi kernels for DistTensor. Sparse kernel is not supported now.