Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR & Inference] Add fused_weight_only_linear_pass #59366

Merged
merged 16 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h"
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

inline int getSMVersion() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

调用这个platform::GetGPUComputeCapability(platform::GetCurrentDeviceId())接口

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

int sm_version = 80;
#if defined(PADDLE_WITH_CUDA)
sm_version = paddle::platform::GetGPUComputeCapability(
paddle::platform::GetCurrentDeviceId());
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#else
PADDLE_THROW 抛出当前Paddle没有带上CUDA编译

这样是不是会 友好点

return sm_version;
}

class FusedWeightOnlyLinearPattern
: public pir::drr::DrrPatternBase<FusedWeightOnlyLinearPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
//
// Source Pattern.
//
pir::drr::SourcePattern src = ctx->SourcePattern();
const auto &matmul =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("matmul_transpose_x")},
{"transpose_y", src.Attr("matmul_transpose_y")}});
src.Tensor("matmul_out") = matmul(src.Tensor("x"), src.Tensor("w"));

const auto &add = src.Op("pd_op.add");
src.Tensor("add_out") = add(src.Tensor("matmul_out"), src.Tensor("bias"));

//
// Constraints.
//
src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool {
bool matmul_trans_x = match_ctx.Attr<bool>("matmul_transpose_x");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sm不在支持的那几个里面,约束需要返回false,你的pass不能生效

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

bool matmul_trans_y = match_ctx.Attr<bool>("matmul_transpose_y");
if (matmul_trans_x || matmul_trans_y) return false;

if (!(match_ctx.Tensor("w").Shape().size() == 2 &&
match_ctx.Tensor("x").Shape().size() >= 2 &&
match_ctx.Tensor("bias").Shape().size() == 1)) {
return false;
}

return true;
});
//
// Result Pattern.
//
pir::drr::ResultPattern res = src.ResultPattern();

// quantize weight
const auto &weight_only_int8_attr =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
return "weight_only_int8";
});
int arch = getSMVersion();
const auto &weight_quantize_arch_attr =
res.Attr([&](const pir::drr::MatchContext &match_ctx) -> std::any {
return arch;
});

const auto &weight_quantize = res.Op(
"pd_op.weight_quantize",
{{"algo", weight_only_int8_attr}, {"arch", weight_quantize_arch_attr}});
weight_quantize({&res.Tensor("w")},
{&res.Tensor("quanted_weight_tensor"),
&res.Tensor("weight_scale_tensor")});

const auto &weight_dtype_attr =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
return "int8";
});

const auto &weight_only_linear_arch_attr = res.Attr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的attr就不要重复写,复用前面的是不是更好保持一致?不然后续可能出现漏改的情况

[&](const pir::drr::MatchContext &match_ctx) -> int { return arch; });
const auto &weight_only_linear =
res.Op("pd_op.weight_only_linear",
{{"weight_dtype", weight_dtype_attr},
{"arch", weight_only_linear_arch_attr}});
weight_only_linear({&res.Tensor("x"),
&res.Tensor("quanted_weight_tensor"),
&res.Tensor("bias"),
&res.Tensor("weight_scale_tensor")},
{&res.Tensor("add_out")});
}
};

class FusedWeightOnlyLinearPass : public pir::PatternRewritePass {
public:
FusedWeightOnlyLinearPass()
: pir::PatternRewritePass("fused_weight_only_linear_pass", 2) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opt_level设为4吧,表示比较激进的优化,并且在pass.h里补充下opt_level=4的注释说明
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor Author

@Wanglongzhi2001 Wanglongzhi2001 Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuanlehome 这里我把 opt_level 改成 4 之后单测就跑不过了,我把别的 pass 的 opt_level 增加也都跑不过单测,但是我在 PIR 的源码,代码中我并没有看到 opt_level 影响 pass 执行的逻辑,所以这里先暂时保持 opt_level 为 2 吧

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还是设为4,改一下这里的逻辑
image


pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(FusedWeightOnlyLinearPattern().Build(context));
return ps;
}

bool CanApplyOn(pir::Operation *op) const override {
int sm_vesion = getSMVersion();
if (sm_vesion != 70 || sm_vesion != 80 || sm_vesion != 86 ||
sm_vesion != 75) {
return false;
}
return op->num_regions() > 0;
}

