From f3bbb032fbf4b8ef692396a69d0da99c1416fcad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Thu, 6 Jul 2023 10:54:48 +0800 Subject: [PATCH] [CINN] Re-Implement operator = for two Expr Tree (#55145) * optimize expr operator = implementation * fix codestyle --- paddle/cinn/hlir/framework/op_lowering_util.cc | 1 + paddle/cinn/ir/ir_visitor.cc | 5 +++-- paddle/fluid/framework/paddle2cinn/cinn_compiler.cc | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/paddle/cinn/hlir/framework/op_lowering_util.cc b/paddle/cinn/hlir/framework/op_lowering_util.cc index 06ec448820362..77443cc86d025 100644 --- a/paddle/cinn/hlir/framework/op_lowering_util.cc +++ b/paddle/cinn/hlir/framework/op_lowering_util.cc @@ -825,6 +825,7 @@ bool CanbeInline(Node* node, } auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + for (auto consumer : consumers) { if (op_pattern_dict[consumer->op()] == framework::kReduction) { return false; diff --git a/paddle/cinn/ir/ir_visitor.cc b/paddle/cinn/ir/ir_visitor.cc index 83090fc9e75d6..50d81b839bc41 100644 --- a/paddle/cinn/ir/ir_visitor.cc +++ b/paddle/cinn/ir/ir_visitor.cc @@ -16,6 +16,7 @@ #include +#include "paddle/cinn/ir/ir_compare.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/utils/string.h" @@ -25,8 +26,8 @@ namespace ir { bool operator==(Expr a, Expr b) { if (a.get() == b.get()) return true; - // TODO(Superjomn) implement with a more accurate one - return utils::GetStreamCnt(a) == utils::GetStreamCnt(b); + IrEqualVisitor cmp; + return cmp.Compare(a, b); } bool operator!=(Expr a, Expr b) { return !(a == b); } diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index 8af88595fd06d..4424b75cef179 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -112,6 +112,7 @@ const CinnCompiledObject &CinnCompiler::Compile( auto compiled_res = CompileGraph(graph, input_tensors, target, compiled_num, stream); + std::unique_lock guard(lock_); // double check cache_by_struct_ if (!cache_by_struct_.count(cur_key_by_struct)) {