Skip to content

Commit

Permalink
add test for __eq__
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoguoguo626807 committed Aug 1, 2023
1 parent a236d6c commit c9eb1f9
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ PhiKernelInstruction::PhiKernelInstruction(
auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds();
std::unordered_set<::ir::Value> no_need_buffer_values;
for (size_t id = 0; id < no_need_buffer_ids.size(); id++) {
no_need_buffer_values.insert(op->operand(no_need_buffer_ids[id]));
no_need_buffer_values.insert(op->operand_source(no_need_buffer_ids[id]));
}
SetNoNeedBuffer(no_need_buffer_values);
VLOG(6) << "finish process no need buffer";
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,11 @@ void BindOperation(py::module *m) {

void BindValue(py::module *m) {
py::class_<Value> value(*m, "Value");
value.def(
"get_defining_op", &Value::GetDefiningOp, return_value_policy::reference);
value
.def("get_defining_op",
&Value::GetDefiningOp,
return_value_policy::reference)
.def("__eq__", &Value::operator==);
}

void BindOpOperand(py::module *m) {
Expand Down
13 changes: 9 additions & 4 deletions test/ir/new_ir/test_ir_pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def get_ir_program():
x_s = paddle.static.data('x', [4, 4], x.dtype)
x_s.stop_gradient = False
y_s = paddle.matmul(x_s, x_s)
y_s = paddle.add(x_s, y_s)
y_s = paddle.tanh(y_s)
z_s = paddle.add(y_s, y_s)
k_s = paddle.tanh(z_s)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program

Expand Down Expand Up @@ -84,6 +84,12 @@ def test_value(self):
matmul_op.result(0).set_stop_gradient(True)
self.assertEqual(matmul_op.result(0).get_stop_gradient(), True)

result_set = set()
for opresult in matmul_op.results():
result_set.add(opresult)

# self.assertEqual(add_op.operands_source()[0] , matmul_op.results()[0],)

self.assertEqual(
tanh_op.operands()[0].source().get_defining_op().name(), "pd.add"
)
Expand All @@ -94,8 +100,7 @@ def test_value(self):
)

self.assertEqual(
tanh_op.operands()[0].source().get_defining_op(),
tanh_op.operands_source()[0].get_defining_op(),
tanh_op.operands()[0].source(), tanh_op.operands_source()[0]
)
self.assertEqual(add_op.result(0).use_empty(), True)

Expand Down

0 comments on commit c9eb1f9

Please sign in to comment.