Skip to content

Commit

Permalink
【CINN】refactor ir_visitor (PaddlePaddle#55171)
Browse files Browse the repository at this point in the history
This PR delete middle ir_visitor class and thus we can avoid middle virtual function call and codes look more clean
pcard-72718
  • Loading branch information
Courtesy-Xs authored and cqulilujia committed Jul 24, 2023
1 parent 8e67259 commit 016aa4b
Show file tree
Hide file tree
Showing 15 changed files with 126 additions and 104 deletions.
4 changes: 3 additions & 1 deletion paddle/cinn/auto_schedule/cost_model/feature_extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ using namespace ::cinn::ir; // NOLINT

FeatureExtractor::FeatureExtractor() {}

void FeatureExtractor::Visit(const Expr *x) { IRVisitor::Visit(x); }
void FeatureExtractor::Visit(const Expr *x) {
IRVisitorRequireReImpl::Visit(x);
}

Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr,
const common::Target &target) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/auto_schedule/cost_model/feature_extractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
namespace cinn {
namespace auto_schedule {

class FeatureExtractor : public ir::IRVisitor {
class FeatureExtractor : public ir::IRVisitorRequireReImpl<void> {
public:
FeatureExtractor();
Feature Extract(const ir::ModuleExpr& mod_expr, const common::Target& target);
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/auto_schedule/search_space/search_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ bool operator<(const SearchState& left, const SearchState& right) {
}

// Visit every node by expanding all of their fields in dfs order
class DfsWithExprsFields : public ir::IRVisitor {
class DfsWithExprsFields : public ir::IRVisitorRequireReImpl<void> {
protected:
#define __m(t__) \
void Visit(const ir::t__* x) override { \
Expand All @@ -85,7 +85,7 @@ class DfsWithExprsFields : public ir::IRVisitor {
NODETY_FORALL(__m)
#undef __m

void Visit(const Expr* expr) override { IRVisitor::Visit(expr); }
void Visit(const Expr* expr) override { IRVisitorRequireReImpl::Visit(expr); }
};

// Generate a reduce hash of a AST tree by combining hash of each AST node
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/backends/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@
namespace cinn {
namespace backends {

class LLVMIRVisitor : public ir::IRVisitorBase<llvm::Value *> {
class LLVMIRVisitor : public ir::IRVisitorRequireReImpl<llvm::Value *> {
public:
LLVMIRVisitor() = default;

using ir::IRVisitorBase<llvm::Value *>::Visit;
using ir::IRVisitorRequireReImpl<llvm::Value *>::Visit;
#define __m(t__) virtual llvm::Value *Visit(const ir::t__ *x) = 0;
NODETY_FORALL(__m)
#undef __m
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/backends/modular.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
namespace cinn {
namespace backends {

class ModularEvaluator : public ir::IRVisitorBase<ModularEntry> {
class ModularEvaluator : public ir::IRVisitorRequireReImpl<ModularEntry> {
public:
explicit ModularEvaluator(const std::map<Var, ModularEntry>& mod_map)
: mod_map_(mod_map) {}

ModularEntry Eval(const Expr& e) {
return ir::IRVisitorBase<ModularEntry>::Visit(&e);
return ir::IRVisitorRequireReImpl<ModularEntry>::Visit(&e);
}

ModularEntry Visit(const ir::IntImm* op) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/ir/collect_ir_nodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace ir {

namespace {

struct IrNodesCollector : public IRVisitor {
struct IrNodesCollector : public IRVisitorRequireReImpl<void> {
using teller_t = std::function<bool(const Expr*)>;
using handler_t = std::function<void(const Expr*)>;

Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/ir/ir_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) {
VLOG(5) << "Not equal on Expr, someone not defined";
}
bool equal = lhs->node_type() == rhs->node_type();
equal = equal && IRVisitorBase<bool, const Expr*>::Visit(&lhs, &rhs);
equal = equal && IRVisitorRequireReImpl<bool, const Expr*>::Visit(&lhs, &rhs);

if (!equal) {
VLOG(5) << "Not equal on Expr, lhs:[type:"
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/ir/ir_compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace ir {

// Determine whether two ir AST trees are euqal by comparing their struct and
// fields of each node through dfs visitor
class IrEqualVisitor : public IRVisitorBase<bool, const Expr*> {
class IrEqualVisitor : public IRVisitorRequireReImpl<bool, const Expr*> {
public:
explicit IrEqualVisitor(bool allow_name_suffix_diff = false)
: allow_name_suffix_diff_(allow_name_suffix_diff) {}
Expand Down
Loading

0 comments on commit 016aa4b

Please sign in to comment.