Skip to content

Commit

Permalink
[Relay] parser/pretty printer roundtripping (apache#3536)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame authored and wweic committed Sep 6, 2019
1 parent c01934c commit fe66e75
Show file tree
Hide file tree
Showing 16 changed files with 1,827 additions and 1,166 deletions.
220 changes: 168 additions & 52 deletions python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name, unused-import
# pylint: disable=invalid-name, unused-argument
"""A parser for Relay's text format."""
from __future__ import absolute_import

import sys
from ast import literal_eval

from collections import deque
from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any, Dict

import tvm

Expand All @@ -32,6 +32,23 @@
from . import ty
from . import op

PYTHON_VERSION = sys.version_info.major
try:
from .grammar.py3.RelayVisitor import RelayVisitor
from .grammar.py3.RelayParser import RelayParser
from .grammar.py3.RelayLexer import RelayLexer
except ImportError:
raise Exeption("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.")

try:
from antlr4 import InputStream, CommonTokenStream
from antlr4.error.ErrorListener import ErrorListener
except ImportError:
raise Exception("Couldn't find ANTLR runtime." +
"Try running `pip{version} install antlr4-python{version}-runtime`."
.format(version=PYTHON_VERSION))

sys.setrecursionlimit(10000)

class ParseError(Exception):
"""Exception type for parse errors."""
Expand All @@ -41,21 +58,50 @@ def __init__(self, message):
super(ParseError, self).__init__()
self.message = message

PYTHON_VERSION = sys.version_info.major
try:
from .grammar.py3.RelayVisitor import RelayVisitor
from .grammar.py3.RelayParser import RelayParser
from .grammar.py3.RelayLexer import RelayLexer
except ImportError:
raise ParseError("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.")
def __repr__(self):
return "ParseError({})".format(self.message)

try:
from antlr4 import ParserRuleContext, InputStream, CommonTokenStream
from antlr4.tree.Tree import TerminalNode
except ImportError:
raise ParseError("Couldn't find ANTLR runtime." +
"Try running `pip{version} install antlr4-python{version}-runtime`."
.format(version=PYTHON_VERSION))
def __str__(self):
return repr(self)

class OpWrapper:
"""Overload the __call__ for op."""
pass

class ExprOp(OpWrapper):
"""Call an expr. The default, but does not handle attrs well."""
def __init__(self, operator):
self.operator = operator

def __call__(self, args, attrs, type_args):
try:
return expr.Call(self.operator, args, attrs, type_args)
except Exception:
raise Exception(str(self.operator) + " " + str(attrs))

class FuncOp(OpWrapper):
"""Convert the attrs, call the python function with the attrs passed in as keyword arguments.
Tvm should provide this in the future, as this is pretty similar to what op.get is providing.
"""
def __init__(self, operator):
self.operator = operator

def convert(self, v):
if isinstance(v, tuple):
return tuple([self.convert(x) for x in v])
if isinstance(v, expr.Constant):
return v.data.asnumpy().item()
if isinstance(v, str):
return v
raise Exception(v)

def __call__(self, args, attrs, type_args):
if attrs is None:
attrs = {}
x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
if isinstance(x, expr.TupleWrapper):
x = x.astuple()
return x

BINARY_OPS = {
RelayParser.MUL: op.multiply,
Expand All @@ -70,16 +116,34 @@ def __init__(self, message):
RelayParser.NE: op.not_equal,
}

FUNC_OPS = {
"nn.conv2d": op.nn.conv2d,
"nn.batch_norm": op.nn.batch_norm,
"nn.dense": op.nn.dense,
"nn.bias_add": op.nn.bias_add,
"nn.max_pool2d": op.nn.max_pool2d,
"nn.global_max_pool2d": op.nn.global_max_pool2d,
"nn.avg_pool2d": op.nn.avg_pool2d,
"nn.global_avg_pool2d": op.nn.global_avg_pool2d,
"nn.softmax": op.nn.softmax,
"reshape": op.reshape,
"nn.conv2d_transpose": op.nn.conv2d_transpose,
"concatenate": op.concatenate,
"nn.dropout": op.nn.dropout_raw,
"zeros": op.zeros,
"split": op.split,
}

TYPE_PREFIXES = [
"int",
"uint",
"float",
"bool",
]

T = TypeVar("T")
Scope = Deque[Tuple[str, T]]
Scopes = Deque[Scope[T]]
T = ty.TypeVar("T")
# Scope = Deque[Tuple[str, T]]
# Scopes = Deque[Scope[T]]

def lookup(scopes, name):
# type: (Scopes[T], str) -> Optional[T]
Expand Down Expand Up @@ -108,6 +172,8 @@ def _wrapper(*args, **kwargs):
ast = f(*args, **kwargs)
line, col = ctx.getSourceInterval()
sp = Span(sn, line, col)
if isinstance(ast, tvm.relay.expr.TupleWrapper):
ast = ast.astuple()
ast.set_span(sp)
return ast
return _wrapper
Expand Down Expand Up @@ -179,6 +245,9 @@ def mk_typ(self, name, kind):
self.type_param_scopes[0].appendleft((name, typ))
return typ

def visitProjection(self, ctx):
return expr.TupleGetItem(self.visit(ctx.expr()), self.visit(ctx.NAT()))

def visitTerminal(self, node):
# type: (TerminalNode) -> Union[expr.Expr, int, float]
"""Visit lexer tokens that aren't ignored or visited by other functions."""
Expand Down Expand Up @@ -213,12 +282,15 @@ def visitTerminal(self, node):
if node_text == "False":
return False
raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text))
if node_type == RelayLexer.QUOTED_STRING:
return literal_eval(node_text)

