Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Custom op support auto parallel #58553

Conversation

wanghuancoder
Copy link
Contributor

@wanghuancoder wanghuancoder commented Nov 1, 2023

PR types

Others

PR changes

Others

Description

自定义算子支持AutoParallel
Pcard-73145

@wanghuancoder wanghuancoder changed the title Custom op support auto parallel [AutoParallel] Custom op support auto parallel Nov 1, 2023
@paddle-bot paddle-bot bot added the contributor External developers label Nov 1, 2023
@paddle-bot paddle-bot bot removed the contributor External developers label Nov 3, 2023
GhostScreaming
GhostScreaming previously approved these changes Nov 3, 2023
Copy link
Contributor

@GhostScreaming GhostScreaming left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 60 to 62
'dist_custom_relu_op.cc',
'dist_custom_relu_op_dup.cc',
'dist_custom_relu_op.cu',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个 kernel 文件和原来的实现有区别吗?是不是可以复用原来的 kernel 代码,应该可以通过相对路径找到

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,我试试~感谢!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我试了,可以。感谢!!

Comment on lines -127 to +128
const std::vector<std::pair<size_t, size_t>>& InputRange();
const std::vector<std::pair<size_t, size_t>>& OutputRange();
const std::vector<std::pair<size_t, size_t>>& InputRange() const;
const std::vector<std::pair<size_t, size_t>>& OutputRange() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢欢哥对这里的 const 保障

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

客气啦~


std::vector<Tensor>* all_inputs = ctx.AllMutableInput();

#ifdef PADDLE_WITH_DISTRIBUTE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好把 #ifdef PADDLE_WITH_DISTRIBUTE 里的代码独立出来一个函数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

啊,这个是为了和PHI api.cc里的格式保持一致~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉还是有必要独立出来,此处分布式和自定义算子的耦合太重了,如果后续需要调试,预计这里的维护成本会很高

Comment on lines 454 to 457
void run_custom_op_impl(paddle::OpMetaInfo op_info,
bool is_forward,
bool is_double_grad,
paddle::CustomOpKernelContext& ctx) { // NOLINT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于写入的变量,在自定义算子的体系里,要改成指针。
paddle::CustomOpKernelContext* ctx
paddle::OpMetaInfo op_info 也做了写入吧?也要改成指针形式

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我把paddle::OpMetaInfo op_info改成const paddle::OpMetaInfo& op_info吧~
ctx在run_custom_op_impl还要做修改,但指针我觉得大量使用不方便。所以加了NOLINT。

} else {
for (size_t j = pair.first; j < pair.second; j++) {
*(ctx.MutableOutputAt(j)) = BuildEmptyDistPaddleTensor(
current_process_mesh, out_dim[0], out_dtype[0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里下标是不是有问题,应该修改为 j ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢感谢!不然是个大坑!


std::vector<Tensor>* all_inputs = ctx.AllMutableInput();

#ifdef PADDLE_WITH_DISTRIBUTE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉还是有必要独立出来,此处分布式和自定义算子的耦合太重了,如果后续需要调试,预计这里的维护成本会很高

jiahy0825
jiahy0825 previously approved these changes Nov 6, 2023
Copy link
Contributor

@jiahy0825 jiahy0825 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM~ Great work!

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@GhostScreaming GhostScreaming left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

current_process_mesh =
paddle::holds_alternative<phi::distributed::TensorDistAttr>(
spmd_info.first[0])
? paddle::get<0>(spmd_info.first[0]).process_mesh()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle::get直接使用需要加try catch,或者直接使用PADDLE_GET系列宏,不然一旦报错就很难分析,后面建议再完善一下

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wanghuancoder wanghuancoder merged commit fe862dd into PaddlePaddle:develop Nov 7, 2023
28 checks passed
jiahy0825 pushed a commit to jiahy0825/Paddle that referenced this pull request Nov 7, 2023
zeroRains pushed a commit to zeroRains/Paddle that referenced this pull request Nov 8, 2023
"have the same mesh.",
input.name()));
} else {
PADDLE_ENFORCE_EQ(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我这边一个带可选input自定义算子的执行会出现问题,定位到这里应该是input.impl().get()出来了一个空指针,请问这里有考虑自定义算子的可选Input吗

danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
SecretXV pushed a commit to SecretXV/Paddle that referenced this pull request Nov 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants