-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PIR & Inference] Add fused_weight_only_linear_pass #59366
[PIR & Inference] Add fused_weight_only_linear_pass #59366
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
bddc1f7
to
267f5f1
Compare
@@ -0,0 +1,135 @@ | |||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass实现移到当前的fusion目录下
} | ||
}; | ||
|
||
class MatmulToWeightOnlyLinearPass : public pir::Pass { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
继承pir::PatternRewritePass来实现
void Run(pir::Operation *op) override { | ||
pir::GreedyRewriteConfig cfg; | ||
cfg.use_top_down_traversal = true; | ||
cfg.max_iterations = 10; | ||
pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); | ||
} | ||
|
||
bool CanApplyOn(pir::Operation *op) const override { | ||
return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
继承pir::PatternRewritePass来实现后,这两个接口就不需要了
@@ -0,0 +1,86 @@ | |||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测验证参考test/ir/pir/fused_pass/test_conv2d_fuse_pass.py这个写
@@ -1138,3 +1139,75 @@ TEST(constant_folding, ConstantFolding_Combine) { | |||
CHECK_EQ(pm.Run(&program), true); | |||
// EXPECT_EQ(program.block()->size(), 6u); | |||
} | |||
|
|||
void BuildWeightOnlyLinearProgram(pir::Program *program, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
写py单测就可以了,这个cpp单测给删掉吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我当前代码的情况下不写cpp单侧的话,ci coverage 的 cpp coverage 过不了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
采用新的comment来实现,ci-converage到时候跑完没过的话,我看一下,理论上不应该存在不被覆盖的代码
src.Tensor("add_out") = add(src.Tensor("matmul_out"), src.Tensor("bias")); | ||
|
||
// | ||
// Constraints. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
}); | ||
|
||
const auto &weight_only_linear_arch_attr = res.Attr( | ||
[](const pir::drr::MatchContext &match_ctx) -> int { return 80; }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前只有80架构能支持吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不是的,只是因为我看这个算子的这个参数的默认值是80,我就也写80了
Paddle/paddle/phi/api/yaml/ops.yaml
Lines 2821 to 2830 in 32af85e
- op : weight_only_linear | |
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch = 80) | |
output : Tensor(out) | |
infer_meta : | |
func : WeightOnlyLinearInferMeta | |
kernel : | |
func : weight_only_linear | |
data_type : x | |
optional: bias | |
backward: weight_only_linear_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这的逻辑确实不大对,已经修改成识别到的当前架构了,和 python api 的 arch 为默认值 none 时的情况保持一致
paddle::framework::Scope *scope) { | ||
pir::Builder builder = pir::Builder(ctx, program->block()); | ||
|
||
pir::Type fp32_dtype = pir::Float32Type::get(ctx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测应该也补充float16的情况
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
namespace { | ||
|
||
inline int getSMVersion() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
调用这个platform::GetGPUComputeCapability(platform::GetCurrentDeviceId())接口
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
// Constraints. | ||
// | ||
src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { | ||
bool matmul_trans_x = match_ctx.Attr<bool>("matmul_transpose_x"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sm不在支持的那几个里面,约束需要返回false,你的pass不能生效
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
class FusedWeightOnlyLinearPass : public pir::PatternRewritePass { | ||
public: | ||
FusedWeightOnlyLinearPass() | ||
: pir::PatternRewritePass("fused_weight_only_linear_pass", 2) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yuanlehome 这里我把 opt_level 改成 4 之后单测就跑不过了,我把别的 pass 的 opt_level 增加也都跑不过单测,但是我在 PIR 的源码,代码中我并没有看到 opt_level 影响 pass 执行的逻辑,所以这里先暂时保持 opt_level 为 2 吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yuanlehome 您好,修改过后 cpp coverage 没过 |
int sm_vesion = getSMVersion(); | ||
if (sm_vesion != 70 || sm_vesion != 80 || sm_vesion != 86 || | ||
sm_vesion != 75) { | ||
return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不好意思哈,我突然想起来,这个应该放在CanApplyOn接口里,这个接口是专门来限制PASS应用范围的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
至于那个coverage ci,等ci跑完,你别着急提commit,我看下是哪些代码行没覆盖到~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yuanlehome 刘哥,ci 跑完了,你看看 ci coverage ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的看到了,覆盖率检测应该是有点问题,明天我看看修复一下~
f603421
to
32e95c1
Compare
|
||
@unittest.skipIf( | ||
not core.is_compiled_with_cuda() | ||
or get_cuda_version() < 11020 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
覆盖率没到的原因应该是你跳过了,converage-ci的cuda version 是10.2,这个单测你本地能验证通过不?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我本地可以通过单测
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
覆盖率没到的原因应该是你跳过了,converage-ci的cuda version 是10.2,这个单测你本地能验证通过不?
所以要不我还是手动添加一下 CPP 的单测?这个 weight_only_linear 确实是需要 cuda version >=11.2, 改 ci-coverage 对应集群的 cuda version 感觉也不现实。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果能验证通过,converage-ci可以豁免
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果能验证通过,converage-ci可以豁免
好的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看windows-inference ci的结果吧,它的cuda版本是11.2,不用增加cpp单测,如果windows-inference ci跑到了这个单测并通过了,但是coverage不够,coverage-ci可以豁免
@yuanlehome ci已经通过了,麻烦看一下。不过本地因为机器显存不够的原因,暂时还没有在 paddlenlp 的weight only的llama等大模型的基础上进行推理实测我这个pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
我理解你在任何一个具有matmul op的模型上同样可以验证这个pass |
好的谢谢,我原本是想打算在大模型的基础上测试一下加速的效果 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
如果由于机器原因,可以让导师来验证一下~ |
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { | ||
return "weight_only_int8"; | ||
}); | ||
// int arch = getSMVersion(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
实在不好意思,这里有一个 typo,应该把这里取消注释,然后下面的 80 改成这个 arch,我立马修改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下个PR改~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,谢谢~
#if defined(PADDLE_WITH_CUDA) | ||
sm_version = paddle::platform::GetGPUComputeCapability( | ||
paddle::platform::GetCurrentDeviceId()); | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#else
PADDLE_THROW 抛出当前Paddle没有带上CUDA编译
这样是不是会 友好点
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { | ||
return "weight_only_int8"; | ||
}); | ||
// int arch = getSMVersion(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
// int arch = getSMVersion(); | ||
const auto &weight_quantize_arch_attr = | ||
res.Attr([&](const pir::drr::MatchContext &match_ctx) -> std::any { | ||
return 80; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
开源版本目前是只有80架构的weightonly linear
但是后面其实是分 70有一个特殊的weightonly,75 80 86 89后用一个weightonly
如果这里hardcode了,我觉得需要加一个注释TODO
|
||
bool CanApplyOn(pir::Operation *op) const override { | ||
int sm_vesion = getSMVersion(); | ||
if (sm_vesion != 70 && sm_vesion != 80 && sm_vesion != 86 && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里先只允许80 canapplyon,然后加上一些注释
return "int8"; | ||
}); | ||
|
||
const auto &weight_only_linear_arch_attr = res.Attr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的attr就不要重复写,复用前面的是不是更好保持一致?不然后续可能出现漏改的情况
新增的comment下个PR统一修改下哈~,这个先合入了 |
好的 |
下一PR描述里,记得引用下这个PR,标明是补充实现~(最好12.10号前提PR并合入) |
好的 |
* [Inference]Add matmul_to_weight_only_linear_pass * fix test and rename pass * fix the comment of test * fix ci * fix: fix test * refactor: refactor pass and test * refactor: refactor pass * refactor: add fp16 test * refactor: refactor pass * refactor: refactor the opt_level * fix: fix typo * fix: fix ci compile error when without gpu * refactor: refactor pass and test * fix: fix conflict * fix: fix conflict * refactor: refactor opt_level in pass_test to 4
PR types
New features
PR changes
APIs
Description
添加将 matmul 算子转换成 weight_only_linear 算子的 PIR 的 Pass