Skip to content

Commit

Permalink
REPL + Operators + Env Refactor (apache#42)
Browse files Browse the repository at this point in the history
* Start on REPL

* Clean up evaluator after rebase

* Add support for registering primitives

Refactor Environment interface to better disambiguate between Globals and
Intrinsics. Start on Python code to produce operator specific JIT compilers
which will specialize an operator based on the type signature of the operator

* Debugging test crashes

* Fix testing error and setup primitive test case

* Add a bunch of type annotations

* Add more .pyi files to stub C++ written functions

* Getting closer to primitive evaluation, need type information

* Continue refactoring needed to call Tensor program w/ Primitives

* Add Tensor values, and improve calling into eval

* Add notes for tomorrow

* Fix linting, prepare for early merge

* Fix rebase

* Fix linting

* Add typed_ast to package list

* Fix failing test cases

* Trigger rebuild
  • Loading branch information
jroesch committed Aug 16, 2018
1 parent 3172b3d commit fad945a
Show file tree
Hide file tree
Showing 27 changed files with 881 additions and 410 deletions.
3 changes: 2 additions & 1 deletion relay/include/relay/anf.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ struct LetList {
Expr plug(const Expr & expr) const {
Expr ret = expr;
for (auto rit = lets.rbegin(); rit != lets.rend(); rit++) {
ret = LetNode::make(rit->first.lid, rit->second, ret);
// TODO(@jroesch) Fix the type of the binder.
ret = LetNode::make(rit->first.lid, BaseTypeNode::Float(32), rit->second, ret);
}
return ret;
}
Expand Down
24 changes: 18 additions & 6 deletions relay/include/relay/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,35 @@
#ifndef NNVM_RELAY_ENVIRONMENT_H_
#define NNVM_RELAY_ENVIRONMENT_H_

#include <tvm/ir_functor.h>
#include <unordered_map>
#include <string>
#include "ir.h"

namespace nnvm {
namespace relay {

// class InternMap {
// // fill me in
// };

struct Environment;

/*! \brief Integer literal `0`, `1000`. */
class EnvironmentNode : public ValueNode {
private:
std::unordered_map<std::string, GlobalId> global_map_;
std::unordered_map<std::string, IntrinsicId> intrinsic_map_;
// What if there are two globalid with the same name?
// This should be fixed in the python code,
// But I havent take much look into it, so I will just hack around.

inline void add_global(const std::string & str, GlobalId id);
inline void add_intrinsic(const std::string & str, IntrinsicId id);
inline void register_primitive(Primitive p);

public:
// This map contains all items *except* Primitives.
tvm::Map<GlobalId, Item> items;

// This map *only* contains primitives.
tvm::Map<IntrinsicId, Primitive> intrinsics;

EnvironmentNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
Expand All @@ -38,10 +43,17 @@ class EnvironmentNode : public ValueNode {

TVM_DLL static Environment make(tvm::Map<GlobalId, Item> items);

// Add an item to the Enviroment.
void add(const Item& item);

GlobalId global_id(const std::string & str);
IntrinsicId intrinsic_id(const std::string & str);

// We can lookup a GlobalId, IntrinsicId.
Item lookup(const GlobalId & id);
Item lookup(const std::string & str);
Item lookup(const IntrinsicId & id);
Item lookup_global(const std::string & str);
Item lookup_intrinsic(const std::string & str);

static constexpr const char* _type_key = "nnvm.Environment";
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
Expand Down
55 changes: 2 additions & 53 deletions relay/include/relay/evaluator.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*!
* Copyright (c) 2018 by Contributors
* \file evaluator.h
* \brief An evaluator for the Relay IR
* \brief An evaluator for Relay IR
*/
#ifndef NNVM_RELAY_EVALUATOR_H_
#define NNVM_RELAY_EVALUATOR_H_
Expand All @@ -16,58 +16,7 @@
namespace nnvm {
namespace relay {

struct Stack {
// Because we know variables will always be unique we can collapse this into a single mapping?
std::vector<std::unordered_map<LocalId, Value>> stack;
Stack() {
push();
}
void push() {
stack.push_back({});
}
void pop() {
stack.pop_back();
}
};

struct StackFrame {
Stack & stack;
explicit StackFrame(Stack & stack) : stack(stack) {
stack.push();
}
~StackFrame() {
stack.pop();
}
};

class Evaluator : public ExprFunctor<Value(const Expr& n)> {
public:
Environment env;
Stack stack;
Evaluator();
Evaluator(Environment env) : env(env) {}
void extend(const LocalId & id, Value v);
Value Eval(const Expr& expr);
Value VisitExpr_(const LocalIdNode* op) override;
Value VisitExpr_(const GlobalIdNode* op) override;
Value VisitExpr_(const IntrinsicIdNode* op) override;
Value VisitExpr_(const FloatLitNode* op) override;
Value VisitExpr_(const BoolLitNode* op) override;
Value VisitExpr_(const IntLitNode* op) override;
Value VisitExpr_(const TensorLitNode* op) override;
Value VisitExpr_(const ProductLitNode* op) override;
Value VisitExpr_(const CastNode* op) override;
Value VisitExpr_(const ParamNode* op) override;
Value VisitExpr_(const FunctionNode* op) override;
Value VisitExpr_(const CallNode* op) override;
Value VisitExpr_(const DebugNode* op) override;
Value VisitExpr_(const UnaryOpNode* op) override;
Value VisitExpr_(const BinaryOpNode* op) override;
Value VisitExpr_(const LetNode* op) override;
Value VisitExpr_(const ZeroNode* op) override;
Value VisitExpr_(const ProjectionNode* op) override;
Value VisitExpr_(const IfNode* op) override;
};
Value evaluate(Environment env, Expr e);

} // namespace relay
} // namespace nnvm
Expand Down
9 changes: 6 additions & 3 deletions relay/include/relay/ir/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
#define RELAY_DEFINE_TYPE(TypeName, NodeName) \
RELAY_DEFINE_NODE_REF(TypeName, NodeName, Type)

/*! \brief Macro to make it easy to define node ref type given node */
#define RELAY_DEFINE_ITEM(TypeName, NodeName) \
RELAY_DEFINE_NODE_REF(TypeName, NodeName, Item)

namespace nnvm {
namespace relay {

Expand Down Expand Up @@ -66,7 +70,7 @@ RELAY_DEFINE_NODE_REF(Span, SpanNode, ::tvm::NodeRef);

struct Node : public tvm::Node {
public:
Span span;
// Span span;

Node() {}
};
Expand Down Expand Up @@ -107,7 +111,7 @@ class LocalIdNode : public ExprNode {
LocalIdNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("span", &span);
// v->Visit("span", &span);
v->Visit("name", &name);
}

Expand Down Expand Up @@ -181,7 +185,6 @@ struct Type : public NodeRef {

struct ItemNode : Node {
public:
GlobalId id;
static constexpr const char* _type_key = "nnvm.Item";
TVM_DECLARE_BASE_NODE_INFO(ItemNode, Node);
};
Expand Down
8 changes: 6 additions & 2 deletions relay/include/relay/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,18 @@ class Function;
class FunctionNode : public ExprNode {
public:
tvm::Array<Param> params;
Type ret_type;
Expr body;

FunctionNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("params", &params);
v->Visit("ret_type", &ret_type);
v->Visit("body", &body);
}

TVM_DLL static Function make(tvm::Array<Param> params, Expr body);
TVM_DLL static Function make(tvm::Array<Param> params, Type ret_type, Expr body);

static constexpr const char* _type_key = "nnvm.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
Expand Down Expand Up @@ -351,18 +353,20 @@ class Let;
class LetNode : public ExprNode {
public:
LocalId id;
Type type;
Expr value;
Expr body;

LetNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("id", &id);
v->Visit("type", &type);
v->Visit("value", &value);
v->Visit("body", &body);
}

TVM_DLL static Let make(LocalId id, Expr value, Expr body);
TVM_DLL static Let make(LocalId id, Type type, Expr value, Expr body);

static constexpr const char* _type_key = "nnvm.Let";
TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode);
Expand Down
9 changes: 6 additions & 3 deletions relay/include/relay/ir/item.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Primitive;

class PrimitiveNode : public ItemNode {
public:
IntrinsicId id;
Type type;

PrimitiveNode() {}
Expand All @@ -28,20 +29,22 @@ class PrimitiveNode : public ItemNode {
v->Visit("type", &type);
}

TVM_DLL static Primitive make(GlobalId id, Type type);
TVM_DLL static Primitive make(IntrinsicId id, Type type);

static constexpr const char* _type_key = "nnvm.Primitive";
TVM_DECLARE_NODE_TYPE_INFO(PrimitiveNode, ItemNode);
};

TVM_DEFINE_NODE_REF(Primitive, PrimitiveNode);
RELAY_DEFINE_ITEM(Primitive, PrimitiveNode);

class Defn;

class DefnNode : public ItemNode {
public:
GlobalId id;
Type type;
Expr body;

DefnNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
Expand All @@ -56,7 +59,7 @@ class DefnNode : public ItemNode {
TVM_DECLARE_NODE_TYPE_INFO(DefnNode, ItemNode);
};

TVM_DEFINE_NODE_REF(Defn, DefnNode);
RELAY_DEFINE_ITEM(Defn, DefnNode);


} // namespace relay
Expand Down
2 changes: 1 addition & 1 deletion relay/include/relay/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class ShapeNode : public TypeNode {
ShapeNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
// v->Visit("dims", &id);
v->Visit("dims", &dims);
}

TVM_DLL static Shape make(tvm::Array<Expr> dims);
Expand Down
20 changes: 20 additions & 0 deletions relay/include/relay/ir/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,26 @@ struct ProductValueNode : ValueNode {

RELAY_DEFINE_VALUE(ProductValue, ProductValueNode);

class TensorValue;

/*! \brief Product literal (x, ... y). */
struct TensorValueNode : ValueNode {
TVMArrayHandle data;

TensorValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", reinterpret_cast<void**>(data));
}

TVM_DLL static TensorValue make(TVMArrayHandle data);

static constexpr const char* _type_key = "nnvm.TensorValue";
TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode);
};

RELAY_DEFINE_VALUE(TensorValue, TensorValueNode);

} // namespace relay
} // namespace nnvm

Expand Down
2 changes: 1 addition & 1 deletion relay/include/relay/py_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class FunctorNode
if (visit_function != nullptr) {
return visit_function(fn->params, fn->body, args);
} else {
return FunctionNode::make(fn->params, fn->body);
return FunctionNode::make(fn->params, fn->ret_type, fn->body);
}
}

Expand Down
29 changes: 29 additions & 0 deletions relay/python/relay/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# pylint: disable=no-else-return, unidiomatic-typecheck
"""A global environment storing everything needed to interpret or compile a Realy program."""
from typing import Union, Dict
from tvm._ffi.function import _init_api
from .base import register_nnvm_node, NodeBase
from .expr import GlobalId, IntrinsicId, Item

_init_api("nnvm.relay.env", __name__)

@register_nnvm_node
class Environment(NodeBase):
"""The global Relay environment containing definitions,
primitives, options, and more.
"""
items: Dict[GlobalId, Item]

def add(self, func: GlobalId) -> None:
return Environment_add(self, func)

def global_id(self, name: str) -> GlobalId:
return Environment_global_id(self, name)

def lookup(self, ident: Union[GlobalId, IntrinsicId]) -> Item:
if isinstance(ident, IntrinsicId):
intrin_id = self.intrinsic_id(ident)
return Environment_lookup_intrinsic(self, intrin_id)
else:
global_id = self.global_id(ident)
return Environment_lookup_global(self, global_id)
10 changes: 10 additions & 0 deletions relay/python/relay/env.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Union, Tuple, Dict
from nnvm.relay.expr import GlobalId, IntrinsicId, Item

class Environment(): ...

def Environment_add(self: Environment, func: GlobalId) -> None: ...
def Environment_global_id(self: Environment, name: str) -> GlobalId: ...
def Environment_intrinsic_id(self: Environment, name: str) -> IntrinsicId: ...
def Environment_lookup_global(self: Environment, id: GlobalId) -> Item: ...
def Environment_lookup_intrinsic(self: Environment, id: GlobalId) -> Item: ...
Loading

0 comments on commit fad945a

Please sign in to comment.