Skip to content

Commit

Permalink
fix test and rename pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Wanglongzhi2001 committed Nov 26, 2023
1 parent 7e217dc commit bddc1f7
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/matmul_to_weight_only_linear_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h"
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

class MatmulToWeightOnlyLinearPattern
: public pir::drr::DrrPatternBase<MatmulToWeightOnlyLinearPattern> {
class FusedWeightOnlyLinearPattern
: public pir::drr::DrrPatternBase<FusedWeightOnlyLinearPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
//
Expand Down Expand Up @@ -65,7 +65,7 @@ class MatmulToWeightOnlyLinearPattern
});

const auto &arch_attr = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> std::any { return 86; });
[](const pir::drr::MatchContext &match_ctx) -> std::any { return 80; });

const auto &weight_quantize =
res.Op("pd_op.weight_quantize",
Expand All @@ -80,30 +80,26 @@ class MatmulToWeightOnlyLinearPattern
});

const auto &weight_only_linear_arch_attr = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> int { return 86; });
[](const pir::drr::MatchContext &match_ctx) -> int { return 80; });
const auto &weight_only_linear =
res.Op("pd_op.weight_only_linear",
{{"weight_dtype", weight_dtype_attr},
{"arch", weight_only_linear_arch_attr}});
weight_only_linear(
{
&res.Tensor("x"),
&res.Tensor("quanted_weight_tensor"),
&res.Tensor("bias"),
&res.Tensor("weight_scale_tensor"),
},
{&res.Tensor("add_out")});
weight_only_linear({&res.Tensor("x"),
&res.Tensor("quanted_weight_tensor"),
&res.Tensor("bias"),
&res.Tensor("weight_scale_tensor")},
{&res.Tensor("add_out")});
}
};

class MatmulToWeightOnlyLinearPass : public pir::Pass {
class FusedWeightOnlyLinearPass : public pir::Pass {
public:
MatmulToWeightOnlyLinearPass()
: pir::Pass("matmul_to_weight_only_linear_pass", 2) {}
FusedWeightOnlyLinearPass() : pir::Pass("fused_weight_only_linear_pass", 2) {}

bool Initialize(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(MatmulToWeightOnlyLinearPattern().Build(context));
ps.Add(FusedWeightOnlyLinearPattern().Build(context));
patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
return true;
}
Expand All @@ -126,10 +122,9 @@ class MatmulToWeightOnlyLinearPass : public pir::Pass {
} // namespace

namespace pir {
std::unique_ptr<Pass> CreateMatmulToWeightOnlyLinearPass() {
return std::make_unique<MatmulToWeightOnlyLinearPass>();
std::unique_ptr<Pass> CreateFusedWeightOnlyLinearPass() {
return std::make_unique<FusedWeightOnlyLinearPass>();
}
} // namespace pir

REGISTER_IR_PASS(matmul_to_weight_only_linear_pass,
MatmulToWeightOnlyLinearPass);
REGISTER_IR_PASS(fused_weight_only_linear_pass, FusedWeightOnlyLinearPass);
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateMatmulToWeightOnlyLinearPass();
IR_API std::unique_ptr<Pass> CreateFusedWeightOnlyLinearPass();

} // namespace pir
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h"
#include "paddle/fluid/pir/transforms/inplace_pass.h"
#include "paddle/fluid/pir/transforms/matmul_to_weight_only_linear_pass.h"
#include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h"
#include "paddle/fluid/pybind/control_flow_api.h"
#include "paddle/phi/core/enforce.h"
Expand Down Expand Up @@ -92,7 +92,7 @@ USE_PIR_PASS(dead_code_elimination_pass);
USE_PIR_PASS(attention_fuse_pass);
USE_PIR_PASS(fused_gemm_epilogue_pass);
USE_PIR_PASS(fused_dropout_add_pass);
USE_PIR_PASS(matmul_to_weight_only_linear_pass);
USE_PIR_PASS(fused_weight_only_linear_pass);
USE_PIR_PASS(fused_linear_param_grad_add_pass);
USE_PIR_PASS(inplace_pass);
USE_PIR_PASS(replace_fetch_with_shadow_output_pass);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_cuda_version():


@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
core.is_compiled_with_cuda() and get_cuda_version() < 11020,
"weight_only_linear needs CUDA version greater than 11.2",
)
class TestMatmulToWeightOnly(unittest.TestCase):
Expand Down Expand Up @@ -74,9 +74,7 @@ def test_matmul_to_weight_only(self):
fetch_list=[res2],
)
pm = paddle.pir.PassManager()
pm.add_pass(
'matmul_to_weight_only_linear_pass'
) # apply pass to elimitate dead code
pm.add_pass('fused_weight_only_linear_pass')
pm.run(main_program)
op_names = [op.name() for op in main_program.global_block().ops]
self.assertTrue('pd_op.weight_only_linear' in op_names)
Expand Down

0 comments on commit bddc1f7

Please sign in to comment.