raise ParseError("todo: {}".format(node_text))
raise ParseError("todo: `{}`".format(node_text))

def visit_list(self, ctx_list):
# type: (List[ParserRuleContext]) -> List[Any]
""""Visit a list of contexts."""
assert isinstance(ctx_list, list)

return [self.visit(ctx) for ctx in ctx_list]

Expand All @@ -232,6 +304,11 @@ def getType_(self, ctx):
return self.visit(ctx)

def visitProg(self, ctx):
self.meta = None
if ctx.METADATA():
header, data = str(ctx.METADATA()).split('\n', 1)
assert header == "METADATA:"
self.meta = tvm.load_json(data)
# type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module]
if ctx.defn():
self.visit_list(ctx.defn())
Expand All @@ -245,11 +322,14 @@ def visitProg(self, ctx):
# Exprs
def visitOpIdent(self, ctx):
# type: (RelayParser.OpIdentContext) -> op.Op
return op.get(ctx.CNAME().getText())
op_name = ctx.CNAME().getText()
if op_name in FUNC_OPS:
return FuncOp(FUNC_OPS[op_name])
return ExprOp(op.get(op_name))

# pass through
def visitParens(self, ctx):
# type: (RelayParser.ParensContext) -> expr.Expr
def visitParen(self, ctx):
# type: (RelayParser.ParenContext) -> expr.Expr
return self.visit(ctx.expr())

# pass through
Expand Down Expand Up @@ -283,25 +363,17 @@ def visitTuple(self, ctx):
tup = self.visit_list(ctx.expr())
return expr.Tuple(tup)

# Currently doesn't support mutable sequencing.
def visitLet(self, ctx):
# type: (RelayParser.SeqContext) -> expr.Let
"""Desugar various sequence constructs to Relay Let nodes."""
if ctx.MUT() is not None:
raise ParseError("Mutation is currently unsupported.")

if ctx.var() is None or ctx.var().ident() is None:
if ctx.var() is None:
# anonymous identity
ident = "_"
type_ = None
var = self.mk_var(ident, type_)
else:
local_var = ctx.var().ident().LOCAL_VAR()
if local_var is None:
raise ParseError("Only local ids may be used in `let`s.")
ident = local_var.getText()[1:]
type_ = self.getType_(ctx.var().type_())

var = self.mk_var(ident, type_)
var = self.visitVar(ctx.var())

self.enter_var_scope()
value = self.visit(ctx.expr(0))
Expand All @@ -326,7 +398,7 @@ def visitBinOp(self, ctx):
def visitVar(self, ctx):
# type: (RelayParser.VarContext) -> expr.Var
"""Visit a single variable."""
ident = ctx.ident().LOCAL_VAR()
ident = ctx.LOCAL_VAR()

if ident is None:
raise ParseError("Only local ids may be used in vars.")
Expand All @@ -344,19 +416,29 @@ def visitAttr(self, ctx):
# type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr]
return (ctx.CNAME().getText(), self.visit(ctx.expr()))

def visitAttrList(self, ctx):
def visitArgNoAttr(self, ctx):
return (self.visit_list(ctx.varList().var()), None)

def visitAttrSeq(self, ctx):
# type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr]
return dict(self.visit_list(ctx.attr()))

def visitArgWithAttr(self, ctx):
return (self.visit_list(ctx.var()), self.visitAttrSeq(ctx.attrSeq()))

