-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel] Custom op support auto parallel #58553
[AutoParallel] Custom op support auto parallel #58553
Conversation
… custom_op_support_auto_parallel
… custom_op_support_auto_parallel
… custom_op_support_auto_parallel
… custom_op_support_auto_parallel
… custom_op_support_auto_parallel
… custom_op_support_auto_parallel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
'dist_custom_relu_op.cc', | ||
'dist_custom_relu_op_dup.cc', | ||
'dist_custom_relu_op.cu', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几个 kernel 文件和原来的实现有区别吗?是不是可以复用原来的 kernel 代码,应该可以通过相对路径找到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,我试试~感谢!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我试了,可以。感谢!!
const std::vector<std::pair<size_t, size_t>>& InputRange(); | ||
const std::vector<std::pair<size_t, size_t>>& OutputRange(); | ||
const std::vector<std::pair<size_t, size_t>>& InputRange() const; | ||
const std::vector<std::pair<size_t, size_t>>& OutputRange() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢欢哥对这里的 const 保障
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
客气啦~
|
||
std::vector<Tensor>* all_inputs = ctx.AllMutableInput(); | ||
|
||
#ifdef PADDLE_WITH_DISTRIBUTE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最好把 #ifdef PADDLE_WITH_DISTRIBUTE
里的代码独立出来一个函数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
啊,这个是为了和PHI api.cc里的格式保持一致~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉还是有必要独立出来,此处分布式和自定义算子的耦合太重了,如果后续需要调试,预计这里的维护成本会很高
void run_custom_op_impl(paddle::OpMetaInfo op_info, | ||
bool is_forward, | ||
bool is_double_grad, | ||
paddle::CustomOpKernelContext& ctx) { // NOLINT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对于写入的变量,在自定义算子的体系里,要改成指针。
paddle::CustomOpKernelContext* ctx
paddle::OpMetaInfo op_info 也做了写入吧?也要改成指针形式
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我把paddle::OpMetaInfo op_info改成const paddle::OpMetaInfo& op_info吧~
ctx在run_custom_op_impl还要做修改,但指针我觉得大量使用不方便。所以加了NOLINT。
} else { | ||
for (size_t j = pair.first; j < pair.second; j++) { | ||
*(ctx.MutableOutputAt(j)) = BuildEmptyDistPaddleTensor( | ||
current_process_mesh, out_dim[0], out_dtype[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里下标是不是有问题,应该修改为 j ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢感谢!不然是个大坑!
|
||
std::vector<Tensor>* all_inputs = ctx.AllMutableInput(); | ||
|
||
#ifdef PADDLE_WITH_DISTRIBUTE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉还是有必要独立出来,此处分布式和自定义算子的耦合太重了,如果后续需要调试,预计这里的维护成本会很高
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM~ Great work!
… custom_op_support_auto_parallel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
current_process_mesh = | ||
paddle::holds_alternative<phi::distributed::TensorDistAttr>( | ||
spmd_info.first[0]) | ||
? paddle::get<0>(spmd_info.first[0]).process_mesh() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle::get直接使用需要加try catch,或者直接使用PADDLE_GET系列宏,不然一旦报错就很难分析,后面建议再完善一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* custom op support auto parallel
* custom op support auto parallel
"have the same mesh.", | ||
input.name())); | ||
} else { | ||
PADDLE_ENFORCE_EQ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我这边一个带可选input自定义算子的执行会出现问题,定位到这里应该是input.impl().get()出来了一个空指针,请问这里有考虑自定义算子的可选Input吗
* custom op support auto parallel
* custom op support auto parallel
PR types
Others
PR changes
Others
Description
自定义算子支持AutoParallel
Pcard-73145