-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CINN] Refactor pass api of group fusion in CINN #55090
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM,TODO 项点已在对应代码做了标识~
在基于 CINN 仓库的分支上,我们仍在持续迭代,希望尽快推动本 PR 合入,完成 CINN 代码向 Paddle 主仓库的迁移~
paddle/cinn/common/bfs_visitor.h
Outdated
|
||
// breadth-first search visitor | ||
template <typename NodeType> | ||
class BfsVisitor final { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: 此处为了语义的精准,后续会将 BfsVisitor 命名更改为 BfsWalker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
paddle/cinn/common/dfs_visitor.h
Outdated
|
||
// depth-first search visitor | ||
template <typename NodeType> | ||
class DfsVisitor final { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DfsVisitor, SccVisitor, TopoVisitor 同上,均会修改为 Walker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
// input groups | ||
std::unordered_set<std::shared_ptr<Group>, | ||
SharedGroupHasher, | ||
SharedGroupComparator> | ||
producer_groups_; | ||
// output grous | ||
std::unordered_set<std::shared_ptr<Group>, | ||
SharedGroupHasher, | ||
SharedGroupComparator> | ||
consumer_groups_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: 后续此处会替换成有序容器,保证 Group 融合的稳定性
} | ||
}; | ||
|
||
class FusePass { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO:本文件中的所有 FusePass class 拆分成文件,方便外部用户参考
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
}; | ||
} | ||
|
||
static bool IsSameSize(FusePassCtxT* ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO:此处对原来的实现做了封装,后续会用新的 API 改写
std::vector<OpGroupPtr> candidates; | ||
for (int i = 0; i < consumers.size(); ++i) { | ||
const auto& consumer = consumers.at(i); | ||
if (!DetectFusabilityByKind(ctx, producer, consumer)) { | ||
break; | ||
} | ||
candidates.push_back(consumer); | ||
} | ||
if (candidates.size() == consumers.size() && | ||
producer.kind() == framework::kElementWise) { | ||
return; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO(长期TODO):为了对齐原有 pass 逻辑,在多个 Pass 内部有相对 Trick 的代码,随着后续的不断迭代,删除这些Trick
paddle/cinn/api/op_group.h
Outdated
namespace cinn { | ||
namespace api { | ||
|
||
using Comparator = hlir::framework::Graph::Group::SharedGroupComparator; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using别名直接放头文件里不大好吧,移到OpGroup
类里或者可以改为GraphGroupComparator
是不是更好
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx~
return !(*this == other); | ||
} | ||
|
||
OpGroup operator*() const { return OpGroup(*iter_); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
迭代器应该也支持->
操作吧?要不要加上?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
->
需要返回对象指针,但由于这里的迭代器都是返回临时创建的封装对象,如果返回对象指针的话对象的析构会是个问题。
目前迭代器主要是用来支持对返回的容器对象做遍历,缺少->
接口暂时不影响该功能的使用,后续数据结构升级时可以再加上
paddle/cinn/api/op_group.h
Outdated
|
||
using const_iterator = OpGroupListIterator; | ||
|
||
size_t size() const { return group_.lock()->producer_groups().size(); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
个人感觉对lock返回值加个CHECK
,检查是否为空会比现在这样交给STL报异常好吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx~
|
||
size_t size() const { return group_.lock()->producer_groups().size(); } | ||
|
||
const_iterator begin() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和STL保持一致,命名为cbegin
更好吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle/cinn/api/op_group.h
Outdated
struct hash<cinn::api::OpGroup> { | ||
size_t operator()(const cinn::api::OpGroup& obj) const { | ||
return std::hash<int64_t>()( | ||
reinterpret_cast<uint64_t>(obj.GetGroup().get())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接hash指针是可以的吧,相比还要reinterpret_cast
的优势是可以避免在非64位机器上(比如某些边缘设备)上运行时不会报错。不用在意机器的位数。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看了下std::hash<T*>内部的实现是通过reinterpret_cast
转为size_t
,这里就直接转为size_t
类型吧
|
||
std::vector<OpGroupList> fusionable_consumers; | ||
for (auto& candidate : consumer_candidates) { | ||
if (ctx->fuse_helper().IsConsumerSetsReachable(candidate, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果可达说明该consumer与其它group存在连接关系,因此不属于HorizontalFuse范围,只能VerticalFuse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的
fusionable_consumers.push_back({candidate}); | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
前面这一大段代码好像跟DefaultInputFusePass一模一样啊,只是多了后面的trick代码?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还有其他的一些区别:
- 函数开头从 ctx 中拿的类型不同。Horizontal Fuse 取的是
const auto& producer = ctx->PickOpGroup();
对 producer 的所有 consumer group 做处理;Input Fuse 取的是const auto& consumer_set = ctx->PickConsumersWithSameInputs();
没有 producer 可作为锚点 - 在
fusion_merge_pass.cc
中做 Input Fuse(fusion_merge_pass.cc
里调用的实际是FuseInputToConsumers
函数) 和 Horizontal Fuse 的时机不同,从这个角度看,也需要拆分成 Input Fuse 和 Horizontal Fuse 两个 Pass。目前这两个 Pass 共享了很多重复代码,后续可能对这两种 Pass 有不同的处理。
break; | ||
} | ||
candidates.push_back(consumer); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这candidates
单纯只是为了做判断?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,这里是为了对齐原有 pass 逻辑,不得已在 Pass 内部写的相对 Trick 的代码
public: | ||
DefaultRecomputeFusePass() : RecomputeFusePass() {} | ||
|
||
int Benefit() const override { return 100; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是为cost model预留的接口么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是为用户预留的优先级接口,用户可以添加自己的 Fusion Merge Pass,会根据用户设定的 Benefit 值作为优先级。
比如说在执行 RecomputeFuse 时,用户新增了一个 CustomRecomputeFuse,优先级设置为 101,那么执行顺序是 CustomRecomputeFuse (benefit = 101), DefaultRecomputeFuse (benefit = 100)
paddle/cinn/api/fuse_pass_context.h
Outdated
public: | ||
virtual ~FusePassCtx() {} | ||
|
||
virtual void EnableFuse(const OpGroup& first, const OpGroup& second) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是不是叫MarkFusible
更合适?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MarkFusible
这个命名挺好的 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
赞同,已替换为MarkFusible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* new group fuse pass api * fix header * update * change logic of get master node to fix bug * revert update for ReduceFuseReduce * modify according review * modify by review * refine * update * fix code-format
* new group fuse pass api * fix header * update * change logic of get master node to fix bug * revert update for ReduceFuseReduce * modify according review * modify by review * refine * update * fix code-format
PR types
New features
PR changes
Others
Description
Pcard-72384
重新设计了CINN中关于group融合的pass api接口,以提升group融合相关优化策略开发的灵活性,并降低开发和维护成本。
一、Group 融合 API 接口现状
1. 当前 Group 融合流程:
FusionMergePass:将OpFusionPass初步融合后的Group做进一步的融合处理
(1) 融合规则:
当前Group融合策略主要为垂直融合与水平融合(Recompute可以等价于是垂直融合的一种特殊处理情况)
(2) 融合方式:
2. 现有流程的主要问题
策略与机制耦合较深,导致开发者的学习和代码修改成本都偏高
二、Group 融合API接口重构方案
1. 重构后流程(核心理念:用户策略与内部机制解耦)
(1) 融合规则:(用户开发接口)
通过用户实现Pass标记Group间是否可融合,初步可开发Recompute,Vertical,Horizontal Fuse 等 Pass 对齐现有融合规则
(2) 融合方式:(框架内部实现)
2. API接口设计(开发者视角)
1. API接口示例
以HorizontalFuse和RecomputeFuse为例,用户开发Pass时主要需要实现如下代码,需要开发者关注的数据结构和接口主要包括:FusePassCtx, OpGroup,OpNode, 以及fuse_helper中提供的一些辅助工具类接口(如: DetectCycleIfFuse 环检测), 其他底层的数据结构和接口用户无需了解。
2. API接口原语
TODO
general_fusion_merge_pass
文件拆分:general_fusion_merge_pass内包含了多个的模块内容,后续需要进行拆分简化,提高代码的易读性和维护性。