Skip to content

Commit

Permalink
fix: fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Wanglongzhi2001 committed Nov 30, 2023
1 parent 65678d1 commit 32e95c1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 10 deletions.
9 changes: 1 addition & 8 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,8 @@
#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h"
<<<<<<< HEAD
<<<<<<< HEAD
#include "paddle/fluid/pir/transforms/infer_symbolic_shape_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h"
=======
>>>>>>> fix ci
=======
#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h"
>>>>>>> refactor: refactor pass and test
#include "paddle/fluid/pir/transforms/infer_symbolic_shape_pass.h"
#include "paddle/fluid/pir/transforms/inplace_pass.h"
#include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h"
#include "paddle/fluid/pybind/control_flow_api.h"
Expand Down
4 changes: 2 additions & 2 deletions test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_cuda_version():
or paddle.device.cuda.get_device_capability()[0] < 8,
"weight_only_linear requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class TestMatmulToWeightOnlyPass_Fp32(PassTest):
class TestFusedWeightOnlyLinearPass_Fp32(PassTest):
def build_ir_progam(self):
with paddle.pir_utils.IrGuard():
self.pir_program = paddle.static.Program()
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_check_output(self):
self.check_pass_correct()


class TestMatmulToWeightOnlyPass_Fp16(TestMatmulToWeightOnlyPass_Fp32):
class TestFusedWeightOnlyLinearPass_Fp16(TestFusedWeightOnlyLinearPass_Fp32):
def setUp(self):
self.place_runtime = "gpu"
self.dtype = 'float16'
Expand Down

0 comments on commit 32e95c1

Please sign in to comment.