private:
pir::FrozenRewritePatternSet patterns_;
};

} // namespace

namespace pir {
std::unique_ptr<Pass> CreateFusedWeightOnlyLinearPass() {
return std::make_unique<FusedWeightOnlyLinearPass>();
}
} // namespace pir

REGISTER_IR_PASS(fused_weight_only_linear_pass, FusedWeightOnlyLinearPass);
26 changes: 26 additions & 0 deletions paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "paddle/pir/core/dll_decl.h"

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateFusedWeightOnlyLinearPass();

} // namespace pir
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h"
#include "paddle/fluid/pir/transforms/infer_symbolic_shape_pass.h"
#include "paddle/fluid/pir/transforms/inplace_pass.h"
#include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h"
Expand Down Expand Up @@ -96,6 +97,7 @@ USE_PIR_PASS(dead_code_elimination_pass);
USE_PIR_PASS(attention_fuse_pass);
USE_PIR_PASS(fused_gemm_epilogue_pass);
USE_PIR_PASS(fused_dropout_add_pass);
USE_PIR_PASS(fused_weight_only_linear_pass);
USE_PIR_PASS(fused_linear_param_grad_add_pass);
USE_PIR_PASS(inplace_pass);
USE_PIR_PASS(replace_fetch_with_shadow_output_pass);
Expand Down
1 change: 1 addition & 0 deletions paddle/pir/pass/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct PassInfo {
// opt_level=1: constant fold, cse, memory optimize, etc.
// opt_level=2: the fusion logical pass.
// opt_level=3: layout, etc.
// opt_level=4: the radical optimization.
uint8_t opt_level;

// The list which pass depends on.
Expand Down
96 changes: 96 additions & 0 deletions test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from pass_test import PassTest

import paddle
from paddle.base import core

np.random.seed(2013)

import os
import re


def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

覆盖率没到的原因应该是你跳过了,converage-ci的cuda version 是10.2,这个单测你本地能验证通过不?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我本地可以通过单测

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

覆盖率没到的原因应该是你跳过了,converage-ci的cuda version 是10.2,这个单测你本地能验证通过不?

所以要不我还是手动添加一下 CPP 的单测?这个 weight_only_linear 确实是需要 cuda version >=11.2, 改 ci-coverage 对应集群的 cuda version 感觉也不现实。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果能验证通过,converage-ci可以豁免

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果能验证通过,converage-ci可以豁免

好的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看windows-inference ci的结果吧,它的cuda版本是11.2,不用增加cpp单测,如果windows-inference ci跑到了这个单测并通过了,但是coverage不够,coverage-ci可以豁免

or paddle.device.cuda.get_device_capability()[0] < 8,
"weight_only_linear requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class TestFusedWeightOnlyLinearPass_Fp32(PassTest):
def build_ir_progam(self):
with paddle.pir_utils.IrGuard():
self.pir_program = paddle.static.Program()
with paddle.pir.core.program_guard(self.pir_program):
x = paddle.static.data(
name='x', shape=[3, 64, 64], dtype=self.dtype
)
w = paddle.static.data(
name="w", shape=[64, 64], dtype=self.dtype
)
bias_ = paddle.static.data(
name="bias", shape=[64], dtype=self.dtype
)
bias = paddle.assign(bias_)
res1 = paddle.matmul(x=x, y=w)
out = paddle.add(res1, bias)

self.pass_list = ['fused_weight_only_linear_pass']
self.feeds = {
"x": np.random.random((3, 64, 64)).astype(self.dtype),
"w": np.random.random((64, 64)).astype(self.dtype),
"bias": np.random.random(64).astype(self.dtype),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.weight_only_linear": 1,
"pd_op.weight_quantize": 1,
"pd_op.matmul": 0,
"pd_op.add": 0,
}

def setUp(self):
self.place_runtime = "gpu"
self.dtype = 'float32'
self.build_ir_progam()

def test_check_output(self):
self.check_pass_correct()


class TestFusedWeightOnlyLinearPass_Fp16(TestFusedWeightOnlyLinearPass_Fp32):
def setUp(self):
self.place_runtime = "gpu"
self.dtype = 'float16'
self.build_ir_progam()


if __name__ == "__main__":
unittest.main()