Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Refactor pass api of group fusion in CINN #55090

Merged
merged 14 commits into from
Jul 13, 2023

Conversation

zyfncg
Copy link
Contributor

@zyfncg zyfncg commented Jul 3, 2023

PR types

New features

PR changes

Others

Description

Pcard-72384

重新设计了CINN中关于group融合的pass api接口,以提升group融合相关优化策略开发的灵活性,并降低开发和维护成本。

一、Group 融合 API 接口现状

1. 当前 Group 融合流程:

FusionMergePass:将OpFusionPass初步融合后的Group做进一步的融合处理
image

(1) 融合规则:

当前Group融合策略主要为垂直融合与水平融合(Recompute可以等价于是垂直融合的一种特殊处理情况)

  • 垂直融合:根据Group的OpPatternKind等信息判断Producer是否可以与Consumer融合(Recompute是将Producer同时融合到多个Consumer中)
  • 水平融合:根据Group的OpPatternKind等信息判断Producer的不同Consumer之间是否可融合

(2) 融合方式:

  • 对Graph上的所有Group进行遍历并且以贪心的方式进行融合,融合后更新Group继续尝试进行融合,直到图中的Group不再更新

2. 现有流程的主要问题

策略与机制耦合较深,导致开发者的学习和代码修改成本都偏高

  1. 开发者在写Group融合策略时既需要学习策略的编写方式,也需要学习Group的整个融合流程,以确保策略能够按照预想的方式执行。修改代码时也是同样的情况。
  2. 开发者在对Group的可融合关系进行判断时会直接操作Node、NodeData等底层数据接口,同样引入了较多需要了解的概念和接口。
  3. 新增融合策略时需要修改现有逻辑代码,容易影响到现有执行流程,不符合软件开发的开闭原则(开发新功能时 , 尽量不修改原有的代码 , 尽量使用扩展增加新功能)

二、Group 融合API接口重构方案

1. 重构后流程(核心理念:用户策略与内部机制解耦)

image

(1) 融合规则:(用户开发接口)

通过用户实现Pass标记Group间是否可融合,初步可开发Recompute,Vertical,Horizontal Fuse 等 Pass 对齐现有融合规则

  1. 用户通过Pass对Group是否可融合进行标记时,仅能指定 两个Group 之间是否可融合 或者 某个Group 是否需要Recompute,无法一次性标记超过两个以上的Group间融合关系。
  2. Recompute策略和Vertical融合会解耦处理,即Recompute会单独通过Pass和Merge进行处理,不再依赖Vertical Fuse进行。

(2) 融合方式:(框架内部实现)

  1. 遍历Group并通过用户Pass对Group的可融合情况进行Tag标记(不实际融合)
  2. 完成标记后启动FuseMerge对标记为可融合的Group进行融合操作
  3. 重复过程 1、2,直到Group不再更新

2. API接口设计(开发者视角)

1. API接口示例

以HorizontalFuse和RecomputeFuse为例,用户开发Pass时主要需要实现如下代码,需要开发者关注的数据结构和接口主要包括:FusePassCtx, OpGroup,OpNode, 以及fuse_helper中提供的一些辅助工具类接口(如: DetectCycleIfFuse 环检测), 其他底层的数据结构和接口用户无需了解。

class DefaultHorizontalFusePass final : public HorizontalFusePass {
 public:
  void operator()(LightwareFusePassCtx* ctx) const override {
    // 选取当前要处理的Group
    const auto& producer        = ctx->PickOpGroup();
    const OpGroupList consumers = producer->Consumers();
    // 跳过处理
    if (consumers.size() <= 1) {
      return;
    }
    // 遍历 Consumer 对
    for (int i = 0; i < consumers.size(); ++i) {
      const auto& src = consumers.at(i);
      for (int j = i + 1; j < consumers.size(); ++j) {
        const auto& dst = consumers.at(j);
        // 成环检测
        if (ctx->fuse_helper().DetectCycleIfFuse(src, dst)) {
          continue;
        }
        // 可融合判断(根据OpGroup提供的各项信息,自行定义可融合规则)
        if (!HorizontalFuseUtil<LightwareFusePassCtx>::DetectFusabilityByKind(ctx, src, dst)) {
          continue;
        }
        // 标记为可融合
        ctx->EnableFuse(src, dst);
        return;
      }
    }
  }
};

