From fad945a579ce440cbeaa17c699ac9c8a5eab0256 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 21 Apr 2018 23:25:58 -0700 Subject: [PATCH] REPL + Operators + Env Refactor (#42) * Start on REPL * Clean up evaluator after rebase * Add support for registering primitives Refactor Environment interface to better disambiguate between Globals and Intrinsics. Start on Python code to produce operator specific JIT compilers which will specialize an operator based on the type signature of the operator * Debugging test crashes * Fix testing error and setup primitive test case * Add a bunch of type annotations * Add more .pyi files to stub C++ written functions * Getting closer to primitive evaluation, need type information * Continue refactoring needed to call Tensor program w/ Primitives * Add Tensor values, and improve calling into eval * Add notes for tomorrow * Fix linting, prepare for early merge * Fix rebase * Fix linting * Add typed_ast to package list * Fix failing test cases * Trigger rebuild --- relay/include/relay/anf.h | 3 +- relay/include/relay/environment.h | 24 ++- relay/include/relay/evaluator.h | 55 +------ relay/include/relay/ir/base.h | 9 +- relay/include/relay/ir/expr.h | 8 +- relay/include/relay/ir/item.h | 9 +- relay/include/relay/ir/type.h | 2 +- relay/include/relay/ir/value.h | 20 +++ relay/include/relay/py_functor.h | 2 +- relay/python/relay/env.py | 29 ++++ relay/python/relay/env.pyi | 10 ++ relay/python/relay/expr.py | 28 +--- relay/python/relay/make.py | 16 +- relay/python/relay/make.pyi | 29 ++++ relay/python/relay/operators.py | 30 ++++ relay/python/relay/relay.py | 242 +++++++++++++++++++---------- relay/python/relay/repl.py | 24 +++ relay/src/relay/anf.cc | 3 +- relay/src/relay/environment.cc | 142 ++++++++++++----- relay/src/relay/evaluator.cc | 245 +++++++++++++++++++++++++++--- relay/src/relay/forward_ad.cc | 14 +- relay/src/relay/ir.cc | 36 ++--- relay/src/relay/ir/expr.cc | 10 +- relay/src/relay/ir/item.cc | 2 +- relay/src/relay/ir/value.cc | 71 ++++++--- relay/src/relay/pretty_printer.cc | 2 +- relay/src/relay/reverse_ad.cc | 226 +++++++++++++-------------- 27 files changed, 881 insertions(+), 410 deletions(-) create mode 100644 relay/python/relay/env.py create mode 100644 relay/python/relay/env.pyi create mode 100644 relay/python/relay/make.pyi create mode 100644 relay/python/relay/operators.py create mode 100644 relay/python/relay/repl.py diff --git a/relay/include/relay/anf.h b/relay/include/relay/anf.h index 54ae117d566e5..315e0269098b4 100644 --- a/relay/include/relay/anf.h +++ b/relay/include/relay/anf.h @@ -41,7 +41,8 @@ struct LetList { Expr plug(const Expr & expr) const { Expr ret = expr; for (auto rit = lets.rbegin(); rit != lets.rend(); rit++) { - ret = LetNode::make(rit->first.lid, rit->second, ret); + // TODO(@jroesch) Fix the type of the binder. + ret = LetNode::make(rit->first.lid, BaseTypeNode::Float(32), rit->second, ret); } return ret; } diff --git a/relay/include/relay/environment.h b/relay/include/relay/environment.h index b8e6bbb0d987d..dfbf90934818d 100644 --- a/relay/include/relay/environment.h +++ b/relay/include/relay/environment.h @@ -6,14 +6,16 @@ #ifndef NNVM_RELAY_ENVIRONMENT_H_ #define NNVM_RELAY_ENVIRONMENT_H_ -#include -#include #include #include "ir.h" namespace nnvm { namespace relay { +// class InternMap { +// // fill me in +// }; + struct Environment; /*! \brief Integer literal `0`, `1000`. */ @@ -21,15 +23,18 @@ class EnvironmentNode : public ValueNode { private: std::unordered_map global_map_; std::unordered_map intrinsic_map_; - // What if there are two globalid with the same name? - // This should be fixed in the python code, - // But I havent take much look into it, so I will just hack around. inline void add_global(const std::string & str, GlobalId id); + inline void add_intrinsic(const std::string & str, IntrinsicId id); + inline void register_primitive(Primitive p); public: + // This map contains all items *except* Primitives. tvm::Map items; + // This map *only* contains primitives. + tvm::Map intrinsics; + EnvironmentNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { @@ -38,10 +43,17 @@ class EnvironmentNode : public ValueNode { TVM_DLL static Environment make(tvm::Map items); + // Add an item to the Enviroment. void add(const Item& item); + GlobalId global_id(const std::string & str); + IntrinsicId intrinsic_id(const std::string & str); + + // We can lookup a GlobalId, IntrinsicId. Item lookup(const GlobalId & id); - Item lookup(const std::string & str); + Item lookup(const IntrinsicId & id); + Item lookup_global(const std::string & str); + Item lookup_intrinsic(const std::string & str); static constexpr const char* _type_key = "nnvm.Environment"; TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); diff --git a/relay/include/relay/evaluator.h b/relay/include/relay/evaluator.h index 382ce97f30819..e7e6a3740df89 100644 --- a/relay/include/relay/evaluator.h +++ b/relay/include/relay/evaluator.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2018 by Contributors * \file evaluator.h - * \brief An evaluator for the Relay IR + * \brief An evaluator for Relay IR */ #ifndef NNVM_RELAY_EVALUATOR_H_ #define NNVM_RELAY_EVALUATOR_H_ @@ -16,58 +16,7 @@ namespace nnvm { namespace relay { -struct Stack { - // Because we know variables will always be unique we can collapse this into a single mapping? - std::vector> stack; - Stack() { - push(); - } - void push() { - stack.push_back({}); - } - void pop() { - stack.pop_back(); - } -}; - -struct StackFrame { - Stack & stack; - explicit StackFrame(Stack & stack) : stack(stack) { - stack.push(); - } - ~StackFrame() { - stack.pop(); - } -}; - -class Evaluator : public ExprFunctor { - public: - Environment env; - Stack stack; - Evaluator(); - Evaluator(Environment env) : env(env) {} - void extend(const LocalId & id, Value v); - Value Eval(const Expr& expr); - Value VisitExpr_(const LocalIdNode* op) override; - Value VisitExpr_(const GlobalIdNode* op) override; - Value VisitExpr_(const IntrinsicIdNode* op) override; - Value VisitExpr_(const FloatLitNode* op) override; - Value VisitExpr_(const BoolLitNode* op) override; - Value VisitExpr_(const IntLitNode* op) override; - Value VisitExpr_(const TensorLitNode* op) override; - Value VisitExpr_(const ProductLitNode* op) override; - Value VisitExpr_(const CastNode* op) override; - Value VisitExpr_(const ParamNode* op) override; - Value VisitExpr_(const FunctionNode* op) override; - Value VisitExpr_(const CallNode* op) override; - Value VisitExpr_(const DebugNode* op) override; - Value VisitExpr_(const UnaryOpNode* op) override; - Value VisitExpr_(const BinaryOpNode* op) override; - Value VisitExpr_(const LetNode* op) override; - Value VisitExpr_(const ZeroNode* op) override; - Value VisitExpr_(const ProjectionNode* op) override; - Value VisitExpr_(const IfNode* op) override; -}; +Value evaluate(Environment env, Expr e); } // namespace relay } // namespace nnvm diff --git a/relay/include/relay/ir/base.h b/relay/include/relay/ir/base.h index f2523e873a497..5fb3c595b9847 100644 --- a/relay/include/relay/ir/base.h +++ b/relay/include/relay/ir/base.h @@ -36,6 +36,10 @@ #define RELAY_DEFINE_TYPE(TypeName, NodeName) \ RELAY_DEFINE_NODE_REF(TypeName, NodeName, Type) +/*! \brief Macro to make it easy to define node ref type given node */ +#define RELAY_DEFINE_ITEM(TypeName, NodeName) \ + RELAY_DEFINE_NODE_REF(TypeName, NodeName, Item) + namespace nnvm { namespace relay { @@ -66,7 +70,7 @@ RELAY_DEFINE_NODE_REF(Span, SpanNode, ::tvm::NodeRef); struct Node : public tvm::Node { public: - Span span; + // Span span; Node() {} }; @@ -107,7 +111,7 @@ class LocalIdNode : public ExprNode { LocalIdNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("span", &span); + // v->Visit("span", &span); v->Visit("name", &name); } @@ -181,7 +185,6 @@ struct Type : public NodeRef { struct ItemNode : Node { public: - GlobalId id; static constexpr const char* _type_key = "nnvm.Item"; TVM_DECLARE_BASE_NODE_INFO(ItemNode, Node); }; diff --git a/relay/include/relay/ir/expr.h b/relay/include/relay/ir/expr.h index 6002a2b9ef79a..f6e429feec74e 100644 --- a/relay/include/relay/ir/expr.h +++ b/relay/include/relay/ir/expr.h @@ -219,16 +219,18 @@ class Function; class FunctionNode : public ExprNode { public: tvm::Array params; + Type ret_type; Expr body; FunctionNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("params", ¶ms); + v->Visit("ret_type", &ret_type); v->Visit("body", &body); } - TVM_DLL static Function make(tvm::Array params, Expr body); + TVM_DLL static Function make(tvm::Array params, Type ret_type, Expr body); static constexpr const char* _type_key = "nnvm.Function"; TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); @@ -351,6 +353,7 @@ class Let; class LetNode : public ExprNode { public: LocalId id; + Type type; Expr value; Expr body; @@ -358,11 +361,12 @@ class LetNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("id", &id); + v->Visit("type", &type); v->Visit("value", &value); v->Visit("body", &body); } - TVM_DLL static Let make(LocalId id, Expr value, Expr body); + TVM_DLL static Let make(LocalId id, Type type, Expr value, Expr body); static constexpr const char* _type_key = "nnvm.Let"; TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode); diff --git a/relay/include/relay/ir/item.h b/relay/include/relay/ir/item.h index 594612dbb4fe6..46f020486f34e 100644 --- a/relay/include/relay/ir/item.h +++ b/relay/include/relay/ir/item.h @@ -19,6 +19,7 @@ class Primitive; class PrimitiveNode : public ItemNode { public: + IntrinsicId id; Type type; PrimitiveNode() {} @@ -28,20 +29,22 @@ class PrimitiveNode : public ItemNode { v->Visit("type", &type); } - TVM_DLL static Primitive make(GlobalId id, Type type); + TVM_DLL static Primitive make(IntrinsicId id, Type type); static constexpr const char* _type_key = "nnvm.Primitive"; TVM_DECLARE_NODE_TYPE_INFO(PrimitiveNode, ItemNode); }; -TVM_DEFINE_NODE_REF(Primitive, PrimitiveNode); +RELAY_DEFINE_ITEM(Primitive, PrimitiveNode); class Defn; class DefnNode : public ItemNode { public: + GlobalId id; Type type; Expr body; + DefnNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { @@ -56,7 +59,7 @@ class DefnNode : public ItemNode { TVM_DECLARE_NODE_TYPE_INFO(DefnNode, ItemNode); }; -TVM_DEFINE_NODE_REF(Defn, DefnNode); +RELAY_DEFINE_ITEM(Defn, DefnNode); } // namespace relay diff --git a/relay/include/relay/ir/type.h b/relay/include/relay/ir/type.h index a0deb7db4422d..9075fdfa252b3 100644 --- a/relay/include/relay/ir/type.h +++ b/relay/include/relay/ir/type.h @@ -145,7 +145,7 @@ class ShapeNode : public TypeNode { ShapeNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { - // v->Visit("dims", &id); + v->Visit("dims", &dims); } TVM_DLL static Shape make(tvm::Array dims); diff --git a/relay/include/relay/ir/value.h b/relay/include/relay/ir/value.h index c15bf04879bd4..f50f38a2ca5dc 100644 --- a/relay/include/relay/ir/value.h +++ b/relay/include/relay/ir/value.h @@ -130,6 +130,26 @@ struct ProductValueNode : ValueNode { RELAY_DEFINE_VALUE(ProductValue, ProductValueNode); +class TensorValue; + +/*! \brief Product literal (x, ... y). */ +struct TensorValueNode : ValueNode { + TVMArrayHandle data; + + TensorValueNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("data", reinterpret_cast(data)); + } + + TVM_DLL static TensorValue make(TVMArrayHandle data); + + static constexpr const char* _type_key = "nnvm.TensorValue"; + TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode); +}; + +RELAY_DEFINE_VALUE(TensorValue, TensorValueNode); + } // namespace relay } // namespace nnvm diff --git a/relay/include/relay/py_functor.h b/relay/include/relay/py_functor.h index 7bd82a75f0daa..fdfede8b84ca8 100644 --- a/relay/include/relay/py_functor.h +++ b/relay/include/relay/py_functor.h @@ -131,7 +131,7 @@ class FunctorNode if (visit_function != nullptr) { return visit_function(fn->params, fn->body, args); } else { - return FunctionNode::make(fn->params, fn->body); + return FunctionNode::make(fn->params, fn->ret_type, fn->body); } } diff --git a/relay/python/relay/env.py b/relay/python/relay/env.py new file mode 100644 index 0000000000000..1bdf9a264cf57 --- /dev/null +++ b/relay/python/relay/env.py @@ -0,0 +1,29 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck +"""A global environment storing everything needed to interpret or compile a Realy program.""" +from typing import Union, Dict +from tvm._ffi.function import _init_api +from .base import register_nnvm_node, NodeBase +from .expr import GlobalId, IntrinsicId, Item + +_init_api("nnvm.relay.env", __name__) + +@register_nnvm_node +class Environment(NodeBase): + """The global Relay environment containing definitions, + primitives, options, and more. + """ + items: Dict[GlobalId, Item] + + def add(self, func: GlobalId) -> None: + return Environment_add(self, func) + + def global_id(self, name: str) -> GlobalId: + return Environment_global_id(self, name) + + def lookup(self, ident: Union[GlobalId, IntrinsicId]) -> Item: + if isinstance(ident, IntrinsicId): + intrin_id = self.intrinsic_id(ident) + return Environment_lookup_intrinsic(self, intrin_id) + else: + global_id = self.global_id(ident) + return Environment_lookup_global(self, global_id) diff --git a/relay/python/relay/env.pyi b/relay/python/relay/env.pyi new file mode 100644 index 0000000000000..1be795affd90e --- /dev/null +++ b/relay/python/relay/env.pyi @@ -0,0 +1,10 @@ +from typing import Union, Tuple, Dict +from nnvm.relay.expr import GlobalId, IntrinsicId, Item + +class Environment(): ... + +def Environment_add(self: Environment, func: GlobalId) -> None: ... +def Environment_global_id(self: Environment, name: str) -> GlobalId: ... +def Environment_intrinsic_id(self: Environment, name: str) -> IntrinsicId: ... +def Environment_lookup_global(self: Environment, id: GlobalId) -> Item: ... +def Environment_lookup_intrinsic(self: Environment, id: GlobalId) -> Item: ... diff --git a/relay/python/relay/expr.py b/relay/python/relay/expr.py index 867db3f892e83..f61eac5fe95ec 100644 --- a/relay/python/relay/expr.py +++ b/relay/python/relay/expr.py @@ -13,29 +13,6 @@ class Span(NodeBase): pass -@register_nnvm_node -class Environment(NodeBase): - """Global Environment - """ - - items: Dict["GlobalId", "Item"] - - def add(self, func: "GlobalId") -> None: - return make.Environment_add(self, func) - - def global_id(self, name: str) -> "GlobalId": - return make.Environment_global_id(self, name) - - def lookup(self, ident: Union["GlobalId", str]) -> "Item": - if isinstance(ident, str): - return make.Environment_lookup_str(self, ident) - else: - return make.Environment_lookup(self, ident) - - def ilookup(self, _: "InstrinsicId") -> None: - assert False - - class Item(NodeBase): """Base class of all expressions. """ @@ -218,6 +195,7 @@ class Param(Expr): @register_nnvm_node class Function(Expr): params: Tuple[Param, ...] + ret_type: Type body: Expr @@ -306,6 +284,7 @@ class FnValue(Value): # unsorted + @register_nnvm_node class If(Expr): guard: Expr @@ -323,6 +302,9 @@ class TensorType(Type): dtype: BaseType shape: Shape +@register_nnvm_node +class TensorValue(Type): + pass def alpha_eq(left: Union[Expr, Type], right: Union[Expr, Type]) -> bool: if isinstance(left, Expr) and isinstance(right, Expr): diff --git a/relay/python/relay/make.py b/relay/python/relay/make.py index f8b9646eb6991..0b7e0d90e5d6d 100644 --- a/relay/python/relay/make.py +++ b/relay/python/relay/make.py @@ -1,22 +1,30 @@ -"""FFI constructors of all relay AST nodes.""" +"""The constructors for all Relay AST nodes exposed from C++. + This module includes MyPy type signatures for all of the + exposed modules. +""" +# pylint: disable-all +# wildcard-import +from typing import Dict, List from tvm._ffi.function import _init_api +from . import expr +from . import env _init_api("nnvm.make", __name__) #pylint: disable=invalid-name -def IntType(bits, lanes=1): +def IntType(bits: int, lanes: int=1) -> expr.Type: """A wrapper for constructing the Int base type.""" return _IntType(bits, lanes) #pylint: disable=invalid-name -def FloatType(bits, lanes=1): +def FloatType(bits: int, lanes: int=1) -> expr.Type: """A wrapper for constructing the Float base type.""" return _FloatType(bits, lanes) #pylint: disable=invalid-name -def BoolType(lanes=1): +def BoolType(lanes: int =1) -> expr.Type: """A wrapper for constructing the Bool base type.""" return _BoolType(lanes) diff --git a/relay/python/relay/make.pyi b/relay/python/relay/make.pyi new file mode 100644 index 0000000000000..b2e65ece66497 --- /dev/null +++ b/relay/python/relay/make.pyi @@ -0,0 +1,29 @@ +from typing import Dict, List +import nnvm.relay.expr as expr +import nnvm.relay.env as env + +def _IntType(bits: int, lanes: int) -> expr.Type: ... +def _FloatType(bits: int, lanes: int) -> expr.Type: ... +def _BoolType(lanes: int) -> expr.Type: ... +# this is annoying, we should talk about how to fix it +def IntType(bits: int, lanes: int=1) -> expr.Type: ... +def FloatType(bits: int, lanes: int=1) -> expr.Type: ... +def BoolType(lanes: int =1) -> expr.Type: ... + +def Environment(items: Dict[expr.GlobalId, expr.Item]) -> env.Environment: ... +def Function(params: List[expr.Param], ret_type: Type, body: expr.Expr) -> expr.Function: ... +def Defn(id: expr.GlobalId, ty: expr.Type, body: expr.Function) -> expr.Defn: ... +def LocalId(name: str) -> expr.Defn: ... +def GlobalId(name: str) -> expr.Defn: ... +def IntrinsicId(name: str) -> expr.Defn: ... +def Let(id: expr.LocalId, ty: expr.Type, value: expr.Expr, body: expr.Expr) -> expr.Defn: ... +def IntLit(value: int) -> expr.IntLit: ... +def FloatLit(value: float) -> expr.FloatLit: ... +def TensorLit(value: List[expr.Expr]) -> expr.TensorLit: ... +def BoolLit(value: bool) -> expr.BoolLit: ... +def String(value: str) -> expr.String: ... +# Can we type this to match the validation code. +def Attributes(attrs: Dict[expr.LocalId, expr.Expr]) -> expr.Attributes: ... +def Call(func: expr.Expr, args: List[expr.Expr], attrs: expr.Attributes) -> expr.Call: ... +def Shape(dims: List[expr.Expr]) -> expr.Shape: ... +def TensorType(dtype: str, shape: expr.Shape) -> expr.Type: ... diff --git a/relay/python/relay/operators.py b/relay/python/relay/operators.py new file mode 100644 index 0000000000000..210e0a742e658 --- /dev/null +++ b/relay/python/relay/operators.py @@ -0,0 +1,30 @@ +"""A module for exposing TVM operators to the evaluator.""" + +from typing import Any +import tvm +# import topi +# import numpy as np +from nnvm.relay.expr import Primitive + +__tgt_host__ = __tgt__ = "llvm" +__relay_tvm_context__ = tvm.context(__tgt__, 0) + + +@tvm.register_func("nnvm.relay.operators.compile_primitive") +def compile_primitive(_: Primitive) -> Any: + pass + # n = tvm.var("n") + # import pdb + # pdb.set_trace() + +# Input = tvm.placeholder((n, ), name='Input') +# Output = topi.tanh(Input) +# schedule = tvm.create_schedule(Output.op) +# tanh = tvm.build(schedule, [Input, Output], tgt, target_host=tgt_host, name="tanh") + +# a = tvm.nd.array(np.random.uniform(size=2).astype(Input.dtype), ctx) +# c = tvm.nd.array(np.zeros(2, dtype=Output.dtype), ctx) +# tanh(a, c) +# # np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) + +# print(c.asnumpy()) diff --git a/relay/python/relay/relay.py b/relay/python/relay/relay.py index c03c44f1917b8..5dd15254efc72 100644 --- a/relay/python/relay/relay.py +++ b/relay/python/relay/relay.py @@ -1,13 +1,18 @@ -#pylint: disable=superfluous-parens +#pylint: disable-all +# LEAVE BRITTANY ALONE +# pylint: disable=superfluous-parens, undefined-variable """A decorator for rewriting Python code into Relay.""" -import ast -from ast import literal_eval +import typed_ast.ast3 as ast import inspect -# from typing import Dict, List +import numpy as np +import tvm +from typing import Dict, List from collections import OrderedDict # pylint: disable=wildcard-import, unused-wildcard-import import nnvm.relay.eval as re -from .make import * +from nnvm.relay.make import * +from . import expr +from .operators import __relay_tvm_context__ # This contains a global environment of all items. __relay_environment__ = Environment({}) @@ -17,51 +22,52 @@ def ensure(expected, value, msg): if value != expected: raise Exception("{}: found {}".format(msg, value)) -def compile_args_to_params(args): - """This function will convert Python arguments into Relay parameters.""" - ensure([], args.defaults, "relay decorator does not support default arguments") - ensure([], args.kw_defaults, - "relay decorator does not support keyword default arguments") - ensure(None, args.kwarg, "relay decorator does not support keyword arguments") - ensure([], args.kwonlyargs, - "relay decorator does not support keyword only arguments") - ensure(None, args.vararg, "relay decorator does not support variadic arguments") - params = [] - for arg in args.args: - ident = arg.arg - # All arguments must have types. - assert arg.annotation is not None - arg_ty = relay_type_from_annotation(arg.annotation) - param = Param(LocalId(ident), arg_ty) - params.append(param) - return params # TODO: Return actual type nodes (i.e., not always bools), once they're added. class TypeToRelay(ast.NodeVisitor): """Compiles a Python type to a Relay type.""" + def __init__(self, to_expr: ast.NodeVisitor) -> None: + self.to_expr = to_expr + def generic_visit(self, node): - return BoolType() + raise Exception("encountered generic case for {}".format(type(node))) + + def to_dtype(self, name: ast.Name) -> expr.BaseType: + # todo fix me + return FloatType(32) + + def to_shape(self, sh: ast.Tuple) -> expr.Shape: + # todo fix me + return Shape([IntLit(1)]) #pylint: disable=invalid-name - def visit_Subscript(self, node): - name = node.value.id - if name == "Tensor": - tensor_params = node.slice.value.elts - assert len(tensor_params) == 2 - # dtype = tensor_params[0].id - # shape = list(map(lambda x: int(x.node), tensor_params[1].elts)) - return BoolType() + def visit_Subscript(self, subscript: ast.Subscript) -> expr.Type: + name = subscript.value + + if isinstance(name, ast.Name) and name.id == "Tensor": + sl = subscript.slice + if isinstance(sl, ast.Index) and isinstance(sl.value, ast.Tuple): + tensor_params = sl.value.elts + assert len(tensor_params) == 2 + # Extract Python representation of `dtype` and `shape`. + dtype, shape = tensor_params + assert isinstance(dtype, ast.Name) + assert isinstance(shape, ast.Tuple) + return TensorType(self.to_dtype(dtype), self.to_shape(shape)) + else: + raise Exception( + "Tensor types must be indexed by a [dtype, shape]") else: raise Exception("Expected \"Tensor\"; got \"%s\"" % name) #pylint: disable=invalid-name - def visit_Attribute(self, node): + def visit_Attribute(self, node: ast.Attribute): typ = node.attr raise Exception("Unknown Relay type \"%s\"" % typ) #pylint: disable=invalid-name - def visit_Name(self, node): + def visit_Name(self, node: ast.Name): """Visit names""" typ = node.id if typ == "int": @@ -76,28 +82,31 @@ def visit_Name(self, node): raise Exception("Unsupported Python builtin \"%s\"" % typ) -def relay_type_from_annotation(annotation): - return TypeToRelay().visit(annotation) - - # We inherit from NodeVisitor to write a pass over the AST. # -# Process a single definition and produce a single Relay Defunc. +# We will process a Python module and transform it into a series +# of Relay items. -class DefToRelay(ast.NodeVisitor): +class ModuleToRelay(ast.NodeVisitor): """Compiles a Python definition to a Relay definition.""" - # local_scopes: List[Dict[LocalId, Expr]] + local_scopes: List[Dict[expr.LocalId, expr.Expr]] + module: ast.Module + to_type_visitor: TypeToRelay - def __init__(self, python_def): + def __init__(self, module: ast.Module) -> None: self.local_scopes = [] - self.python_def = python_def + self.module = module + self.to_type_visitor = TypeToRelay(self) + + def to_type(self, annotation: ast.expr) -> expr.Type: + return self.to_type_visitor.visit(annotation) #pylint: disable=invalid-name - def visit_Name(self, name_node): + def visit_Name(self, name_node: ast.Name) -> expr.Expr: return self.translate_ident(name_node) #pylint: disable=invalid-name - def visit_Return(self, return_node): + def visit_Return(self, return_node: ast.Return) -> expr.Expr: value = return_node.value if value: return self.visit(value) @@ -105,8 +114,8 @@ def visit_Return(self, return_node): raise Exception("return must have a value") #pylint: disable=invalid-name - def visit_Num(self, num_node): - literal = literal_eval(num_node) + def visit_Num(self, num_node: ast.Num) -> expr.Expr: + literal = ast.literal_eval(num_node) if isinstance(literal, int): return IntLit(literal) elif isinstance(literal, float): @@ -115,45 +124,46 @@ def visit_Num(self, num_node): raise Exception("unknown numeric literal") #pylint: disable=invalid-name - def visit_Str(self, str_node): + def visit_Str(self, str_node: ast.Str) -> expr.Expr: s = str_node.s return String(s) #pylint: disable=invalid-name - def visit_List(self, list_node): + def visit_List(self, list_node: ast.List) -> expr.Expr: python_list = list_node.elts relay_list = [self.visit(elt) for elt in python_list] return TensorLit(relay_list) #pylint: disable=invalid-name - def visit_NameConstant(self, nc_node): + def visit_NameConstant(self, nc_node: ast.NameConstant) -> expr.Expr: singleton = nc_node.value if singleton is None: raise Exception("relay decorator does not support None") else: return BoolLit(singleton) - def visit_Call(self, call_node): - """Transform a Python call into a Relay call""" + def visit_Call(self, call_node: ast.Call) -> expr.Expr: + """Transform a Python call into a Relay call.""" func = call_node.func args = call_node.args keywords = call_node.keywords - if isinstance(func, ast.Attribute): + if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): if func.value.id == 'relay': relay_func = IntrinsicId(func.attr) else: raise Exception( - "only supported namespace is relay right now") # improve me + "only supported namespace is relay right now") # improve me else: raise Exception("unsupported calls") # TODO(joshpoll): Handle args relay_args = [self.visit(arg) for arg in args] - relay_attrs = Attributes({LocalId(kwd.arg): self.visit(kwd.value) for kwd in keywords}) + relay_attrs = Attributes( + {LocalId(kwd.arg): self.visit(kwd.value) for kwd in keywords}) return Call(relay_func, relay_args, relay_attrs) - def visit_Assign(self, assign_node): + def compile_assign(self, assign_node: ast.Assign) -> None: targets = assign_node.targets value = assign_node.value assert len(targets) == 1 @@ -161,55 +171,125 @@ def visit_Assign(self, assign_node): rhs = self.visit(value) self.local_scopes[-1][ident] = rhs + def visit_Assign(self, assign_node: ast.Assign) -> expr.Expr: + raise Exception("should not process assignments in this way") + # Need to put types some-where - def compile_stmt_seq_to_body(self, stmts): + def compile_stmt_seq_to_body(self, stmts: List[ast.stmt]) -> expr.Expr: """Compile a sequence of statements into a Relay expression.""" assert stmts ret = stmts[-1] stmts = stmts[0:-1] self.local_scopes.append(OrderedDict([])) for stmt in stmts: - self.visit(stmt) + if isinstance(stmt, ast.Assign): + self.compile_assign(stmt) + else: + raise Exception("hello world") cont = self.visit(ret) scope = self.local_scopes.pop() - for key in reversed(scope): - value = scope[key] - cont = Let(key, value, cont) + for key, value in reversed(scope.items()): + # TOOD(@jroesch): dummy type + cont = Let(key, IntType(32), value, cont) return cont - def translate_ident(self, ident): + def translate_ident(self, ident) -> expr.Expr: # import pdb; pdb.set_trace() return LocalId(ident.id) - def run(self): - """executes visitor""" - func = self.python_def + def compile_args_to_params(self, args): + """This function will convert Python arguments into Relay parameters.""" + ensure([], args.defaults, + "relay decorator does not support default arguments") + ensure([], args.kw_defaults, + "relay decorator does not support keyword default arguments") + ensure(None, args.kwarg, "relay decorator does not support keyword arguments") + ensure([], args.kwonlyargs, + "relay decorator does not support keyword only arguments") + ensure(None, args.vararg, "relay decorator does not support variadic arguments") + params = [] + for arg in args.args: + ident = arg.arg + # All arguments must have types. + assert arg.annotation != None + arg_ty = self.to_type(arg.annotation) + param = Param(LocalId(ident), arg_ty) + params.append(param) + return params + + def visit_FunctionDef(self, func: ast.FunctionDef) -> expr.Expr: name = func.name args = func.args body = func.body - params = compile_args_to_params(args) + returns = func.returns + params = self.compile_args_to_params(args) relay_body = self.compile_stmt_seq_to_body(body) - func = Function(params, relay_body) + # need to compute function signature + assert returns != None + func = Function(params, self.to_type(returns), relay_body) defunc = Defn(GlobalId(name), BoolType(), func) return defunc + def compile_single_function(self) -> expr.Defn: + """Executes visitor on a single function, and ensures there is nothing + else in the top level module. + """ + if len(self.module.body) != 1: + raise Exception("expected exactly one stmt") + + func = self.module.body[0] + + if isinstance(func, ast.FunctionDef): + return self.visit_FunctionDef(func) + else: + raise Exception("expected exactly one function definition") + -def compile_def_to_defn(func): - return DefToRelay(func).run() +def compile_def_to_defn(module: ast.Module) -> expr.Defn: + return ModuleToRelay(module).compile_single_function() def get_env(): return __relay_environment__ -def relay_compile(f): - """Compile the Python function to a Relay function.""" + +def compile_stmt(stmt): + assert False + + +def compile_func(f): + """Compile a single Python function to a Relay function. + + This is to support the Python decorator which transforms + a single function into Relay, we run the normal module + visitor populating the Environment with possibly many + functions. We then retreive said function from the + environment. + """ source = inspect.getsource(f) mod = ast.parse(source) - mod_body = mod.body - assert len(mod_body) == 1 # We only handle one function at a time. - func = mod_body[0] - return compile_def_to_defn(func) + return compile_def_to_defn(mod) + + +def relay_compile(f): + return compile_func(f) + + +def marshal_argument(arg, expected_ty) -> expr.Value: + """Convert Python values into the appropriate types + for the Relay evaluator. + """ + if isinstance(arg, int): + return IntValue(arg) + elif isinstance(arg, float): + return FloatValue(arg) + elif isinstance(arg, np.ndarray): + # np.random.uniform(size=2).astype(Input.dtype) + return TensorValue(tvm.nd.array(arg, __relay_tvm_context__)) + else: + raise Exception("unsupported argument type") + def relay(func): """ @@ -230,8 +310,16 @@ def relay(func): defn = relay_compile(func) get_env().add(defn) - def wrapper(*_): - return re.eval(get_env(), Call(defn.id, [])) + def wrapper(*py_args): + # need to fix type system + # fix me here + args = [] + assert len(py_args) == len(defn.body.params) + for arg, param in zip(py_args, defn.body.params): + args.append(marshal_argument(arg, param.type)) + attrs = Attributes({}) + import pdb; pdb.set_trace() + return re.invoke(get_env(), defn.id, args) return wrapper diff --git a/relay/python/relay/repl.py b/relay/python/relay/repl.py new file mode 100644 index 0000000000000..e240cd400b00b --- /dev/null +++ b/relay/python/relay/repl.py @@ -0,0 +1,24 @@ +# pylint: disable-all +# will add back once stablize +import ast +# import nnvm.relay.relay as relay +from code import InteractiveConsole +from nnvm.relay.relay import relay_compile + +class RelayConsole(InteractiveConsole): + def runsource(self, source, filename="", symbol="single"): + mod = ast.parse(source) + func = mod.body[0] + import pdb + pdb.set_trace() + relay_fn = relay.compile_def_to_defn(func) + return relay_fn + + +def repl(): + console = RelayConsole(None) # might need to fix this + console.interact("relay", "Bye!") + + +if __name__ == "__main__": + repl() diff --git a/relay/src/relay/anf.cc b/relay/src/relay/anf.cc index 6b30dae819e4b..7b4a2bb557979 100644 --- a/relay/src/relay/anf.cc +++ b/relay/src/relay/anf.cc @@ -89,7 +89,8 @@ PartialLocalId ANF::VisitExpr_(const ParamNode* f) { } PartialLocalId ANF::VisitExpr_(const FunctionNode* f) { - return ll.let(FunctionNode::make(f->params, ToANF(f->body))); + // TODO(@jroesch): fix the return type + return ll.let(FunctionNode::make(f->params, f->ret_type, ToANF(f->body))); } PartialLocalId ANF::VisitExpr_(const CallNode* f) { diff --git a/relay/src/relay/environment.cc b/relay/src/relay/environment.cc index 7dd6bfaa64ba7..7ce9dff839c33 100644 --- a/relay/src/relay/environment.cc +++ b/relay/src/relay/environment.cc @@ -13,7 +13,7 @@ using tvm::IRPrinter; using namespace tvm::runtime; struct EnvError : dmlc::Error { - explicit EnvError(const std::string & msg) : dmlc::Error(msg) {} + explicit EnvError(const std::string &msg) : dmlc::Error(msg) {} }; Environment EnvironmentNode::make(tvm::Map items) { @@ -22,36 +22,90 @@ Environment EnvironmentNode::make(tvm::Map items) { return Environment(n); } -void EnvironmentNode::add_global(const std::string & str, GlobalId id) { +// todo throw an error +void EnvironmentNode::add_global(const std::string &str, GlobalId id) { global_map_[str] = id; } -GlobalId EnvironmentNode::global_id(const std::string & str) { +// todo throw an error +void EnvironmentNode::add_intrinsic(const std::string &str, IntrinsicId id) { + intrinsic_map_[str] = id; +} + +/*! + * \brief Get PackedFunction from global registry and + * report error if it does not exist + * \param name The name of the function. + * \return The created PackedFunc. + */ +inline const PackedFunc &GetPackedFunc(const std::string &name) { + const PackedFunc *pf = tvm::runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find function " << name << " in registry"; + return *pf; +} + +void EnvironmentNode::register_primitive(Primitive p) { + const PackedFunc &compile_primitive = + GetPackedFunc("nnvm.relay.operators.compile_primitive"); + compile_primitive(p); + // this->intrinsics.Set(prim->id, prim); +} + +GlobalId EnvironmentNode::global_id(const std::string &str) { if (global_map_.find(str) != global_map_.end()) { - return global_map_.at(str); - } else { - GlobalId id = GlobalIdNode::make(str); - this->add_global(str, id); - return id; - } + return global_map_.at(str); + } else { + GlobalId id = GlobalIdNode::make(str); + this->add_global(str, id); + return id; + } +} + +// todo: factor code into intern map class? +IntrinsicId EnvironmentNode::intrinsic_id(const std::string &str) { + if (intrinsic_map_.find(str) != intrinsic_map_.end()) { + return intrinsic_map_.at(str); + } else { + IntrinsicId id = IntrinsicIdNode::make(str); + this->add_intrinsic(str, id); + return id; + } } void EnvironmentNode::add(const Item &item) { // Should we first check to see if any Global of this name // been allocated and even disallow duplicate hints? - add_global(item->id->name, item->id); - this->items.Set(item->id, item); + if (const PrimitiveNode *p = item.as()) { + auto prim = GetRef(p); + register_primitive(prim); + } else if (const DefnNode *d = item.as()) { + auto def = GetRef(d); + add_global(def->id->name, def->id); + this->items.Set(def->id, item); + } else { + throw; // fix me + } } -Item EnvironmentNode::lookup(const GlobalId & id) { +Item EnvironmentNode::lookup(const GlobalId &id) { if (items.find(id) != items.end()) { - return items.at(id); - } else { - throw EnvError("there is no definition of " + id->name); - } + return items.at(id); + } else { + throw EnvError("there is no definition of " + id->name); + } +} + +Item EnvironmentNode::lookup(const IntrinsicId &id) { + if (intrinsics.find(id) != intrinsics.end()) { + return intrinsics.at(id); + } else { + throw EnvError("there is no definition of " + id->name); + } } -Item EnvironmentNode::lookup(const std::string & str) { +Item EnvironmentNode::lookup_intrinsic(const std::string &str) { throw; } + +Item EnvironmentNode::lookup_global(const std::string &str) { GlobalId id = this->global_id(str); return this->lookup(id); } @@ -62,37 +116,51 @@ TVM_REGISTER_API("nnvm.make.Environment") }); // TODO(jroesch): change the API namespace -TVM_REGISTER_API("nnvm.make.Environment_add") +TVM_REGISTER_API("nnvm.relay.env.Environment_add") .set_body([](TVMArgs args, TVMRetValue *ret) { Environment env = args[0]; Item item = args[1]; env->add(item); }); -TVM_REGISTER_API("nnvm.make.Environment_lookup") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - GlobalId id = args[1]; - *ret = env->lookup(id); - }); +TVM_REGISTER_API("nnvm.relay.env.Environment_lookup_global") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + GlobalId id = args[1]; + *ret = env->lookup(id); + }); -TVM_REGISTER_API("nnvm.make.Environment_lookup_str") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - std::string str = args[1]; - *ret = env->lookup(str); - }); +TVM_REGISTER_API("nnvm.relay.env.Environment_lookup_intrinsic") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + IntrinsicId id = args[1]; + *ret = env->lookup(id); + }); + +TVM_REGISTER_API("nnvm.relay.env.Environment_global_id") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + std::string str = args[1]; + *ret = env->global_id(str); + }); +TVM_REGISTER_API("nnvm.relay.env.Environment_intrinsic_id") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + std::string str = args[1]; + *ret = env->intrinsic_id(str); + }); -TVM_REGISTER_API("nnvm.make.Environment_global_id") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - std::string str = args[1]; - *ret = env->global_id(str); - }); +TVM_REGISTER_API("nnvm.relay.env.Environment_register_primitive") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + std::string str = args[1]; + *ret = env->global_id(str); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const EnvironmentNode *node, tvm::IRPrinter *p) { + .set_dispatch([](const EnvironmentNode *node, + tvm::IRPrinter *p) { p->stream << "EnvironmentNode(" << node->items << ")"; }); diff --git a/relay/src/relay/evaluator.cc b/relay/src/relay/evaluator.cc index 4d36537e982fd..6fc79d0b27cc5 100644 --- a/relay/src/relay/evaluator.cc +++ b/relay/src/relay/evaluator.cc @@ -4,8 +4,9 @@ * \brief An evaluator for the Relay IR. */ -#include #include +#include "nnvm/relay/evaluator.h" +#include "nnvm/relay/type_functor.h" namespace nnvm { namespace relay { @@ -13,7 +14,133 @@ namespace relay { using namespace tvm::runtime; struct EvalError : dmlc::Error { - explicit EvalError(const std::string & msg) : Error(msg) {} + explicit EvalError(const std::string& msg) : Error(msg) {} +}; + +// This is something we likely need for the compiler. + +// struct ZeroInitialize : TypeFunctor { +// public: +// Value VisitType_(const BaseTypeNode* op) override { +// if (!op->type.is_scalar()) { +// throw EvalError("can not initialize vectorized values"); +// } + +// // TODO(@jroesch): support variable size values +// if (op->type.is_float()) { +// return FloatLitNode::make(0.0); +// } else if (op->type.is_int()) { +// return IntLitNode::make(0); +// } else if (op->type.is_bool()) { +// return BoolLitNode::make(false); +// } else { +// throw EvalError("unsupported base type should be scalar float, int, +// bool"); +// } +// } + +// Value VisitType_(const TypeVarNode* op) override { +// throw EvalError("It is impossible to zero initialize a type var."); +// } + +// Value VisitType_(const TypeIdNode* op) override { +// throw EvalError("It is impossible to zero initialize a type id."); +// } + +// Value VisitType_(const TypeQuantifierNode* op) override { +// throw EvalError("It is impossible to zero initialize a quantifier."); +// } + +// Value VisitType_(const TensorTypeNode* op) override { +// throw EvalError("TODO: need to initialize Tensor types"); +// } +// }; + +// Value zero_initialize(const Type & t) { +// return ZeroInitialize().VisitType(t); +// } + +struct Frame { + // In the efficient version this should seperate args, locals, and return + // address. + std::unordered_map locals; + Frame(tvm::Array params, tvm::Array args) { + if (params.size() != args.size()) { + // this should be an assertion at some point, and enforced by code + // generator. + std::cout << params.size() << std::endl; + std::cout << args.size() << std::endl; + throw EvalError("parameter and argument size mismatch "); + } + + for (auto i = 0; i < params.size(); i++) { + auto param = params[i]; + auto arg = args[i]; + this->insert_local(param->id, arg); + } + } + + // We should probably create a new frame for each sequence of bindings, + // instead of this method. + void insert_local(const LocalId& id, Value v) { locals[id] = v; } +}; + +struct Stack { + std::vector frames; + Stack() : frames() {} + + Frame& current_frame() { return frames.back(); } + + void push(Frame fr) { frames.push_back(fr); } + + void pop() { frames.pop_back(); } + + Value lookup(const LocalId& local) { + for (auto frame = frames.rbegin(); frame != frames.rend(); frame++) { + if (frame->locals.find(local) != frame->locals.end()) { + return frame->locals.at(local); + } + } + throw dmlc::Error("internal error could not find"); + } +}; + +class Evaluator : public ExprFunctor { + public: + Environment env; + Stack stack; + Evaluator(); + Evaluator(Environment env) : env(env) {} + + void extend(const LocalId& id, Value v) { + this->stack.current_frame().insert_local(id, v); + } + + inline Value lookup(const LocalId& local) { + return this->stack.lookup(local); + } + + Value invoke(const FnValue& func, tvm::Array& args); + Value Eval(const Expr& expr); + Value VisitExpr_(const LocalIdNode* op) override; + Value VisitExpr_(const GlobalIdNode* op) override; + Value VisitExpr_(const IntrinsicIdNode* op) override; + Value VisitExpr_(const FloatLitNode* op) override; + Value VisitExpr_(const BoolLitNode* op) override; + Value VisitExpr_(const IntLitNode* op) override; + Value VisitExpr_(const TensorLitNode* op) override; + Value VisitExpr_(const ProductLitNode* op) override; + Value VisitExpr_(const CastNode* op) override; + Value VisitExpr_(const ParamNode* op) override; + Value VisitExpr_(const FunctionNode* op) override; + Value VisitExpr_(const CallNode* op) override; + Value VisitExpr_(const DebugNode* op) override; + Value VisitExpr_(const UnaryOpNode* op) override; + Value VisitExpr_(const BinaryOpNode* op) override; + Value VisitExpr_(const LetNode* op) override; + Value VisitExpr_(const ZeroNode* op) override; + Value VisitExpr_(const ProjectionNode* op) override; + Value VisitExpr_(const IfNode* op) override; }; Evaluator::Evaluator() : env() {} @@ -22,12 +149,16 @@ Value Evaluator::Eval(const Expr& expr) { return this->operator()(expr); } Value Evaluator::VisitExpr_(const LocalIdNode* local_node) { LocalId local_id = GetRef(local_node); - for (auto frame = this->stack.stack.rbegin(); frame != this->stack.stack.rend(); - frame++) { - if (frame->find(local_id) != frame->end()) { - return frame->at(local_id); + for (auto frame = this->stack.frames.rbegin(); + frame != this->stack.frames.rend(); frame++) { + if (frame->locals.find(local_id) != frame->locals.end()) { + return frame->locals.at(local_id); } } + + // missing local, there is also a bug in using environment code from python + // trace evaluator, how to return tensors across ask tq, + // a lot more work to do, should crash on the line with // fix me throw EvalError("Reference can not be found"); } @@ -45,15 +176,15 @@ Value Evaluator::VisitExpr_(const IntrinsicIdNode* op) { } Value Evaluator::VisitExpr_(const FloatLitNode* op) { - return Value(FloatValueNode::make(op->value).node_); // fix this reboxing + return FloatValueNode::make(op->value); } Value Evaluator::VisitExpr_(const BoolLitNode* op) { - return Value(BoolValueNode::make(op->value).node_); // fix this reboxing + return BoolValueNode::make(op->value); } Value Evaluator::VisitExpr_(const IntLitNode* op) { - return Value(IntValueNode::make(op->value).node_); // fix this reboxing + return IntValueNode::make(op->value); } Value Evaluator::VisitExpr_(const TensorLitNode* op) { @@ -64,9 +195,8 @@ Value Evaluator::VisitExpr_(const ProductLitNode* op) { if (op->fields.size() != 2) { throw EvalError("Unsupported Size"); } - return Value(ProductValueNode::make({ - Eval(Expr(op->fields[0].node_)), - Eval(Expr(op->fields[1].node_))})); + return ProductValueNode::make( + {Eval(Expr(op->fields[0].node_)), Eval(Expr(op->fields[1].node_))}); } Value Evaluator::VisitExpr_(const CastNode* op) { @@ -80,12 +210,55 @@ Value Evaluator::VisitExpr_(const ParamNode* op) { } Value Evaluator::VisitExpr_(const FunctionNode* op) { - throw EvalError("Function NYI"); + // TODO(@jroesch): compute free variables here and store them + // in closure value. + tvm::Map free_vars; + return FnValueNode::make(free_vars, GetRef(op)); +} + +// An efficient interpreter needs a faster way to access args, relative to stack +// pointer? +Value Evaluator::invoke(const FnValue& closure, tvm::Array& args) { + // TODO(@jroesch): map environment into a Frame + // + // In the VM we should support building a frame from free vars and parameters + // we should compute the frame layout statically. + if (closure->env.size() != 0) { + throw EvalError("Do not support Closures with non-empty environments yet."); + } + + // Get a reference to the function inside the closure. + auto func = closure->func; + + // Allocate a frame with the parameters, todo add free variables. + Frame fr(func->params, args); + + // Add a frame to the stack. + this->stack.push(fr); + + // Visit the body of the function with the new frame in scope. + auto ret = this->VisitExpr(func->body); + + // Clear the frame. + this->stack.pop(); + + return ret; } Value Evaluator::VisitExpr_(const CallNode* op) { - // auto fn = this->VisitExpr(op->fn); - throw EvalError("Call NYI"); + std::cout << "Inside CallNode" << std::endl; + + tvm::Array args; + for (auto arg : op->args) { + args.push_back(this->VisitExpr(arg)); + } + + auto fn_val = this->VisitExpr(op->fn); + if (const FnValueNode* closure = fn_val.as()) { + return this->invoke(GetRef(closure), args); + } else { + throw EvalError("Type error, expected function value in the call position"); + } } Value Evaluator::VisitExpr_(const DebugNode* op) { @@ -96,14 +269,14 @@ Value Evaluator::VisitExpr_(const UnaryOpNode* op) { switch (op->op) { case UOp::NEG: { if (auto value = this->Eval(op->node).as()) { - return Value(FloatValueNode::make(-value->value).node_); + return FloatValueNode::make(-value->value); } else { throw EvalError("cannot eval"); } } case UOp::SQ: { if (auto value = this->Eval(op->node).as()) { - return Value(FloatValueNode::make(value->value * value->value).node_); + return FloatValueNode::make(value->value * value->value); } else { throw EvalError("cannot eval"); } @@ -119,7 +292,7 @@ Value Evaluator::VisitExpr_(const BinaryOpNode* op) { auto left = Eval(op->left).as(); auto right = Eval(op->right).as(); if (left && right) { - return Value(FloatValueNode::make(left->value + right->value).node_); + return FloatValueNode::make(left->value + right->value); } else { throw EvalError("cannot eval"); } @@ -128,7 +301,7 @@ Value Evaluator::VisitExpr_(const BinaryOpNode* op) { auto left = Eval(op->left).as(); auto right = Eval(op->right).as(); if (left && right) { - return Value(FloatValueNode::make(left->value * right->value)); + return FloatValueNode::make(left->value * right->value); } else { throw EvalError("cannot eval"); } @@ -137,7 +310,7 @@ Value Evaluator::VisitExpr_(const BinaryOpNode* op) { auto left = Eval(op->left).as(); auto right = Eval(op->right).as(); if (left && right) { - return Value(FloatValueNode::make(left->value - right->value)); + return FloatValueNode::make(left->value - right->value); } else { throw EvalError("cannot eval"); } @@ -146,7 +319,7 @@ Value Evaluator::VisitExpr_(const BinaryOpNode* op) { auto left = Eval(op->left).as(); auto right = Eval(op->right).as(); if (left && right) { - return Value(FloatValueNode::make(left->value / right->value)); + return FloatValueNode::make(left->value / right->value); } else { throw EvalError("cannot eval"); } @@ -157,21 +330,22 @@ Value Evaluator::VisitExpr_(const BinaryOpNode* op) { } Value Evaluator::VisitExpr_(const LetNode* op) { - stack.stack.back()[op->id] = Eval(op->value); + auto value = Eval(op->value); + this->extend(op->id, value); return Eval(op->body); } Value Evaluator::VisitExpr_(const ZeroNode* op) { - throw EvalError("cannot eval"); + throw EvalError("Currently we do not support evaluating a Zero node."); } Value Evaluator::VisitExpr_(const ProjectionNode* op) { - throw EvalError("cannot eval"); + throw EvalError("Currently we do not support evaluting a Projection node."); } Value Evaluator::VisitExpr_(const IfNode* op) { Value v = this->VisitExpr(op->guard); - if (const BoolValueNode *bv = v.as()) { + if (const BoolValueNode* bv = v.as()) { if (bv->value) { return this->VisitExpr(op->true_b); } else { @@ -182,6 +356,11 @@ Value Evaluator::VisitExpr_(const IfNode* op) { } } +Value evaluate(Environment env, Expr e) { + Evaluator eval(env); + return eval.Eval(e); +} + TVM_REGISTER_API("nnvm.eval.eval").set_body([](TVMArgs args, TVMRetValue* ret) { Environment env = args[0]; Expr expr = args[1]; @@ -189,5 +368,21 @@ TVM_REGISTER_API("nnvm.eval.eval").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = eval.Eval(expr); }); +TVM_REGISTER_API("nnvm.eval.invoke") + .set_body([](TVMArgs args, TVMRetValue* ret) { + // tood maybe tweak interface + Environment env = args[0]; + GlobalId id = args[1]; + tvm::Array relay_args = args[2]; + Evaluator eval(env); + auto fn_val = eval.VisitExpr(id); + if (const FnValueNode* closure = fn_val.as()) { + *ret = eval.invoke(GetRef(closure), relay_args); + } else { + throw EvalError( + "Type error, expected function value in the call position"); + } + }); + } // namespace relay } // namespace nnvm diff --git a/relay/src/relay/forward_ad.cc b/relay/src/relay/forward_ad.cc index 694a7f28e12b4..232b7d9039582 100644 --- a/relay/src/relay/forward_ad.cc +++ b/relay/src/relay/forward_ad.cc @@ -77,7 +77,8 @@ Expr ForwardAD::VisitExpr_(const DebugNode* op) { Expr ForwardAD::VisitExpr_(const UnaryOpNode* op) { Expr adop = AD(op->node); LocalId ref = gf->fresh(); - return LetNode::make(ref, adop, [&]() { + // TODO(@jroesch) fix me + return LetNode::make(ref, BaseTypeNode::Float(32), adop, [&]() { auto o = GetField(ref, 0); auto d = GetField(ref, 1); switch (op->op) { @@ -96,13 +97,17 @@ Expr ForwardAD::VisitExpr_(const UnaryOpNode* op) { Expr ForwardAD::VisitExpr_(const BinaryOpNode* op) { Expr left = AD(op->left); LocalId lref = gf->fresh(); - return LetNode::make(lref, left, [&]() { + + // TODO(@jroesch) fix me + return LetNode::make(lref, BaseTypeNode::Float(32), left, [&]() { auto lo = GetField(lref, 0); auto ld = GetField(lref, 1); Expr right = AD(op->right); LocalId rref = gf->fresh(); - return LetNode::make(rref, right, [&]() { + + // TODO(@jroesch) fix me + return LetNode::make(rref, BaseTypeNode::Float(32), right, [&]() { auto ro = GetField(rref, 0); auto rd = GetField(rref, 1); @@ -130,7 +135,8 @@ Expr ForwardAD::VisitExpr_(const BinaryOpNode* op) { } Expr ForwardAD::VisitExpr_(const LetNode* op) { - return LetNode::make(op->id, AD(op->value), AD(op->body)); + // TODO(@jroesch) fix me + return LetNode::make(op->id, BaseTypeNode::Float(32), AD(op->value), AD(op->body)); } TVM_REGISTER_API("nnvm.forward_ad.forward_ad").set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/relay/src/relay/ir.cc b/relay/src/relay/ir.cc index 73a1b4da6ab88..02969bb86b1c2 100644 --- a/relay/src/relay/ir.cc +++ b/relay/src/relay/ir.cc @@ -13,47 +13,41 @@ namespace nnvm { namespace relay { -Expr GetField(const Expr & x, int index) { +Expr GetField(const Expr& x, int index) { return ProjectionNode::make(x, index); } -Expr Pair(const Expr & l, const Expr & r) { - return ProductLitNode::make({l, r}); -} +Expr Pair(const Expr& l, const Expr& r) { return ProductLitNode::make({l, r}); } -Expr Neg(const Expr & x) { - return UnaryOpNode::make(UOp::NEG, x); -} +Expr Neg(const Expr& x) { return UnaryOpNode::make(UOp::NEG, x); } -Expr Square(const Expr & x) { - return UnaryOpNode::make(UOp::SQ, x); -} +Expr Square(const Expr& x) { return UnaryOpNode::make(UOp::SQ, x); } -Expr Plus(const Expr & l, const Expr & r) { +Expr Plus(const Expr& l, const Expr& r) { return BinaryOpNode::make(BinOp::PLUS, l, r); } -Expr Sub(const Expr & l, const Expr & r) { +Expr Sub(const Expr& l, const Expr& r) { return BinaryOpNode::make(BinOp::SUB, l, r); } -Expr Mul(const Expr & l, const Expr & r) { +Expr Mul(const Expr& l, const Expr& r) { return BinaryOpNode::make(BinOp::MUL, l, r); } -Expr Div(const Expr & l, const Expr & r) { +Expr Div(const Expr& l, const Expr& r) { return BinaryOpNode::make(BinOp::DIV, l, r); } -Expr Float(double d) { - return FloatLitNode::make(d); -} +Expr Float(double d) { return FloatLitNode::make(d); } -Expr Seq(const Expr & l, const Expr & r) { +Expr Seq(const Expr& l, const Expr& r) { return CallNode::make( - FunctionNode::make({ParamNode::make(LocalIdNode::make("_"), ProductNode::make({}))}, r), - {l}, - AttributesNode::make({})); + // TODO(@jroesch): fix type annotation + FunctionNode::make( + {ParamNode::make(LocalIdNode::make("_"), ProductNode::make({}))}, + BaseTypeNode::Float(32), r), + {l}, AttributesNode::make({})); } } // namespace relay diff --git a/relay/src/relay/ir/expr.cc b/relay/src/relay/ir/expr.cc index 29d86ecd0f905..175a61fb7f23a 100644 --- a/relay/src/relay/ir/expr.cc +++ b/relay/src/relay/ir/expr.cc @@ -226,16 +226,17 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "ParamNode(" << node->id << ", " << node->type << ")"; }); -Function FunctionNode::make(tvm::Array params, Expr body) { +Function FunctionNode::make(tvm::Array params, Type ret_type, Expr body) { std::shared_ptr n = std::make_shared(); n->params = std::move(params); + n->ret_type = std::move(ret_type); n->body = std::move(body); return Function(n); } TVM_REGISTER_API("nnvm.make.Function") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = FunctionNode::make(args[0], args[1]); + *ret = FunctionNode::make(args[0], args[1], args[2]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -384,16 +385,17 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "ZeroNode(" << node->type << ")"; }); -Let LetNode::make(LocalId id, Expr value, Expr body) { +Let LetNode::make(LocalId id, Type type, Expr value, Expr body) { std::shared_ptr n = std::make_shared(); n->id = std::move(id); + n->type = std::move(type); n->value = std::move(value); n->body = std::move(body); return Let(n); } TVM_REGISTER_API("nnvm.make.Let").set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = LetNode::make(args[0], args[1], args[2]); + *ret = LetNode::make(args[0], args[1], args[2], args[3]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) diff --git a/relay/src/relay/ir/item.cc b/relay/src/relay/ir/item.cc index edb36042fa1e4..f36ee3fe3bcac 100644 --- a/relay/src/relay/ir/item.cc +++ b/relay/src/relay/ir/item.cc @@ -13,7 +13,7 @@ using tvm::IRPrinter; using namespace tvm::runtime; -Primitive PrimitiveNode::make(GlobalId id, Type type) { +Primitive PrimitiveNode::make(IntrinsicId id, Type type) { std::shared_ptr n = std::make_shared(); n->id = id; n->type = type; diff --git a/relay/src/relay/ir/value.cc b/relay/src/relay/ir/value.cc index a5831ac4222df..b0542ce9cdff3 100644 --- a/relay/src/relay/ir/value.cc +++ b/relay/src/relay/ir/value.cc @@ -1,11 +1,11 @@ /*! * Copyright (c) 2018 by Contributors - * \file node.cc - * \brief Relay node data structure. + * \file value.cc + * \brief Relay evaluator values. */ -#include -#include - +#include "nnvm/relay/ir.h" +#include "tvm/ir_functor.h" +#include "tvm/runtime/util.h" namespace nnvm { namespace relay { @@ -19,14 +19,15 @@ IntValue IntValueNode::make(int value) { } TVM_REGISTER_API("nnvm.make.IntValue") - .set_body([](TVMArgs args, - TVMRetValue *ret) { *ret = IntValueNode::make(args[0]); }); + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = IntValueNode::make(args[0]); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const IntValueNode *node, tvm::IRPrinter *p) { - p->stream << "IntValueNode(" << node->value << ")"; - }); + p->stream << "IntValueNode(" << node->value << ")"; + }); BoolValue BoolValueNode::make(bool value) { std::shared_ptr n = std::make_shared(); @@ -35,14 +36,15 @@ BoolValue BoolValueNode::make(bool value) { } TVM_REGISTER_API("nnvm.make.BoolValue") - .set_body([](TVMArgs args, - TVMRetValue *ret) { *ret = BoolValueNode::make(args[0]); }); + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = BoolValueNode::make(args[0]); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const BoolValueNode *node, tvm::IRPrinter *p) { - p->stream << "BoolValueNode(" << node->value << ")"; - }); + p->stream << "BoolValueNode(" << node->value << ")"; + }); FloatValue FloatValueNode::make(double value) { std::shared_ptr n = std::make_shared(); @@ -51,14 +53,15 @@ FloatValue FloatValueNode::make(double value) { } TVM_REGISTER_API("nnvm.make.FloatValue") - .set_body([](TVMArgs args, - TVMRetValue *ret) { *ret = FloatValueNode::make(args[0]); }); + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FloatValueNode::make(args[0]); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FloatValueNode *node, tvm::IRPrinter *p) { - p->stream << "FloatValueNode(" << node->value << ")"; - }); + p->stream << "FloatValueNode(" << node->value << ")"; + }); FnValue FnValueNode::make(tvm::Map env, Function func) { std::shared_ptr n = std::make_shared(); @@ -69,13 +72,13 @@ FnValue FnValueNode::make(tvm::Map env, Function func) { TVM_REGISTER_API("nnvm.make.FnValue") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = FnValueNode::make(args[0], args[1]); - }); + *ret = FnValueNode::make(args[0], args[1]); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FnValueNode *node, tvm::IRPrinter *p) { - p->stream << "FnValueNode(todo)"; - }); + p->stream << "FnValueNode(todo)"; + }); ProductValue ProductValueNode::make(tvm::Array value) { std::shared_ptr n = std::make_shared(); @@ -84,14 +87,32 @@ ProductValue ProductValueNode::make(tvm::Array value) { } TVM_REGISTER_API("nnvm.make.ProductValue") - .set_body([](TVMArgs args, - TVMRetValue *ret) { *ret = ProductValueNode::make(args[0]); }); + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ProductValueNode::make(args[0]); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ProductValueNode *node, tvm::IRPrinter *p) { - p->stream << "ProductValueNode(" << node->fields << ")"; - }); + p->stream << "ProductValueNode(" << node->fields << ")"; + }); + +TensorValue TensorValueNode::make(DLTensor *data) { + std::shared_ptr n = std::make_shared(); + n->data = data; + return TensorValue(n); +} + +TVM_REGISTER_API("nnvm.make.TensorValue") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TensorValueNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TensorValueNode *node, + tvm::IRPrinter *p) { + p->stream << "TensorValueNode(TODO)"; + }); } // namespace relay } // namespace nnvm diff --git a/relay/src/relay/pretty_printer.cc b/relay/src/relay/pretty_printer.cc index b7a7d661b1c3f..f616f6555f588 100644 --- a/relay/src/relay/pretty_printer.cc +++ b/relay/src/relay/pretty_printer.cc @@ -28,7 +28,7 @@ void PrettyPrinter::VisitExpr_(const LocalIdNode* local, ostream & os) { } void PrettyPrinter::VisitExpr_(const GlobalIdNode* op, ostream & os) { - Item i = this->env->lookup(op->name); + Item i = this->env->lookup(GetRef(op)); if (const DefnNode* def = i.as()) { this->PrettyPrint(def->body, os); } else { diff --git a/relay/src/relay/reverse_ad.cc b/relay/src/relay/reverse_ad.cc index 38e7975208421..8e766dcc3440b 100644 --- a/relay/src/relay/reverse_ad.cc +++ b/relay/src/relay/reverse_ad.cc @@ -13,34 +13,29 @@ namespace relay { using namespace tvm::runtime; struct ReverseADError : dmlc::Error { - explicit ReverseADError(const std::string & msg) : Error(msg) {} + explicit ReverseADError(const std::string& msg) : Error(msg) {} }; -Type Unit() { - return ProductNode::make({}); -} +Type Unit() { return ProductNode::make({}); } Type BackProper() { return TypeQuantifierNode::make(TypeIdNode::make("_"), Unit(), Unit()); } -struct ReverseTypeFunctor : TypeFunctor { +struct ReverseTypeFunctor : TypeFunctor { Type VisitType_(const BaseTypeNode* op) override { Type optype(op->GetNodeRef().node_); if (op->type.is_float()) { - return ProductNode::make({optype, RefTypeNode::make(optype), BackProper()}); + return ProductNode::make( + {optype, RefTypeNode::make(optype), BackProper()}); } throw ReverseADError("Unknown Type"); } }; -Type ReverseType(const Type & t) { - return ReverseTypeFunctor()(t); -} +Type ReverseType(const Type& t) { return ReverseTypeFunctor()(t); } -Expr ReverseAD::AD(const Expr& expr) { - return (*this)(expr); -} +Expr ReverseAD::AD(const Expr& expr) { return (*this)(expr); } Expr ReverseAD::VisitExpr_(const LocalIdNode* local) { return Expr(local->GetNodeRef().node_); @@ -55,10 +50,11 @@ Expr ReverseAD::VisitExpr_(const IntrinsicIdNode* op) { } Expr ReverseAD::VisitExpr_(const FloatLitNode* op) { - return ProductLitNode::make({ - Expr(op->GetNodeRef().node_), - RefNode::make(FloatLitNode::make(0)), - FunctionNode::make({}, ProductLitNode::make({}))}); + return ProductLitNode::make({Expr(op->GetNodeRef().node_), + RefNode::make(FloatLitNode::make(0)), + // TODO(@jroesch): fix return type + FunctionNode::make({}, BaseTypeNode::Float(32), + ProductLitNode::make({}))}); } Expr ReverseAD::VisitExpr_(const BoolLitNode* op) { @@ -71,7 +67,7 @@ Expr ReverseAD::VisitExpr_(const IntLitNode* op) { Expr ReverseAD::VisitExpr_(const ProductLitNode* op) { std::vector vec; - for (const Expr & e : op->fields) { + for (const Expr& e : op->fields) { vec.push_back(AD(e)); } return ProductLitNode::make(vec); @@ -83,10 +79,11 @@ Expr ReverseAD::VisitExpr_(const TensorLitNode* op) { Expr ReverseAD::VisitExpr_(const FunctionNode* op) { std::vector params; - for (const Param & param : op->params) { + for (const Param& param : op->params) { params.push_back(ParamNode::make(param->id, ReverseType(param->type))); } - return FunctionNode::make(params, AD(op->body)); + // TODO(@jroesch): fix return type + return FunctionNode::make(params, BaseTypeNode::Float(32), AD(op->body)); } Expr ReverseAD::VisitExpr_(const DebugNode* op) { @@ -103,131 +100,126 @@ Expr ReverseAD::VisitExpr_(const ParamNode* op) { Expr ReverseAD::VisitExpr_(const CallNode* op) { std::vector args; - for (const Expr & arg : op->args) { + for (const Expr& arg : op->args) { args.push_back(AD(arg)); } tvm::Map attr(op->attrs->attributes); - for (const std::pair & p : op->attrs->attributes) { + for (const std::pair& p : op->attrs->attributes) { attr.Set(p.first, AD(attr[p.first])); } return CallNode::make(AD(op->fn), args, AttributesNode::make(attr)); } -Expr SVar() { - return RefNode::make(Float(0)); -} +Expr SVar() { return RefNode::make(Float(0)); } -Expr CallBP(const Expr & expr) { +Expr CallBP(const Expr& expr) { return CallNode::make(expr, {}, AttributesNode::make({})); } -Expr MkBP(const Expr & expr) { - return FunctionNode::make({}, expr); +Expr MkBP(const Expr& expr) { + // TODO(@jroesch): fix return type + return FunctionNode::make({}, BaseTypeNode::Float(32), expr); } Expr ReverseAD::VisitExpr_(const UnaryOpNode* op) { - return LetList::with([&](LetList & ll) { - auto adop = ll.let(AD(op->node)); - auto val = ll.let(GetField(adop.lid, 0)); - auto ref = ll.let(GetField(adop.lid, 1)); - auto bp = ll.let(GetField(adop.lid, 2)); - auto myref = ll.let(SVar()); - switch (op->op) { - case UOp::NEG: { - return ProductLitNode::make({ - Neg(val.lid), - myref.lid, - MkBP(Seq( - SetRefNode::make( - ref.lid, - Sub(ValRefNode::make(ref.lid), ValRefNode::make(myref.lid))), - CallBP(bp.lid)))}); - } - case UOp::SQ: { - return ProductLitNode::make({ - Square(val.lid), - myref.lid, - MkBP(Seq(SetRefNode::make( - ref.lid, - Plus( - ValRefNode::make(ref.lid), - Mul(ValRefNode::make(myref.lid), Mul(Float(2), val.lid)))), CallBP(bp.lid)))}); - } - default: - throw ReverseADError("UnaryOpNode"); - } - }, gf); + return LetList::with( + [&](LetList& ll) { + auto adop = ll.let(AD(op->node)); + auto val = ll.let(GetField(adop.lid, 0)); + auto ref = ll.let(GetField(adop.lid, 1)); + auto bp = ll.let(GetField(adop.lid, 2)); + auto myref = ll.let(SVar()); + switch (op->op) { + case UOp::NEG: { + return ProductLitNode::make( + {Neg(val.lid), myref.lid, + MkBP(Seq(SetRefNode::make(ref.lid, + Sub(ValRefNode::make(ref.lid), + ValRefNode::make(myref.lid))), + CallBP(bp.lid)))}); + } + case UOp::SQ: { + return ProductLitNode::make( + {Square(val.lid), myref.lid, + MkBP(Seq(SetRefNode::make(ref.lid, + Plus(ValRefNode::make(ref.lid), + Mul(ValRefNode::make(myref.lid), + Mul(Float(2), val.lid)))), + CallBP(bp.lid)))}); + } + default: + throw ReverseADError("UnaryOpNode"); + } + }, + gf); } Expr ReverseAD::VisitExpr_(const BinaryOpNode* op) { - return LetList::with([&](LetList & ll) { - auto left = ll.let(AD(op->left)); - auto leftval = ll.let(GetField(left.lid, 0)); - auto leftref = ll.let(GetField(left.lid, 1)); - auto leftbp = ll.let(GetField(left.lid, 2)); - auto right = ll.let(AD(op->right)); - auto rightval = ll.let(GetField(right.lid, 0)); - auto rightref = ll.let(GetField(right.lid, 1)); - auto rightbp = ll.let(GetField(right.lid, 2)); - auto myref = ll.let(SVar()); - auto WHATAMIGONNADOOOO = [&](const Expr & l, const Expr & r) { - return MkBP(Seq( - SetRefNode::make(leftref.lid, - Plus(ValRefNode::make(leftref.lid), Mul(ValRefNode::make(myref.lid), l))), Seq( - SetRefNode::make(rightref.lid, - Plus(ValRefNode::make(rightref.lid), Mul(ValRefNode::make(myref.lid), r))), Seq( - CallBP(leftbp.lid), - CallBP(rightbp.lid))))); - }; - switch (op->op) { - case BinOp::PLUS: { - return ProductLitNode::make({ - Plus(leftval.lid, rightval.lid), - myref.lid, - WHATAMIGONNADOOOO(Float(1), Float(1)) - }); - } - case BinOp::MUL: { - return ProductLitNode::make({ - Mul(leftval.lid, rightval.lid), - myref.lid, - WHATAMIGONNADOOOO(rightval.lid, leftval.lid) - }); - } - case BinOp::SUB: { - return ProductLitNode::make({ - Sub(leftval.lid, rightval.lid), - myref.lid, - WHATAMIGONNADOOOO(Float(1), Float(-1)) - }); - } - case BinOp::DIV: { - return ProductLitNode::make({ - Div(leftval.lid, rightval.lid), - myref.lid, - WHATAMIGONNADOOOO( - Div(Float(1), rightval.lid), - Div(Float(-1), Square(rightval.lid))) - }); - } - default: - throw ReverseADError("BinaryOpNode"); - } - }, gf); + return LetList::with( + [&](LetList& ll) { + auto left = ll.let(AD(op->left)); + auto leftval = ll.let(GetField(left.lid, 0)); + auto leftref = ll.let(GetField(left.lid, 1)); + auto leftbp = ll.let(GetField(left.lid, 2)); + auto right = ll.let(AD(op->right)); + auto rightval = ll.let(GetField(right.lid, 0)); + auto rightref = ll.let(GetField(right.lid, 1)); + auto rightbp = ll.let(GetField(right.lid, 2)); + auto myref = ll.let(SVar()); + auto WHATAMIGONNADOOOO = [&](const Expr& l, const Expr& r) { + return MkBP(Seq( + SetRefNode::make(leftref.lid, + Plus(ValRefNode::make(leftref.lid), + Mul(ValRefNode::make(myref.lid), l))), + Seq(SetRefNode::make(rightref.lid, + Plus(ValRefNode::make(rightref.lid), + Mul(ValRefNode::make(myref.lid), r))), + Seq(CallBP(leftbp.lid), CallBP(rightbp.lid))))); + }; + switch (op->op) { + case BinOp::PLUS: { + return ProductLitNode::make( + {Plus(leftval.lid, rightval.lid), myref.lid, + WHATAMIGONNADOOOO(Float(1), Float(1))}); + } + case BinOp::MUL: { + return ProductLitNode::make( + {Mul(leftval.lid, rightval.lid), myref.lid, + WHATAMIGONNADOOOO(rightval.lid, leftval.lid)}); + } + case BinOp::SUB: { + return ProductLitNode::make( + {Sub(leftval.lid, rightval.lid), myref.lid, + WHATAMIGONNADOOOO(Float(1), Float(-1))}); + } + case BinOp::DIV: { + return ProductLitNode::make( + {Div(leftval.lid, rightval.lid), myref.lid, + WHATAMIGONNADOOOO(Div(Float(1), rightval.lid), + Div(Float(-1), Square(rightval.lid)))}); + } + default: + throw ReverseADError("BinaryOpNode"); + } + }, + gf); } Expr ReverseAD::VisitExpr_(const LetNode* op) { - return LetNode::make(op->id, AD(op->value), AD(op->body)); + // TODO(@jroesch): fix return type + return LetNode::make(op->id, BaseTypeNode::Float(32), AD(op->value), + AD(op->body)); } Expr ReverseAD::VisitExpr_(const ProjectionNode* op) { return ProjectionNode::make(AD(op->tuple), op->field); } -TVM_REGISTER_API("nnvm.reverse_ad.reverse_ad").set_body([](TVMArgs args, TVMRetValue* ret) { - ReverseAD ad(std::make_shared()); - *ret = ad.AD(args[0]); -}); +TVM_REGISTER_API("nnvm.reverse_ad.reverse_ad") + .set_body([](TVMArgs args, TVMRetValue* ret) { + ReverseAD ad(std::make_shared()); + *ret = ad.AD(args[0]); + }); } // namespace relay } // namespace nnvm