Skip to content

Commit

Permalink
refactor: refactor pass and test
Browse files Browse the repository at this point in the history
  • Loading branch information
Wanglongzhi2001 committed Nov 30, 2023
1 parent c65442a commit f603421
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 32 deletions.
27 changes: 15 additions & 12 deletions paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace {

inline int getSMVersion() {
int sm_version = 80;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA)
sm_version = paddle::platform::GetGPUComputeCapability(
paddle::platform::GetCurrentDeviceId());
#endif
Expand Down Expand Up @@ -62,12 +62,6 @@ class FusedWeightOnlyLinearPattern
return false;
}

int sm_vesion = getSMVersion();
if (sm_vesion != 70 || sm_vesion != 80 || sm_vesion != 86 ||
sm_vesion != 75) {
return false;
}

return true;
});
//
Expand All @@ -81,14 +75,14 @@ class FusedWeightOnlyLinearPattern
return "weight_only_int8";
});
int arch = getSMVersion();
const auto &arch_attr =
const auto &weight_quantize_arch_attr =
res.Attr([&](const pir::drr::MatchContext &match_ctx) -> std::any {
return arch;
});

const auto &weight_quantize =
res.Op("pd_op.weight_quantize",
{{"algo", weight_only_int8_attr}, {"arch", arch_attr}});
const auto &weight_quantize = res.Op(
"pd_op.weight_quantize",
{{"algo", weight_only_int8_attr}, {"arch", weight_quantize_arch_attr}});
weight_quantize({&res.Tensor("w")},
{&res.Tensor("quanted_weight_tensor"),
&res.Tensor("weight_scale_tensor")});
Expand All @@ -115,14 +109,23 @@ class FusedWeightOnlyLinearPattern
class FusedWeightOnlyLinearPass : public pir::PatternRewritePass {
public:
FusedWeightOnlyLinearPass()
: pir::PatternRewritePass("fused_weight_only_linear_pass", 4) {}
: pir::PatternRewritePass("fused_weight_only_linear_pass", 2) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(FusedWeightOnlyLinearPattern().Build(context));
return ps;
}

bool CanApplyOn(pir::Operation *op) const override {
int sm_vesion = getSMVersion();
if (sm_vesion != 70 || sm_vesion != 80 || sm_vesion != 86 ||
sm_vesion != 75) {
return false;
}
return op->num_regions() > 0;
}

private:
pir::FrozenRewritePatternSet patterns_;
};
Expand Down
24 changes: 4 additions & 20 deletions test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,6 @@ def get_cuda_version():
"weight_only_linear requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class TestMatmulToWeightOnlyPass_Fp32(PassTest):
@classmethod
def setUpClass(self):
self.main_program = paddle.static.Program()
self.feeds = None
self.fetch_list = None
self.valid_op_map = {}
self.pass_list = []
self.pir_program = None
self.place_runtime = "cpu"
self.dtype = 'float32'

def build_ir_progam(self):
with paddle.pir_utils.IrGuard():
self.pir_program = paddle.static.Program()
Expand Down Expand Up @@ -89,23 +78,18 @@ def build_ir_progam(self):

def setUp(self):
self.place_runtime = "gpu"
self.dtype = 'float32'
self.build_ir_progam()

def test_check_output(self):
self.check_pass_correct()


class TestMatmulToWeightOnlyPass_Fp16(TestMatmulToWeightOnlyPass_Fp32):
@classmethod
def setUpClass(self):
self.main_program = paddle.static.Program()
self.feeds = None
self.fetch_list = None
self.valid_op_map = {}
self.pass_list = []
self.pir_program = None
self.place_runtime = "cpu"
def setUp(self):
self.place_runtime = "gpu"
self.dtype = 'float16'
self.build_ir_progam()


if __name__ == "__main__":
Expand Down

0 comments on commit f603421

Please sign in to comment.