class DefaultRecomputeFusePass final : public RecomputeFusePass {
 public:
  void operator()(LightwareFusePassCtx* ctx) const override {
    // 选取节点
    const auto& producer        = ctx->PickOpGroup();
    const OpGroupList consumers = producer->Consumers();
    
    // 融合判断策略
    if (consumers.size() <= 1) {
      return;
    }
    std::vector<OpGroupPtr> candidates;
    for (int i = 0; i < consumers.size(); ++i) {
      const auto& consumer = consumers.at(i);
      if (!DetectFusabilityByKind(ctx, producer, consumer)) {
        continue;
      }
      candidates.push_back(consumer);
    }
    if (candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) {
      for (const auto& consumer : consumers) {
        // 标记可融合
        ctx->EnableFuse(producer, consumer);
      }
    }
  }
};

2. API接口原语

Class method description
OpGroup kind() Get the Kind of group
producers() Get producer groups of current group
consumers() Get consumer groups of current group
WalkOpNodes(const std::function<void(const OpNode&)>& VisitOpNode) Visit the op_nodes in the group and execute the VisitOpNode function for each OpNode
OpNode kind() Get the Kind of op_node
inputs() Get input tensors of op_node
outputs() Get output tensors of op_node
GetAttr(const std::string& attr_name) Get attribute of op_node by attr name
TensorNode shape() Get shape of tensor
producer() Get the producer op_node of tensor
consumers() Get the consumer op_nodes of tensor
Shape numel() Get total number of elements in the shape
other methods are same with std::vector<int64_t>
LightwareFusePassCtx PickOpGroup() Get the current group in the pass context
void EnableFuse(const OpGroup& first, const OpGroup& second) Mark the two groups which can fuse togather
fuse_helper() Get the fuse_helper provided by pass context
InputFusePassCtx PickConsumersWithSameInputs() Get all consumer groups for input tensors of graph
void EnableFuse(const OpGroup& first, const OpGroup& second) Mark the two groups which can fuse togather
fuse_helper() Get the fuse_helper provided by pass context
FuseHelper DetectCycleIfFuse(const OpGroup& first, const OpGroup& second) Whether there is cycle in graph after fusing two groups

TODO

  1. group 融合稳定性增强:目前group层由于大量无序容器的使用导致group的融合顺序存在一定的随机性,不同的融合顺序可能会有一定的性能差异,因此我们需要尝试使用有序容器来将这些group的访问和融合顺序固定下来。
  2. fuse_helper辅助接口替换:为了降低初版PR合入的风险,我们通过fuse_helper包装了原来的一些融合判断策略接口,这些接口后续都需要使用新的api接口进行重新实现。
  3. general_fusion_merge_pass文件拆分:general_fusion_merge_pass内包含了多个的模块内容,后续需要进行拆分简化,提高代码的易读性和维护性。
  4. shared_ptr的循环引用消除。

@paddle-bot
Copy link

paddle-bot bot commented Jul 3, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

jiahy0825
jiahy0825 previously approved these changes Jul 5, 2023
Copy link
Contributor

@jiahy0825 jiahy0825 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM,TODO 项点已在对应代码做了标识~
在基于 CINN 仓库的分支上,我们仍在持续迭代,希望尽快推动本 PR 合入,完成 CINN 代码向 Paddle 主仓库的迁移~


