From 7babba86194c46cb4badff815d3881aea234bbca Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 13 Apr 2018 13:55:47 -0700 Subject: [PATCH] Add implementation of alpha equivalence (#35) * Refactor base types to reuse HalideIR::Type * Tweak BaseType constructors * Add start on alpha_eq for expressions * Fix free_vars.h * Add alpha_eq implementation and tests * Restore reverse_ad * Add test skeleton * Refactor type_visitor.h * Update free_type_vars.cc * Stub type alpha-equivalence * Fix style in AlphaEq * Fill in unfinished test cases * Get half of the tests passing * More changes, only 13 failures left * Fix a few more bugs * Fix Tensor and Product cases * Fix cast case * Fix test cases * Fix lint * Fix remaining test cases --- relay/include/relay/alpha_eq.h | 19 ++ relay/include/relay/expr_functor.h | 1 + relay/include/relay/expr_visitor.h | 97 +++++++-- relay/include/relay/free_type_vars.h | 6 - relay/include/relay/free_vars.h | 2 +- relay/include/relay/node.h | 94 ++++---- relay/include/relay/reverse_ad.h | 32 +-- relay/include/relay/type_functor.h | 46 ++-- relay/include/relay/type_visitor.h | 16 +- relay/python/relay/expr.py | 20 +- relay/python/relay/make.py | 17 ++ relay/python/relay/relay.py | 2 +- relay/python/relay/visitor.py | 2 +- relay/src/relay/alpha_eq.cc | 309 +++++++++++++++++++++++++++ relay/src/relay/expr_visitor.cc | 84 -------- relay/src/relay/forward_ad.cc | 2 +- relay/src/relay/free_type_vars.cc | 12 +- relay/src/relay/node.cc | 74 ++++--- relay/src/relay/reverse_ad.cc | 32 +-- relay/src/relay/typechecker.cc | 2 +- 20 files changed, 601 insertions(+), 268 deletions(-) create mode 100644 relay/include/relay/alpha_eq.h create mode 100644 relay/src/relay/alpha_eq.cc delete mode 100644 relay/src/relay/expr_visitor.cc diff --git a/relay/include/relay/alpha_eq.h b/relay/include/relay/alpha_eq.h new file mode 100644 index 0000000000000..da7a8834a23e6 --- /dev/null +++ b/relay/include/relay/alpha_eq.h @@ -0,0 +1,19 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file alpha_eq.h + * \brief Check expressions for structural equivalence. + */ +#ifndef NNVM_RELAY_ALPHA_EQ_H_ +#define NNVM_RELAY_ALPHA_EQ_H_ + +#include "node.h" + +namespace nnvm { +namespace relay { + +bool alpha_eq(const Expr & e1, const Expr & e2); +bool alpha_eq(const Type & t1, const Type & t2); + +} // namespace relay +} // namespace nnvm +#endif // NNVM_RELAY_ALPHA_EQ_H_ diff --git a/relay/include/relay/expr_functor.h b/relay/include/relay/expr_functor.h index 8dac4d2dfbd23..b566db9ee186e 100644 --- a/relay/include/relay/expr_functor.h +++ b/relay/include/relay/expr_functor.h @@ -153,6 +153,7 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode); RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode); RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); + RELAY_EXPR_FUNCTOR_DISPATCH(DebugNode); RELAY_EXPR_FUNCTOR_DISPATCH(UnaryOpNode); RELAY_EXPR_FUNCTOR_DISPATCH(BinaryOpNode); RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); diff --git a/relay/include/relay/expr_visitor.h b/relay/include/relay/expr_visitor.h index fbc10ecf069d4..94a9239e55157 100644 --- a/relay/include/relay/expr_visitor.h +++ b/relay/include/relay/expr_visitor.h @@ -12,27 +12,84 @@ namespace nnvm { namespace relay { -class ExprVisitor : public ExprFunctor { +template +class ExprVisitor : public ExprFunctor { public: - void VisitExpr_(const LocalIdNode* op) override; - void VisitExpr_(const GlobalIdNode* op) override; - void VisitExpr_(const IntrinsicIdNode* op) override; - void VisitExpr_(const FloatLitNode* op) override; - void VisitExpr_(const BoolLitNode* op) override; - void VisitExpr_(const IntLitNode* op) override; - void VisitExpr_(const TensorLitNode* op) override; - void VisitExpr_(const ProductLitNode* op) override; - void VisitExpr_(const CastNode* op) override; - void VisitExpr_(const ParamNode* op) override; - void VisitExpr_(const FunctionNode* op) override; - void VisitExpr_(const CallNode* op) override; - void VisitExpr_(const DebugNode* op) override; - void VisitExpr_(const UnaryOpNode* op) override; - void VisitExpr_(const BinaryOpNode* op) override; - void VisitExpr_(const LetNode* op) override; - void VisitExpr_(const ReverseNode* op) override; - void VisitExpr_(const AccumulateNode* op) override; - void VisitExpr_(const ZeroNode* op) override; + void VisitExpr_(const LocalIdNode* op, Args... args) override { return; } + + void VisitExpr_(const GlobalIdNode* op, Args... args) override { return; } + + void VisitExpr_(const IntrinsicIdNode* op, Args... args) override { return; } + + void VisitExpr_(const FloatLitNode* op, Args... args) override { return; } + + void VisitExpr_(const BoolLitNode* op, Args... args) override { return; } + + void VisitExpr_(const IntLitNode* op, Args... args) override { return; } + + void VisitExpr_(const TensorLitNode* op, Args... args) override { + // todo + return; + } + + void VisitExpr_(const ProductLitNode* op, Args... args) override { + // todo + return; + } + + void VisitExpr_(const CastNode* op, Args... args) override { + this->VisitExpr(op->node, args...); + } + + void VisitExpr_(const ParamNode* op, Args... args) override { + this->VisitExpr(op->id, args...); + } + + void VisitExpr_(const FunctionNode* op, Args... args) override { + for (auto param : op->params) { + this->VisitExpr(param, args...); + } + + this->VisitExpr(op->body, args...); + } + + void VisitExpr_(const CallNode* op, Args... args) override { + this->VisitExpr(op->fn, args...); + for (auto arg : op->args) { + this->VisitExpr(arg, args...); + } + } + + void VisitExpr_(const DebugNode* op, Args... args) override { + this->VisitExpr(op->node, args...); + } + + void VisitExpr_(const UnaryOpNode* op, Args... args) override { + this->VisitExpr(op->node, args...); + } + + void VisitExpr_(const BinaryOpNode* op, Args... args) override { + this->VisitExpr(op->left, args...); + this->VisitExpr(op->right, args...); + } + void VisitExpr_(const LetNode* op, Args... args) override { + this->VisitExpr(op->id, args...); + this->VisitExpr(op->value, args...); + } + + void VisitExpr_(const ReverseNode* op, Args... args) override { + this->VisitExpr(op->node, args...); + } + + void VisitExpr_(const AccumulateNode* op, Args... args) override { + // todo + return; + } + + void VisitExpr_(const ZeroNode* op, Args... args) override { + // todo + return; + } }; } // namespace relay diff --git a/relay/include/relay/free_type_vars.h b/relay/include/relay/free_type_vars.h index 09406f2f4dfda..425242303d52e 100644 --- a/relay/include/relay/free_type_vars.h +++ b/relay/include/relay/free_type_vars.h @@ -12,12 +12,6 @@ namespace nnvm { namespace relay { -struct FreeTypeVars : TypeVisitor { - public: - std::set free_vars; - void VisitType_(const TypeVar& op) override; -}; - std::set free_type_vars(const Type& e); } // namespace relay diff --git a/relay/include/relay/free_vars.h b/relay/include/relay/free_vars.h index 837a67902b506..a648de3aafe36 100644 --- a/relay/include/relay/free_vars.h +++ b/relay/include/relay/free_vars.h @@ -12,7 +12,7 @@ namespace nnvm { namespace relay { -struct FreeVars : ExprVisitor { +struct FreeVars : ExprVisitor<> { public: std::set free_vars; diff --git a/relay/include/relay/node.h b/relay/include/relay/node.h index fb15fe8324a99..e08cba198062d 100644 --- a/relay/include/relay/node.h +++ b/relay/include/relay/node.h @@ -8,35 +8,39 @@ #include #include +#include #include #include /*! \brief Macro to make it easy to define node ref type given node */ -#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefPtr) \ - class TypeName : public NodeRefPtr { \ - public: \ - TypeName() {} \ - explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefPtr(n) {} \ - const NodeName* operator->() const { \ - return static_cast(node_.get()); \ - } \ - using ContainerType = NodeName; \ - }; \ +#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefPtr) \ + class TypeName : public NodeRefPtr { \ + public: \ + TypeName() {} \ + explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefPtr(n) {} \ + const NodeName* operator->() const { \ + return static_cast(node_.get()); \ + } \ + using ContainerType = NodeName; \ + }; /*! \brief Macro to make it easy to define node ref type given node */ -#define RELAY_DEFINE_EXPR(TypeName, NodeName) \ - RELAY_DEFINE_NODE_REF(TypeName, NodeName, Expr) +#define RELAY_DEFINE_EXPR(TypeName, NodeName) \ + RELAY_DEFINE_NODE_REF(TypeName, NodeName, Expr) /*! \brief Macro to make it easy to define node ref type given node */ -#define RELAY_DEFINE_VALUE(TypeName, NodeName) \ - RELAY_DEFINE_NODE_REF(TypeName, NodeName, Value) +#define RELAY_DEFINE_VALUE(TypeName, NodeName) \ + RELAY_DEFINE_NODE_REF(TypeName, NodeName, Value) /*! \brief Macro to make it easy to define node ref type given node */ -#define RELAY_DEFINE_TYPE(TypeName, NodeName) RELAY_DEFINE_NODE_REF(TypeName, NodeName, Type) +#define RELAY_DEFINE_TYPE(TypeName, NodeName) \ + RELAY_DEFINE_NODE_REF(TypeName, NodeName, Type) namespace nnvm { namespace relay { +typedef HalideIR::Type HType; + struct Node : public tvm::Node {}; /*! @@ -141,8 +145,6 @@ class FloatValueNode : public ValueNode { RELAY_DEFINE_VALUE(FloatValue, FloatValueNode); - - // end move me /*! \brief Base type of the Relay type hiearchy. */ @@ -163,43 +165,36 @@ struct Type : public NodeRef { using ContainerType = TypeNode; }; -class IntType; +class BaseType; /*! \brief The type of integer values. */ -class IntTypeNode : public TypeNode { +class BaseTypeNode : public TypeNode { public: - unsigned width; - - IntTypeNode() {} - - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("width", reinterpret_cast(&width)); - } + HalideIR::Type type; - TVM_DLL static IntType make(int width); + BaseTypeNode() {} - static constexpr const char* _type_key = "nnvm.IntType"; - TVM_DECLARE_NODE_TYPE_INFO(IntTypeNode, TypeNode); -}; + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("dtype", &type); } -RELAY_DEFINE_TYPE(IntType, IntTypeNode); + TVM_DLL static BaseType make(HalideIR::Type type); -class BoolType; + /** Constructing an unsigned integer type */ + TVM_DLL static BaseType Int(int bits, int lanes = 1); -/*! \brief The type of boolean values. */ -class BoolTypeNode : public TypeNode { - public: - BoolTypeNode() {} + /** Constructing an unsigned integer type */ + TVM_DLL static BaseType UInt(int bits, int lanes = 1); - void VisitAttrs(tvm::AttrVisitor* v) final {} + /** Construct a floating-point type */ + TVM_DLL static BaseType Float(int bits, int lanes = 1); - TVM_DLL static BoolType make(); + /** Construct a boolean type */ + TVM_DLL static BaseType Bool(int lanes = 1); - static constexpr const char* _type_key = "nnvm.BoolType"; - TVM_DECLARE_NODE_TYPE_INFO(BoolTypeNode, TypeNode); + static constexpr const char* _type_key = "nnvm.BaseType"; + TVM_DECLARE_NODE_TYPE_INFO(BaseTypeNode, TypeNode); }; -RELAY_DEFINE_TYPE(BoolType, BoolTypeNode); +RELAY_DEFINE_TYPE(BaseType, BaseTypeNode); class TypeVar; @@ -362,13 +357,13 @@ class TensorLit; /*! \brief Tensor literal [t1, [x1, ..., xn]]. */ class TensorLitNode : public ExprNode { public: - tvm::Array data; + tvm::Array data; TensorLitNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); } - TVM_DLL static TensorLit make(tvm::Array data); + TVM_DLL static TensorLit make(tvm::Array data); static constexpr const char* _type_key = "nnvm.TensorLit"; TVM_DECLARE_NODE_TYPE_INFO(TensorLitNode, ExprNode); @@ -381,13 +376,13 @@ class ProductLit; /*! \brief Product literal (x, ... y). */ class ProductLitNode : public ExprNode { public: - tvm::Array fields; + tvm::Array fields; ProductLitNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); } - TVM_DLL static ProductLit make(tvm::Array data); + TVM_DLL static ProductLit make(tvm::Array data); static constexpr const char* _type_key = "nnvm.ProductLit"; TVM_DECLARE_NODE_TYPE_INFO(ProductLitNode, ExprNode); @@ -585,10 +580,7 @@ RELAY_DEFINE_EXPR(Debug, DebugNode) class UnaryOp; -enum UOp : int { - NEG = 0, - SQ = 1 -}; +enum UOp : int { NEG = 0, SQ = 1 }; /*! \brief Unary Operator. */ class UnaryOpNode : public ExprNode { @@ -879,9 +871,7 @@ struct hash { * \param id global id. * \return hash code. */ - size_t operator()(const nnvm::relay::LocalId& id) const { - return id.hash(); - } + size_t operator()(const nnvm::relay::LocalId& id) const { return id.hash(); } }; } // namespace std diff --git a/relay/include/relay/reverse_ad.h b/relay/include/relay/reverse_ad.h index fb1c03b8f02ad..895b515428d4b 100644 --- a/relay/include/relay/reverse_ad.h +++ b/relay/include/relay/reverse_ad.h @@ -19,22 +19,22 @@ namespace relay { struct ReverseAD : ExprFunctor { ReverseAD() {} Expr AD(const Expr& expr); - Expr VisitExpr_(const LocalId& op); - Expr VisitExpr_(const GlobalId& op); - Expr VisitExpr_(const IntrinsicId& op); - Expr VisitExpr_(const FloatLit& op); - Expr VisitExpr_(const BoolLit& op); - Expr VisitExpr_(const IntLit& op); - Expr VisitExpr_(const TensorLit& op); - Expr VisitExpr_(const ProductLit& op); - Expr VisitExpr_(const Cast& op); - Expr VisitExpr_(const Param& op); - Expr VisitExpr_(const Function& op); - Expr VisitExpr_(const Call& op); - Expr VisitExpr_(const Debug& op); - Expr VisitExpr_(const UnaryOp& op); - Expr VisitExpr_(const BinaryOp& op); - Expr VisitExpr_(const Let& op); + Expr VisitExpr_(const LocalIdNode* op); + Expr VisitExpr_(const GlobalIdNode* op); + Expr VisitExpr_(const IntrinsicIdNode* op); + Expr VisitExpr_(const FloatLitNode* op); + Expr VisitExpr_(const BoolLitNode* op); + Expr VisitExpr_(const IntLitNode* op); + Expr VisitExpr_(const TensorLitNode* op); + Expr VisitExpr_(const ProductLitNode* op); + Expr VisitExpr_(const CastNode* op); + Expr VisitExpr_(const ParamNode* op); + Expr VisitExpr_(const FunctionNode* op); + Expr VisitExpr_(const CallNode* op); + Expr VisitExpr_(const DebugNode* op); + Expr VisitExpr_(const UnaryOpNode* op); + Expr VisitExpr_(const BinaryOpNode* op); + Expr VisitExpr_(const LetNode* op); }; } // namespace relay diff --git a/relay/include/relay/type_functor.h b/relay/include/relay/type_functor.h index 5b042ba1b94d7..54aed667963fa 100644 --- a/relay/include/relay/type_functor.h +++ b/relay/include/relay/type_functor.h @@ -13,24 +13,23 @@ namespace nnvm { namespace relay { -template +template class TypeFunctor; // functions to be overriden. -#define TYPE_FUNCTOR_DEFAULT { \ - return VisitTypeDefault_(op, std::forward(args)...); \ - } +#define TYPE_FUNCTOR_DEFAULT \ + { return VisitTypeDefault_(op, std::forward(args)...); } -#define RELAY_EXPR_FUNCTOR_DISPATCH(OP, OPRef) \ - vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitType_(static_cast(n), \ - std::forward(args)...); \ - }); \ +#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitType_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); -template +template class TypeFunctor { -private: + private: using TSelf = TypeFunctor; using FType = tvm::IRFunctor; @@ -59,13 +58,15 @@ class TypeFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitType_(const IntType& op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const BoolType& op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TypeVar& op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TypeId& op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TypeQuantifier& op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const BaseTypeNode* op, + Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeVarNode* op, + Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeIdNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeQuantifierNode* op, + Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitTypeDefault_(const NodeRef & op, Args ...) { + virtual R VisitTypeDefault_(const Node* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->type_key(); return R(); } @@ -75,11 +76,10 @@ class TypeFunctor { static FType InitVTable() { FType vtable; // Set dispatch - RELAY_EXPR_FUNCTOR_DISPATCH(IntTypeNode, IntType); - RELAY_EXPR_FUNCTOR_DISPATCH(BoolTypeNode, BoolType); - RELAY_EXPR_FUNCTOR_DISPATCH(TypeVarNode, TypeVar); - RELAY_EXPR_FUNCTOR_DISPATCH(TypeIdNode, TypeId); - RELAY_EXPR_FUNCTOR_DISPATCH(TypeQuantifierNode, TypeQuantifier); + RELAY_TYPE_FUNCTOR_DISPATCH(BaseTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeIdNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeQuantifierNode); return vtable; } }; diff --git a/relay/include/relay/type_visitor.h b/relay/include/relay/type_visitor.h index 330e5e88a3cc7..4019057514e59 100644 --- a/relay/include/relay/type_visitor.h +++ b/relay/include/relay/type_visitor.h @@ -12,14 +12,14 @@ namespace nnvm { namespace relay { -struct TypeVisitor : TypeFunctor { - void VisitType_(const IntType& op) override {} - void VisitType_(const BoolType& op) override {} - void VisitType_(const TypeVar& op) override {} - void VisitType_(const TypeId& op) override {} - void VisitType_(const TypeQuantifier& op) override { - this->VisitType(op->binderType); - this->VisitType(op->boundType); +template +struct TypeVisitor : TypeFunctor { + void VisitType_(const BaseTypeNode* op, Args... args) override {} + void VisitType_(const TypeVarNode* op, Args... args) override {} + void VisitType_(const TypeIdNode* op, Args... args) override {} + void VisitType_(const TypeQuantifierNode* op, Args... args) override { + this->VisitType(op->binderType, args...); + this->VisitType(op->boundType, args...); } }; diff --git a/relay/python/relay/expr.py b/relay/python/relay/expr.py index ebfdc1c52fb81..f403eee380af7 100644 --- a/relay/python/relay/expr.py +++ b/relay/python/relay/expr.py @@ -1,4 +1,4 @@ -#pylint: disable=no-else-return, unidiomatic-typecheck +# pylint: disable=no-else-return, unidiomatic-typecheck """All the expression nodes""" from enum import IntEnum from typing import Union @@ -181,7 +181,7 @@ class ProductLit(Expr): @register_nnvm_node -class IntType(Type): +class BaseType(Type): pass @@ -207,7 +207,7 @@ class IntrinsicId(Expr): @register_nnvm_node -class Param(NodeBase): +class Param(Expr): pass @@ -287,14 +287,28 @@ class Zero(Expr): class IntValue(Value): pass + @register_nnvm_node class FloatValue(Value): pass + @register_nnvm_node class BoolValue(Value): pass + @register_nnvm_node class FnValue(Value): pass + +# TODO move me + + +def alpha_eq(left: Union[Expr, Type], right: [Expr, Type]) -> bool: + if isinstance(left, Expr) and isinstance(right, Expr): + return bool(make._alpha_eq(left, right)) + elif isinstance(left, Type) and isinstance(right, Type): + return bool(make._type_alpha_eq(left, right)) + else: + return False diff --git a/relay/python/relay/make.py b/relay/python/relay/make.py index f3a83a33706e1..f8b9646eb6991 100644 --- a/relay/python/relay/make.py +++ b/relay/python/relay/make.py @@ -3,3 +3,20 @@ from tvm._ffi.function import _init_api _init_api("nnvm.make", __name__) + +#pylint: disable=invalid-name +def IntType(bits, lanes=1): + """A wrapper for constructing the Int base type.""" + return _IntType(bits, lanes) + + +#pylint: disable=invalid-name +def FloatType(bits, lanes=1): + """A wrapper for constructing the Float base type.""" + return _FloatType(bits, lanes) + + +#pylint: disable=invalid-name +def BoolType(lanes=1): + """A wrapper for constructing the Bool base type.""" + return _BoolType(lanes) diff --git a/relay/python/relay/relay.py b/relay/python/relay/relay.py index 89587e1023501..92667abe36668 100644 --- a/relay/python/relay/relay.py +++ b/relay/python/relay/relay.py @@ -4,7 +4,7 @@ import inspect # from typing import Dict, List from collections import OrderedDict -# pylint: disable=wildcard-import +# pylint: disable=wildcard-import, unused-wildcard-import import nnvm.relay.eval as re from .make import * diff --git a/relay/python/relay/visitor.py b/relay/python/relay/visitor.py index 5089f06e54fa9..b863235db7ae0 100644 --- a/relay/python/relay/visitor.py +++ b/relay/python/relay/visitor.py @@ -1,6 +1,6 @@ """Visitor shim for C++'s ExprFunctor.""" import inspect -#pylint: disable=wildcard-import +#pylint: disable=wildcard-import,unused-wildcard-import from .make import * # TODO(jroesch): Come back to do wrapper automatically diff --git a/relay/src/relay/alpha_eq.cc b/relay/src/relay/alpha_eq.cc new file mode 100644 index 0000000000000..0736a712431d3 --- /dev/null +++ b/relay/src/relay/alpha_eq.cc @@ -0,0 +1,309 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file alpha_eq.cc + * \brief Compute the set of variables not bound in the expression. + */ +#include "nnvm/relay/alpha_eq.h" +#include "nnvm/relay/expr_visitor.h" +#include "nnvm/relay/type_visitor.h" + +namespace nnvm { +namespace relay { + +using namespace tvm::runtime; + +struct AlphaEq : ExprVisitor { + public: + tvm::Map eq_map; + bool equal; + AlphaEq() : eq_map(), equal(true) {} + + void VisitExpr_(const LocalIdNode* e1, const Expr& e2) override { + if (const LocalIdNode* id2 = e2.as()) { + auto local1 = GetRef(e1); + auto local2 = GetRef(id2); + // + // We handle open terms with this rule assuming variables are identical. + // + // Not sure if we should do this. + if (local1 == local2) { + equal = true; + return; + } + + // Next we see if there is mapping for local1 into the rhs term. + // If there is we check to see if those are equal. + if (eq_map.find(local1) != eq_map.end()) { + equal = equal && eq_map[local1] == local2; + } else { + equal = false; + } + } else { + equal = false; + } + } + + void VisitExpr_(const GlobalIdNode* g1, const Expr& e2) override { + if (const GlobalIdNode* g2 = e2.as()) { + equal = equal && g1 == g2; + } else { + equal = false; + } + } + + void VisitExpr_(const IntrinsicIdNode* i1, const Expr& e2) override { + if (const IntrinsicIdNode* i2 = e2.as()) { + equal = equal && i1 == i2; + } else { + equal = false; + } + } + + void VisitExpr_(const FloatLitNode* fl1, const Expr& e2) override { + if (const FloatLitNode* fl2 = e2.as()) { + equal = equal && fl1->value == fl2->value; + } else { + equal = false; + } + } + + void VisitExpr_(const BoolLitNode* bl1, const Expr& e2) override { + if (const BoolLitNode* bl2 = e2.as()) { + equal = equal && bl1->value == bl2->value; + } else { + equal = false; + } + } + + void VisitExpr_(const IntLitNode* il1, const Expr& e2) override { + if (const IntLitNode* il2 = e2.as()) { + equal = equal && il1->value == il2->value; + } else { + equal = false; + } + } + + void VisitExpr_(const TensorLitNode* tl1, const Expr& e2) override { + if (const TensorLitNode* tl2 = e2.as()) { + if (tl1->data.size() != tl2->data.size()) { + equal = false; + return; + } + + for (auto i = 0; i < tl1->data.size(); i++) { + this->VisitExpr(tl1->data[i], tl2->data[i]); + } + } else { + equal = false; + } + } + + void VisitExpr_(const ProductLitNode* pl1, const Expr& e2) override { + if (const ProductLitNode* pl2 = e2.as()) { + if (pl1->fields.size() != pl2->fields.size()) { + equal = false; + return; + } + + for (auto i = 0; i < pl1->fields.size(); i++) { + this->VisitExpr(pl1->fields[i], pl2->fields[i]); + } + } else { + equal = false; + } + } + + void VisitExpr_(const CastNode* op, const Expr& e2) override { + if (const CastNode* cast = e2.as()) { + equal = equal && alpha_eq(op->target, cast->target); + this->VisitExpr(op->node, cast->node); + } else { + equal = false; + } + } + + void VisitExpr_(const ParamNode* p1, const Expr& e2) override { + if (const ParamNode* p2 = e2.as()) { + eq_map.Set(p1->id, p2->id); + equal = equal && alpha_eq(p1->type, p2->type); + } else { + equal = false; + } + } + + void VisitExpr_(const FunctionNode* func1, const Expr& e2) override { + if (const FunctionNode* func2 = e2.as()) { + if (func1->params.size() != func2->params.size()) { + equal = false; + return; + } + + for (int i = 0; i < func1->params.size(); i++) { + this->VisitExpr(func1->params[i], func2->params[i]); + } + + this->VisitExpr(func1->body, func2->body); + } else { + equal = false; + } + } + + void VisitExpr_(const CallNode* op, const Expr& e2) override { + if (const CallNode* call = e2.as()) { + this->VisitExpr(op->fn, call->fn); + + if (op->args.size() != call->args.size()) { + equal = false; + return; + } + + for (int i = 0; i < op->args.size(); i++) { + this->VisitExpr(op->args[i], call->args[i]); + } + } else { + equal = false; + } + } + + void VisitExpr_(const DebugNode* op, const Expr& e2) override { + if (const DebugNode* bop = e2.as()) { + this->VisitExpr(op->node, bop->node); + } else { + equal = false; + } + } + + void VisitExpr_(const UnaryOpNode* op, const Expr& e2) override { + if (const UnaryOpNode* bop = e2.as()) { + this->VisitExpr(op->node, bop->node); + } else { + equal = false; + } + } + + void VisitExpr_(const BinaryOpNode* op, const Expr& e2) override { + if (const BinaryOpNode* bop = e2.as()) { + this->VisitExpr(op->left, bop->left); + this->VisitExpr(op->right, bop->right); + } else { + equal = false; + } + } + + void VisitExpr_(const LetNode* op, const Expr& e2) override { + if (const LetNode* let = e2.as()) { + eq_map.Set(op->id, let->id); + this->VisitExpr(op->value, let->value); + this->VisitExpr(op->body, let->body); + } else { + equal = false; + } + } + + void VisitExpr_(const ReverseNode* op, const Expr& e2) override { + if (const ReverseNode* rev = e2.as()) { + this->VisitExpr(op->node, rev->node); + } else { + equal = false; + } + } + + void VisitExpr_(const AccumulateNode* op, const Expr& e2) override { + // todo + return; + } + + void VisitExpr_(const ZeroNode* z1, const Expr& e2) override { + if (const ZeroNode* z2 = e2.as()) { + equal = equal && alpha_eq(z1->type, z2->type); + } else { + equal = false; + } + } +}; + +bool alpha_eq(const Expr& e1, const Expr& e2) { + AlphaEq eq; + eq.VisitExpr(e1, e2); + return eq.equal; +} + +struct TypeAlphaEq : TypeVisitor { + tvm::Map eq_map; + bool equal; + TypeAlphaEq() : eq_map(), equal(true) {} + + void VisitType_(const BaseTypeNode* bt1, const Type& t2) override { + if (const BaseTypeNode* bt2 = t2.as()) { + const HType& t1 = bt1->type; + const HType& t2 = bt2->type; + equal = equal && t1 == t2; + return; + } else { + equal = false; + } + } + + void VisitType_(const TypeVarNode* bt1, const Type& t2) override { + if (const TypeVarNode* bt2 = t2.as()) { + equal = equal && bt1 == bt2; + return; + } else { + equal = false; + } + } + + void VisitType_(const TypeIdNode* ti1, const Type& t2) override { + if (const TypeIdNode* ti2 = t2.as()) { + auto tid1 = GetRef(ti1); + auto tid2 = GetRef(ti2); + + // We handle open terms with this rule assuming variables are identical. + // + // Not sure if we should do this. + if (tid1 == tid2) { + equal = true; + return; + } + + // Next we see if there is mapping for local1 into the rhs term. + // If there is we check to see if those are equal. + if (eq_map.find(tid1) != eq_map.end()) { + equal = equal && eq_map[tid1] == tid2; + } else { + equal = false; + } + } else { + equal = false; + } + } + + void VisitType_(const TypeQuantifierNode* op, const Type& t2) override { + // todo + return; + } +}; + +bool alpha_eq(const Type& t1, const Type& t2) { + TypeAlphaEq aeq; + aeq.VisitType(t1, t2); + return aeq.equal; +} + +// TODO(@jroesch): move to correct namespace? +TVM_REGISTER_API("nnvm.make._alpha_eq") + .set_body([](TVMArgs args, TVMRetValue* ret) { + Expr e1 = args[0]; + Expr e2 = args[1]; + *ret = alpha_eq(e1, e2); + }); + +TVM_REGISTER_API("nnvm.make._type_alpha_eq") + .set_body([](TVMArgs args, TVMRetValue* ret) { + Type t1 = args[0]; + Type t2 = args[1]; + *ret = alpha_eq(t1, t2); + }); + +} // namespace relay +} // namespace nnvm diff --git a/relay/src/relay/expr_visitor.cc b/relay/src/relay/expr_visitor.cc deleted file mode 100644 index d3c250152d0f6..0000000000000 --- a/relay/src/relay/expr_visitor.cc +++ /dev/null @@ -1,84 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file expr_visitor.cc - * \brief A default visitor implementation using ExprFuntor, easy for - * constructing visitors which use mutable state internally. - */ - -#include -#include - -namespace nnvm { -namespace relay { - -using namespace tvm::runtime; - -void ExprVisitor::VisitExpr_(const LocalIdNode* op) { return; } - -void ExprVisitor::VisitExpr_(const GlobalIdNode* op) { return; } - -void ExprVisitor::VisitExpr_(const IntrinsicIdNode* op) { return; } - -void ExprVisitor::VisitExpr_(const FloatLitNode* op) { return; } - -void ExprVisitor::VisitExpr_(const BoolLitNode* op) { return; } - -void ExprVisitor::VisitExpr_(const IntLitNode* op) { return; } - -void ExprVisitor::VisitExpr_(const TensorLitNode* op) { - // todo - return; -} - -void ExprVisitor::VisitExpr_(const ProductLitNode* op) { - // todo - return; -} - -void ExprVisitor::VisitExpr_(const CastNode* op) { this->VisitExpr(op->node); } - -void ExprVisitor::VisitExpr_(const ParamNode* op) { this->VisitExpr(op->id); } - -void ExprVisitor::VisitExpr_(const FunctionNode* op) { - for (auto param : op->params) { - this->VisitExpr(param); - } - - this->VisitExpr(op->body); -} - -void ExprVisitor::VisitExpr_(const CallNode* op) { - this->VisitExpr(op->fn); - for (auto arg : op->args) { - this->VisitExpr(arg); - } -} - -void ExprVisitor::VisitExpr_(const DebugNode* op) { VisitExpr(op->node); } - -void ExprVisitor::VisitExpr_(const UnaryOpNode* op) { VisitExpr(op->node); } - -void ExprVisitor::VisitExpr_(const BinaryOpNode* op) { - VisitExpr(op->left); - VisitExpr(op->right); -} - -void ExprVisitor::VisitExpr_(const LetNode* op) { - VisitExpr(op->id); - VisitExpr(op->value); -} - -void ExprVisitor::VisitExpr_(const ReverseNode* op) { VisitExpr(op->node); } - -void ExprVisitor::VisitExpr_(const AccumulateNode* op) { - // todo - return; -} - -void ExprVisitor::VisitExpr_(const ZeroNode* op) { - // todo - return; -} - -} // namespace relay -} // namespace nnvm diff --git a/relay/src/relay/forward_ad.cc b/relay/src/relay/forward_ad.cc index 4213e5a4a2aa8..b2e9f068a6c5a 100644 --- a/relay/src/relay/forward_ad.cc +++ b/relay/src/relay/forward_ad.cc @@ -32,7 +32,7 @@ Expr ForwardAD::VisitExpr_(const IntrinsicIdNode* op) { } Expr ForwardAD::VisitExpr_(const FloatLitNode* op) { - return ProductLitNode::make({op->GetNodeRef(), FloatLitNode::make(0)}); + return ProductLitNode::make({GetRef(op), FloatLitNode::make(0)}); } Expr ForwardAD::VisitExpr_(const BoolLitNode* op) { diff --git a/relay/src/relay/free_type_vars.cc b/relay/src/relay/free_type_vars.cc index 816acd3962a38..d8586f0e57efe 100644 --- a/relay/src/relay/free_type_vars.cc +++ b/relay/src/relay/free_type_vars.cc @@ -12,9 +12,15 @@ namespace relay { using namespace tvm::runtime; -void FreeTypeVars::VisitType_(const TypeVar& op) { - free_vars.insert(op); -} +struct FreeTypeVars : TypeVisitor<> { + public: + std::set free_vars; + + void VisitType_(const TypeVarNode* op) override { + this->free_vars.insert(GetRef(op)); + } +}; + /*! \brief Compute the set of variables not bound in the expression e */ std::set free_type_vars(const Type& e) { diff --git a/relay/src/relay/node.cc b/relay/src/relay/node.cc index 843b492bb97da..d05ff209e8911 100644 --- a/relay/src/relay/node.cc +++ b/relay/src/relay/node.cc @@ -72,7 +72,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "BoolLitNode(" << node->value << ")"; }); -TensorLit TensorLitNode::make(tvm::Array value) { +TensorLit TensorLitNode::make(tvm::Array value) { std::shared_ptr n = std::make_shared(); n->data = value; return TensorLit(n); @@ -89,7 +89,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TensorLitNode(" << node->data << ")"; }); -ProductLit ProductLitNode::make(tvm::Array value) { +ProductLit ProductLitNode::make(tvm::Array value) { std::shared_ptr n = std::make_shared(); n->fields = value; return ProductLit(n); @@ -157,44 +157,54 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "CastNode(" << node->target << ", " << node->node << ")"; }); -IntType IntTypeNode::make(int value) { - std::shared_ptr n = std::make_shared(); - if (0 <= value) { - n->width = (unsigned)value; - return IntType(n); - } else { - throw "this shoudl be an error, what is right error to throw here?"; - } +BaseType BaseTypeNode::make(HalideIR::Type type) { + std::shared_ptr n = std::make_shared(); + n->type = std::move(type); + return BaseType(n); } -TVM_REGISTER_API("nnvm.make.IntType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args.size() == 0) { - *ret = IntTypeNode::make(64); // fix me - } else { - *ret = IntTypeNode::make(args[0]); // fix me - } - }); +BaseType BaseTypeNode::Int(int bits, int lanes) { + return BaseTypeNode::make(HalideIR::Int(bits, lanes)); +} -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const IntTypeNode *node, tvm::IRPrinter *p) { - p->stream << "IntType(" << node->width << ")"; - }); +BaseType BaseTypeNode::UInt(int bits, int lanes) { + return BaseTypeNode::make(HalideIR::UInt(bits, lanes)); +} -BoolType BoolTypeNode::make() { - std::shared_ptr n = std::make_shared(); - return BoolType(n); +BaseType BaseTypeNode::Float(int bits, int lanes) { + return BaseTypeNode::make(HalideIR::Float(bits, lanes)); } -TVM_REGISTER_API("nnvm.make.BoolType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = BoolTypeNode::make(); - }); +/** Construct a floating-point type */ +BaseType BaseTypeNode::Bool(int lanes) { + return BaseTypeNode::make(HalideIR::Bool(lanes)); +} + +TVM_REGISTER_API("nnvm.make.BaseType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + HalideIR::Type type = args[0]; + *ret = BaseTypeNode::make(type); + }); + +TVM_REGISTER_API("nnvm.make._IntType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = BaseTypeNode::Int(args[0], args[1]); + }); + +TVM_REGISTER_API("nnvm.make._BoolType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = BaseTypeNode::Bool(args[0]); + }); + +TVM_REGISTER_API("nnvm.make._FloatType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = BaseTypeNode::Float(args[0], args[1]); + }); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const BoolTypeNode *node, - tvm::IRPrinter *p) { - p->stream << "BoolType()"; + .set_dispatch([](const BaseTypeNode *node, tvm::IRPrinter *p) { + p->stream << "BaseType(" << node->type << ")"; }); TypeVar TypeVarNode::make(int id) { diff --git a/relay/src/relay/reverse_ad.cc b/relay/src/relay/reverse_ad.cc index a1ec87dbd3898..9d6b20b0f8954 100644 --- a/relay/src/relay/reverse_ad.cc +++ b/relay/src/relay/reverse_ad.cc @@ -18,65 +18,65 @@ struct ReverseADError : dmlc::Error { Expr ReverseAD::AD(const Expr& expr) { return this->operator()(expr); } -Expr ReverseAD::VisitExpr_(const LocalId& local) { +Expr ReverseAD::VisitExpr_(const LocalIdNode* local) { throw ReverseADError("LocalIdNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const GlobalId& op) { +Expr ReverseAD::VisitExpr_(const GlobalIdNode* op) { throw ReverseADError("GlobalIdNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const IntrinsicId& op) { +Expr ReverseAD::VisitExpr_(const IntrinsicIdNode* op) { throw ReverseADError("IntrinsicIdNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const FloatLit& op) { +Expr ReverseAD::VisitExpr_(const FloatLitNode* op) { throw ReverseADError("FloatLitNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const BoolLit& op) { +Expr ReverseAD::VisitExpr_(const BoolLitNode* op) { throw ReverseADError("BoolLitNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const IntLit& op) { +Expr ReverseAD::VisitExpr_(const IntLitNode* op) { throw ReverseADError("IntLitNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const TensorLit& op) { +Expr ReverseAD::VisitExpr_(const TensorLitNode* op) { throw ReverseADError("TensorLitNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const ProductLit& op) { +Expr ReverseAD::VisitExpr_(const ProductLitNode* op) { throw ReverseADError("ProductLitNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const Cast& op) { +Expr ReverseAD::VisitExpr_(const CastNode* op) { return this->VisitExpr(op->node); } -Expr ReverseAD::VisitExpr_(const Param& op) { +Expr ReverseAD::VisitExpr_(const ParamNode* op) { throw ReverseADError("ParamNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const Function& op) { +Expr ReverseAD::VisitExpr_(const FunctionNode* op) { throw ReverseADError("FunctionNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const Call& op) { +Expr ReverseAD::VisitExpr_(const CallNode* op) { throw ReverseADError("CallNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const Debug& op) { +Expr ReverseAD::VisitExpr_(const DebugNode* op) { throw ReverseADError("DebugNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const UnaryOp& op) { +Expr ReverseAD::VisitExpr_(const UnaryOpNode* op) { throw ReverseADError("UnaryOpNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const BinaryOp& op) { +Expr ReverseAD::VisitExpr_(const BinaryOpNode* op) { throw ReverseADError("BinaryOpNode Reverse AD NYI"); } -Expr ReverseAD::VisitExpr_(const Let& op) { +Expr ReverseAD::VisitExpr_(const LetNode* op) { throw ReverseADError("LetNode Reverse AD NYI"); } diff --git a/relay/src/relay/typechecker.cc b/relay/src/relay/typechecker.cc index 50f042834d74a..534422c7efa40 100644 --- a/relay/src/relay/typechecker.cc +++ b/relay/src/relay/typechecker.cc @@ -90,7 +90,7 @@ Type Typechecker::VisitExpr_(const LetNode* op) { TVM_REGISTER_API("nnvm.tyck.check") .set_body([](TVMArgs args, TVMRetValue *ret) { - // *ret = BaseTypeNode::Int(32); + *ret = BaseTypeNode::Int(32); }); } // namespace relay