Skip to content

Commit

Permalink
Generic dispatching in Visitor (apache#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuchenJin committed Nov 17, 2022
1 parent a0b8528 commit bafde35
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 112 deletions.
78 changes: 56 additions & 22 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,19 @@ class ExprFunctor<R(const Expr& n, Args...)> {
}
};


/*!
* \brief A simple visitor wrapper around ExprFunctor.
* Recursively visit the content.
*/
class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
class ExprVisitor : public ExprFunctor<void(const Expr&)> {
public:
/*!
* \brief Generic dispatcher for Expr.
* \param expr The expr to be visited.
*/
void VisitExpr(const Expr& expr) override;
// specific leaf level visitor functions
void VisitExpr_(const ConstantNode* op) override;
void VisitExpr_(const TupleNode* op) override;
void VisitExpr_(const VarNode* op) override;
Expand All @@ -157,13 +163,36 @@ class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;

virtual void VisitType(const Type& t);
virtual void VisitSpan(const Span& span);
/*!
* \brief Generic dispatcher for bindings.
* \param binding The binding to be visited.
*/
virtual void VisitBinding(const Binding& binding);
virtual void VisitVarBinding(const VarBinding& binding);
virtual void VisitMatchShape(const MatchShape& binding);
// specific leaf level visitor functions
virtual void VisitBinding_(const VarBindingNode* binding);
virtual void VisitBinding_(const MatchShapeNode* binding);

/*!
* \brief Generic dispatcher for binding blocks.
* \param block The binding block to be visited.
*/
virtual void VisitBindingBlock(const BindingBlock& block);
virtual void VisitDataflowBlock(const DataflowBlock& block);
// specific leaf level visitor functions
virtual void VisitBindingBlock_(const BindingBlockNode* block);
virtual void VisitBindingBlock_(const DataflowBlockNode* block);

/*!
* \brief Generic dispatcher for visiting the var definition site.
* \param var The var to be visited.
* \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
*/
virtual void VisitVarDef(const Var& var);
// specific leaf level visitor functions
virtual void VisitVarDef_(const VarNode* var);
virtual void VisitVarDef_(const DataflowVarNode* var);

virtual void VisitType(const Type& t);
virtual void VisitSpan(const Span& span);
};

void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
Expand Down Expand Up @@ -205,20 +234,35 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
*/
virtual Type VisitType(const Type& t);

/*!
* \brief Generic dispatcher for bindings.
* \param binding The binding to be visited.
*/
virtual void VisitBinding(const Binding& binding);
virtual void VisitVarBinding(const VarBinding& binding);
virtual void VisitMatchShape(const MatchShape& binding);
// specific leaf level visitor functions
virtual void VisitBinding_(const VarBindingNode* binding);
virtual void VisitBinding_(const MatchShapeNode* binding);

/*!
* \brief Generic dispatcher for binding blocks.
* \param block The binding block to be visited.
* \return The binding block after transformation.
*/
virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
// specific leaf level visitor functions
virtual BindingBlock VisitBindingBlock_(const BindingBlockNode* block);
virtual BindingBlock VisitBindingBlock_(const DataflowBlockNode* block);

/*!
* \brief Rewrite the var definition site.
* \brief Generic dispatcher for rewriting the var definition site.
* \param var The var to be visited.
* \return The var after post-order rewritten.
* \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
*/
virtual Var VisitVarDef(const Var& var);

virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);
// specific leaf level visitor functions
virtual Var VisitVarDef_(const VarNode* var);
virtual Var VisitVarDef_(const DataflowVarNode* var);

protected:
class ExprNormalizer;
Expand Down Expand Up @@ -265,16 +309,6 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
};

// TODO(@yuchen, @altan): Refactor to enforce dataflow mutator only rewrite stuff in dataflow blocks
/*! \brief Dataflow Graph Rewriting for Custom Rewriting Passes
*/
class DataflowMutator : public ExprMutator {
public:
void VisitBinding(const Binding& binding) final;

virtual void VisitDataflowVarBinding(const VarBinding& binding);
};

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_EXPR_FUNCTOR_H_
2 changes: 1 addition & 1 deletion src/relax/backend/vm/vm_shape_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class VMShapeLowerMutator : public ExprMutator {
return ret_mod_;
}

void VisitMatchShape(const MatchShape& binding) override {
void VisitBinding_(const MatchShapeNode* binding) override {
Expr shape = ExprMutator::VisitExpr(binding->value);
static const Op& store_shape_op = Op::Get("relax.vm.builtin.store_shape");
auto store_shape_attr = make_object<ShapeHeapAttrs>();
Expand Down
Loading

0 comments on commit bafde35

Please sign in to comment.