// breadth-first search visitor
template <typename NodeType>
class BfsVisitor final {
Copy link
Contributor

@jiahy0825 jiahy0825 Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: 此处为了语义的精准,后续会将 BfsVisitor 命名更改为 BfsWalker

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


// depth-first search visitor
template <typename NodeType>
class DfsVisitor final {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DfsVisitor, SccVisitor, TopoVisitor 同上,均会修改为 Walker

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Comment on lines +184 to +193
// input groups
std::unordered_set<std::shared_ptr<Group>,
SharedGroupHasher,
SharedGroupComparator>
producer_groups_;
// output grous
std::unordered_set<std::shared_ptr<Group>,
SharedGroupHasher,
SharedGroupComparator>
consumer_groups_;
Copy link
Contributor

@jiahy0825 jiahy0825 Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: 后续此处会替换成有序容器,保证 Group 融合的稳定性

}
};

class FusePass {
Copy link
Contributor

@jiahy0825 jiahy0825 Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO:本文件中的所有 FusePass class 拆分成文件,方便外部用户参考

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

};
}

static bool IsSameSize(FusePassCtxT* ctx,
Copy link
Contributor

@jiahy0825 jiahy0825 Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO:此处对原来的实现做了封装,后续会用新的 API 改写

Comment on lines +744 to +755
std::vector<OpGroupPtr> candidates;
for (int i = 0; i < consumers.size(); ++i) {
const auto& consumer = consumers.at(i);
if (!DetectFusabilityByKind(ctx, producer, consumer)) {
break;
}
candidates.push_back(consumer);
}
if (candidates.size() == consumers.size() &&
producer.kind() == framework::kElementWise) {
return;
}
Copy link
Contributor

@jiahy0825 jiahy0825 Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO(长期TODO):为了对齐原有 pass 逻辑,在多个 Pass 内部有相对 Trick 的代码,随着后续的不断迭代,删除这些Trick

@jiahy0825
Copy link
Contributor

正确性与性能的验证结果

Resnet

  • 正确性对齐
    生成子图完全一致,收敛曲线重叠
image
  • 性能对齐
    可能与机器和配置有关,未达到周报中测出的最优 ips 值。但是相同机器和配置下,两者性能差距 1.4%,可被认为是误差。可以等本 PR 合入后,依赖 CE 进一步测试是否达到最优性能。
    origin:2163.85
    modified:2133.11

Bert

  • 正确性对齐
    收敛曲线几乎重叠
image
  • 性能对齐
    但是相同机器和配置下,两者性能差距 2.2%,可被认为是误差,并且几乎达到周报中的最优 ips 值(1385)。
    origin:1366.36
    modified:1336.82

namespace cinn {
namespace api {

using Comparator = hlir::framework::Graph::Group::SharedGroupComparator;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using别名直接放头文件里不大好吧,移到OpGroup类里或者可以改为GraphGroupComparator是不是更好

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx~

return !(*this == other);
}

OpGroup operator*() const { return OpGroup(*iter_); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

迭代器应该也支持->操作吧?要不要加上?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

->需要返回对象指针,但由于这里的迭代器都是返回临时创建的封装对象,如果返回对象指针的话对象的析构会是个问题。
目前迭代器主要是用来支持对返回的容器对象做遍历,缺少->接口暂时不影响该功能的使用,后续数据结构升级时可以再加上


using const_iterator = OpGroupListIterator;

size_t size() const { return group_.lock()->producer_groups().size(); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

个人感觉对lock返回值加个CHECK,检查是否为空会比现在这样交给STL报异常好吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx~


size_t size() const { return group_.lock()->producer_groups().size(); }

const_iterator begin() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和STL保持一致,命名为cbegin更好吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

STL中目前也可以通过begin函数的const标记来区分迭代器的类型

struct hash<cinn::api::OpGroup> {
size_t operator()(const cinn::api::OpGroup& obj) const {
return std::hash<int64_t>()(
reinterpret_cast<uint64_t>(obj.GetGroup().get()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接hash指针是可以的吧,相比还要reinterpret_cast的优势是可以避免在非64位机器上(比如某些边缘设备)上运行时不会报错。不用在意机器的位数。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看了下std::hash<T*>内部的实现是通过reinterpret_cast转为size_t,这里就直接转为size_t类型吧


std::vector<OpGroupList> fusionable_consumers;
for (auto& candidate : consumer_candidates) {
if (ctx->fuse_helper().IsConsumerSetsReachable(candidate,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果可达说明该consumer与其它group存在连接关系,因此不属于HorizontalFuse范围,只能VerticalFuse?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的

fusionable_consumers.push_back({candidate});
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前面这一大段代码好像跟DefaultInputFusePass一模一样啊,只是多了后面的trick代码?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还有其他的一些区别:

  1. 函数开头从 ctx 中拿的类型不同。Horizontal Fuse 取的是 const auto& producer = ctx->PickOpGroup(); 对 producer 的所有 consumer group 做处理;Input Fuse 取的是 const auto& consumer_set = ctx->PickConsumersWithSameInputs(); 没有 producer 可作为锚点
  2. fusion_merge_pass.cc 中做 Input Fuse( fusion_merge_pass.cc 里调用的实际是 FuseInputToConsumers 函数) 和 Horizontal Fuse 的时机不同,从这个角度看,也需要拆分成 Input Fuse 和 Horizontal Fuse 两个 Pass。目前这两个 Pass 共享了很多重复代码,后续可能对这两种 Pass 有不同的处理。

break;
}
candidates.push_back(consumer);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

candidates单纯只是为了做判断?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这里是为了对齐原有 pass 逻辑,不得已在 Pass 内部写的相对 Trick 的代码

public:
DefaultRecomputeFusePass() : RecomputeFusePass() {}

int Benefit() const override { return 100; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是为cost model预留的接口么?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是为用户预留的优先级接口,用户可以添加自己的 Fusion Merge Pass,会根据用户设定的 Benefit 值作为优先级。

比如说在执行 RecomputeFuse 时,用户新增了一个 CustomRecomputeFuse,优先级设置为 101,那么执行顺序是 CustomRecomputeFuse (benefit = 101), DefaultRecomputeFuse (benefit = 100)

public:
virtual ~FusePassCtx() {}

virtual void EnableFuse(const OpGroup& first, const OpGroup& second) = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是叫MarkFusible更合适?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MarkFusible 这个命名挺好的 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

赞同,已替换为MarkFusible

@jiahy0825
Copy link
Contributor

为了便于理解 general_fusion_merge_pass 中各个类的关系,可以参考下面的 UML 设计图。

用户新增 Fuse 策略时,可参考标黄的数据结构。

PassUML

@PaddlePaddle PaddlePaddle locked and limited conversation to collaborators Jul 10, 2023
@PaddlePaddle PaddlePaddle unlocked this conversation Jul 10, 2023
@PaddlePaddle PaddlePaddle locked and limited conversation to collaborators Jul 10, 2023
@PaddlePaddle PaddlePaddle unlocked this conversation Jul 10, 2023
Copy link
Contributor

@thisjiang thisjiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zyfncg zyfncg merged commit c80bf36 into PaddlePaddle:develop Jul 13, 2023
@zyfncg zyfncg deleted the cinn_pass_api branch July 13, 2023 04:58
cqulilujia pushed a commit to cqulilujia/Paddle that referenced this pull request Jul 24, 2023
* new group fuse pass api

* fix header

* update

* change logic of get master node to fix bug

* revert update for ReduceFuseReduce

* modify according review

* modify by review

* refine

* update

* fix code-format
wz1qqx pushed a commit to wz1qqx/Paddle that referenced this pull request Jul 31, 2023
* new group fuse pass api

* fix header

* update

* change logic of get master node to fix bug

* revert update for ReduceFuseReduce

* modify according review

* modify by review

* refine

* update

* fix code-format
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants