-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PIR & Inference] Add fused_weight_only_linear_pass #59366
Changes from all commits
273a310
9623b19
29958e0
5beab51
7584624
0bda343
51e12c7
8592ff5
7ffa712
59a8dad
df4e75d
7d2a390
65678d1
32e95c1
98e6c13
a09533c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h" | ||
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" | ||
#include "paddle/fluid/platform/device/gpu/gpu_info.h" | ||
#include "paddle/fluid/platform/place.h" | ||
#include "paddle/pir/pass/pass.h" | ||
#include "paddle/pir/pass/pass_registry.h" | ||
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" | ||
|
||
namespace { | ||
|
||
inline int getSMVersion() { | ||
int sm_version = 80; | ||
#if defined(PADDLE_WITH_CUDA) | ||
sm_version = paddle::platform::GetGPUComputeCapability( | ||
paddle::platform::GetCurrentDeviceId()); | ||
#endif | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #else 这样是不是会 友好点 |
||
return sm_version; | ||
} | ||
|
||
class FusedWeightOnlyLinearPattern | ||
: public pir::drr::DrrPatternBase<FusedWeightOnlyLinearPattern> { | ||
public: | ||
void operator()(pir::drr::DrrPatternContext *ctx) const override { | ||
// | ||
// Source Pattern. | ||
// | ||
pir::drr::SourcePattern src = ctx->SourcePattern(); | ||
const auto &matmul = | ||
src.Op("pd_op.matmul", | ||
{{"transpose_x", src.Attr("matmul_transpose_x")}, | ||
{"transpose_y", src.Attr("matmul_transpose_y")}}); | ||
src.Tensor("matmul_out") = matmul(src.Tensor("x"), src.Tensor("w")); | ||
|
||
const auto &add = src.Op("pd_op.add"); | ||
src.Tensor("add_out") = add(src.Tensor("matmul_out"), src.Tensor("bias")); | ||
|
||
// | ||
// Constraints. | ||
// | ||
src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { | ||
bool matmul_trans_x = match_ctx.Attr<bool>("matmul_transpose_x"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sm不在支持的那几个里面,约束需要返回false,你的pass不能生效 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
bool matmul_trans_y = match_ctx.Attr<bool>("matmul_transpose_y"); | ||
if (matmul_trans_x || matmul_trans_y) return false; | ||
|
||
if (!(match_ctx.Tensor("w").Shape().size() == 2 && | ||
match_ctx.Tensor("x").Shape().size() >= 2 && | ||
match_ctx.Tensor("bias").Shape().size() == 1)) { | ||
return false; | ||
} | ||
|
||
return true; | ||
}); | ||
// | ||
// Result Pattern. | ||
// | ||
pir::drr::ResultPattern res = src.ResultPattern(); | ||
|
||
// quantize weight | ||
const auto &weight_only_int8_attr = | ||
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { | ||
return "weight_only_int8"; | ||
}); | ||
// int arch = getSMVersion(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 实在不好意思,这里有一个 typo,应该把这里取消注释,然后下面的 80 改成这个 arch,我立马修改一下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 下个PR改~ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,谢谢~ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
const auto &weight_quantize_arch_attr = | ||
res.Attr([&](const pir::drr::MatchContext &match_ctx) -> std::any { | ||
return 80; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 开源版本目前是只有80架构的weightonly linear 但是后面其实是分 70有一个特殊的weightonly,75 80 86 89后用一个weightonly 如果这里hardcode了,我觉得需要加一个注释TODO |
||
}); | ||
|
||
const auto &weight_quantize = res.Op( | ||
"pd_op.weight_quantize", | ||
{{"algo", weight_only_int8_attr}, {"arch", weight_quantize_arch_attr}}); | ||
weight_quantize({&res.Tensor("w")}, | ||
{&res.Tensor("quanted_weight_tensor"), | ||
&res.Tensor("weight_scale_tensor")}); | ||
|
||
const auto &weight_dtype_attr = | ||
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { | ||
return "int8"; | ||
}); | ||
|
||
const auto &weight_only_linear_arch_attr = res.Attr( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的attr就不要重复写,复用前面的是不是更好保持一致?不然后续可能出现漏改的情况 |
||
[&](const pir::drr::MatchContext &match_ctx) -> int { return 80; }); | ||
const auto &weight_only_linear = | ||
res.Op("pd_op.weight_only_linear", | ||
{{"weight_dtype", weight_dtype_attr}, | ||
{"arch", weight_only_linear_arch_attr}}); | ||
weight_only_linear({&res.Tensor("x"), | ||
&res.Tensor("quanted_weight_tensor"), | ||
&res.Tensor("bias"), | ||
&res.Tensor("weight_scale_tensor")}, | ||
{&res.Tensor("add_out")}); | ||
} | ||
}; | ||
|
||
class FusedWeightOnlyLinearPass : public pir::PatternRewritePass { | ||
public: | ||
FusedWeightOnlyLinearPass() | ||
: pir::PatternRewritePass("fused_weight_only_linear_pass", 4) {} | ||
|
||
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { | ||
pir::RewritePatternSet ps(context); | ||
ps.Add(FusedWeightOnlyLinearPattern().Build(context)); | ||
return ps; | ||
} | ||
|
||
bool CanApplyOn(pir::Operation *op) const override { | ||
int sm_vesion = getSMVersion(); | ||
if (sm_vesion != 70 && sm_vesion != 80 && sm_vesion != 86 && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里先只允许80 canapplyon,然后加上一些注释 |
||
sm_vesion != 75) { | ||
return false; | ||
} | ||
return op->num_regions() > 0; | ||
} | ||
|
||
private: | ||
pir::FrozenRewritePatternSet patterns_; | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace pir { | ||
std::unique_ptr<Pass> CreateFusedWeightOnlyLinearPass() { | ||
return std::make_unique<FusedWeightOnlyLinearPass>(); | ||
} | ||
} // namespace pir | ||
|
||
REGISTER_IR_PASS(fused_weight_only_linear_pass, FusedWeightOnlyLinearPass); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include "paddle/pir/core/dll_decl.h" | ||
|
||
namespace pir { | ||
|
||
class Pass; | ||
|
||
IR_API std::unique_ptr<Pass> CreateFusedWeightOnlyLinearPass(); | ||
|
||
} // namespace pir |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
from pass_test import PassTest | ||
|
||
import paddle | ||
from paddle.base import core | ||
|
||
np.random.seed(2013) | ||
|
||
import os | ||
import re | ||
|
||
|
||
def get_cuda_version(): | ||
result = os.popen("nvcc --version").read() | ||
regex = r'release (\S+),' | ||
match = re.search(regex, result) | ||
if match: | ||
num = str(match.group(1)) | ||
integer, decimal = num.split('.') | ||
return int(integer) * 1000 + int(float(decimal) * 10) | ||
else: | ||
return -1 | ||
|
||
|
||
@unittest.skipIf( | ||
not core.is_compiled_with_cuda() | ||
or get_cuda_version() < 11020 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 覆盖率没到的原因应该是你跳过了,converage-ci的cuda version 是10.2,这个单测你本地能验证通过不? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我本地可以通过单测 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
所以要不我还是手动添加一下 CPP 的单测?这个 weight_only_linear 确实是需要 cuda version >=11.2, 改 ci-coverage 对应集群的 cuda version 感觉也不现实。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果能验证通过,converage-ci可以豁免 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
好的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 看windows-inference ci的结果吧,它的cuda版本是11.2,不用增加cpp单测,如果windows-inference ci跑到了这个单测并通过了,但是coverage不够,coverage-ci可以豁免 |
||
or paddle.device.cuda.get_device_capability()[0] < 8, | ||
"weight_only_linear requires CUDA >= 11.2 and CUDA_ARCH >= 8", | ||
) | ||
class TestFusedWeightOnlyLinearPass_Fp32(PassTest): | ||
def build_ir_progam(self): | ||
with paddle.pir_utils.IrGuard(): | ||
self.pir_program = paddle.static.Program() | ||
with paddle.pir.core.program_guard(self.pir_program): | ||
x = paddle.static.data( | ||
name='x', shape=[3, 64, 64], dtype=self.dtype | ||
) | ||
w = paddle.static.data( | ||
name="w", shape=[64, 64], dtype=self.dtype | ||
) | ||
bias_ = paddle.static.data( | ||
name="bias", shape=[64], dtype=self.dtype | ||
) | ||
bias = paddle.assign(bias_) | ||
res1 = paddle.matmul(x=x, y=w) | ||
out = paddle.add(res1, bias) | ||
|
||
self.pass_list = ['fused_weight_only_linear_pass'] | ||
self.feeds = { | ||
"x": np.random.random((3, 64, 64)).astype(self.dtype), | ||
"w": np.random.random((64, 64)).astype(self.dtype), | ||
"bias": np.random.random(64).astype(self.dtype), | ||
} | ||
self.fetch_list = [out] | ||
self.valid_op_map = { | ||
"pd_op.weight_only_linear": 1, | ||
"pd_op.weight_quantize": 1, | ||
"pd_op.matmul": 0, | ||
"pd_op.add": 0, | ||
} | ||
|
||
def setUp(self): | ||
self.place_runtime = "gpu" | ||
self.dtype = 'float32' | ||
self.build_ir_progam() | ||
|
||
def test_check_output(self): | ||
self.check_pass_correct() | ||
|
||
|
||
class TestFusedWeightOnlyLinearPass_Fp16(TestFusedWeightOnlyLinearPass_Fp32): | ||
def setUp(self): | ||
self.place_runtime = "gpu" | ||
self.dtype = 'float16' | ||
self.build_ir_progam() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
调用这个platform::GetGPUComputeCapability(platform::GetCurrentDeviceId())接口
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.