def visitArgList(self,
ctx # type: RelayParser.ArgListContext
):
# type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]
var_list = self.visit(ctx.varList()) if ctx.varList() else None
attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None

return (var_list, attr_list)

def visitMeta(self, ctx):
type_key = str(ctx.CNAME())
index = int(self.visit(ctx.NAT()))
return self.meta[type_key][index]

def mk_func(self, ctx):
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function
"""Construct a function from either a Func or Defn."""
Expand All @@ -365,7 +447,7 @@ def mk_func(self, ctx):
self.enter_var_scope()
# Capture type params in params.
self.enter_type_param_scope()
type_params = ctx.typeParamSeq()
type_params = ctx.typeParamList()

if type_params is not None:
type_params = type_params.ident()
Expand Down Expand Up @@ -405,18 +487,25 @@ def visitDefn(self, ctx):
raise ParseError("Only global ids may be used in `def`s.")
ident_name = ident.getText()[1:]
ident = self.mk_global_var(ident_name)

self.module[ident] = self.mk_func(ctx)

def visitCallNoAttr(self, ctx):
return (self.visit_list(ctx.exprList().expr()), None)

def visitCallWithAttr(self, ctx):
return (self.visit_list(ctx.expr()), self.visit(ctx.attrSeq()))

def call(self, func, args, attrs, type_args):
if isinstance(func, OpWrapper):
return func(args, attrs, type_args)
return expr.Call(func, args, attrs, type_args)

@spanify
def visitCall(self, ctx):
# type: (RelayParser.CallContext) -> expr.Call
visited_exprs = self.visit_list(ctx.expr())

func = visited_exprs[0]
args = visited_exprs[1:]

return expr.Call(func, args, None, None)
func = self.visit(ctx.expr())
args, attrs = self.visit(ctx.callList())
return self.call(func, args, attrs, [])

@spanify
def visitIfElse(self, ctx):
Expand All @@ -438,9 +527,7 @@ def visitIfElse(self, ctx):
def visitGraph(self, ctx):
# type: (RelayParser.GraphContext) -> expr.Expr
"""Visit a graph variable assignment."""
if ctx.ident().GRAPH_VAR() is None:
raise ParseError("Expected a graph var, but got `{}`".format(ctx.ident().getText()))
graph_nid = int(ctx.ident().GRAPH_VAR().getText()[1:])
graph_nid = int(ctx.GRAPH_VAR().getText()[1:])

self.enter_var_scope()
value = self.visit(ctx.expr(0))
Expand Down Expand Up @@ -500,15 +587,18 @@ def visitParensShape(self, ctx):
# type: (RelayParser.ParensShapeContext) -> int
return self.visit(ctx.shape())

def visitShapeSeq(self, ctx):
# type: (RelayParser.ShapeSeqContext) -> List[int]
def visitShapeList(self, ctx):
# type: (RelayParser.ShapeListContext) -> List[int]
return self.visit_list(ctx.shape())

def visitTensor(self, ctx):
return tuple(self.visit_list(ctx.expr()))

def visitTensorType(self, ctx):
# type: (RelayParser.TensorTypeContext) -> ty.TensorType
"""Create a simple tensor type. No generics."""

shape = self.visit(ctx.shapeSeq())
shape = self.visit(ctx.shapeList())
dtype = self.visit(ctx.type_())

if not isinstance(dtype, ty.TensorType):
Expand Down Expand Up @@ -536,11 +626,37 @@ def make_parser(data):
"""Construct a RelayParser a given data stream."""
input_stream = InputStream(data)
lexer = RelayLexer(input_stream)
lexer.addErrorListener(StrictErrorListener(data))
token_stream = CommonTokenStream(lexer)
return RelayParser(token_stream)
p = RelayParser(token_stream)
p.addErrorListener(StrictErrorListener(data))
return p

__source_name_counter__ = 0

class StrictErrorListener(ErrorListener):
"""This ErrorListener fail eagerly on all error, and report the program."""
def __init__(self, text):
self.text = text

def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
raise Exception("Syntax Error in:\n" + self.text)

def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs):
raise Exception("Ambiguity Error in:\n" + self.text)

def reportAttemptingFullContext(self,
recognizer,
dfa,
startIndex,
stopIndex,
conflictingAlts,
configs):
raise Exception("Attempting Full Context in:\n" + self.text)

def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs):
raise Exception("Context Sensitivity in:\n" + self.text)

def fromtext(data, source_name=None):
# type: (str, str) -> Union[expr.Expr, module.Module]
"""Parse a Relay program."""
Expand Down
Loading

0 comments on commit fe66e75

Please sign in to comment.