Skip to content

Commit

Permalink
[Parser][Printer] More parser/printer improvements (#12)
Browse files Browse the repository at this point in the history
* Relax pretty printer initial prototype

* call into TVMScriptPrinter for PrimFuncs

* most round-trip tests pass

* address comments

* implement relax.output syntax for dataflow block outputs

* remove leftover comments

* fix Var constructor on ShapeExpr annotation

* add printing and parsing for simple PrimExpr and Call Attrs
  • Loading branch information
altanh authored and junrushao committed Feb 5, 2023
1 parent 3809561 commit 9731df5
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 60 deletions.
92 changes: 81 additions & 11 deletions python/tvm/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,41 @@ def _tir_from_synr(

# NOTE: call_dps is an actual registered operator
class SpecialOp(Enum):
"""Relax operator calls that have special semantics handled by the parser."""
"""Relax operators that have special semantics handled by the parser."""

MATCH_SHAPE = "relax.match_shape"
CALL_PACKED = "relax.call_packed"
DATAFLOW = "relax.dataflow"
DATAFLOW_OUTPUT = "relax.output"


class ArithmeticOp(Enum):
"""Arithmetic operators that can desugar to either Relax or TIR PrimExpr operators."""

ADD = ast.BuiltinOp.Add
SUB = ast.BuiltinOp.Sub
MUL = ast.BuiltinOp.Mul
DIV = ast.BuiltinOp.Div
FLOOR_DIV = ast.BuiltinOp.FloorDiv


RELAX_ARITHMETIC_OP_MAP = {
ArithmeticOp.ADD: relay.op.get("add"),
ArithmeticOp.SUB: relay.op.get("subtract"),
ArithmeticOp.MUL: relay.op.get("multiply"),
ArithmeticOp.DIV: relay.op.get("divide"),
ArithmeticOp.FLOOR_DIV: relay.op.get("floor_divide"),
}

PRIMEXPR_ARITHMETIC_OP_MAP = {
ArithmeticOp.ADD: tir.Add,
ArithmeticOp.SUB: tir.Sub,
ArithmeticOp.MUL: tir.Mul,
ArithmeticOp.DIV: tir.Div,
ArithmeticOp.FLOOR_DIV: tir.FloorDiv,
}


class RelaxTransformer(Transformer):
def __init__(self, definition_scope):
super().__init__()
Expand Down Expand Up @@ -367,16 +394,25 @@ def parse_primexpr(self, expr: ast.Expr, bind_free_vars: bool) -> tir.PrimExpr:
"cannot introduce new dimension variables in this expression",
expr.span,
)

elif isinstance(expr, ast.Constant):
if not isinstance(expr.value, int):
self.report_error("only integer constants are supported", expr.span)
return tir.const(expr.value, "int32", self.to_tvm_span(expr.span))

elif isinstance(expr, ast.Call):
if not isinstance(expr.func_name, ast.Op):
self.report_error(
"only built-in operators can be used in dimension expressions",
expr.func_name.span,
)
op = PRIMEXPR_ARITHMETIC_OP_MAP[self.transform_expr(expr.func_name)]
# TODO(@altanh): it might not make sense to bind free variables
args = [self.parse_primexpr(arg, bind_free_vars) for arg in expr.params]
return op(*args, span=self.to_tvm_span(expr.span))

else:
# TODO(@altanh): parse (simple) PrimExprs
self.report_error(
"only dimension variable expressions are currently supported",
expr.span,
)
self.report_error(f"unsupported dimension expression: {expr}", expr.span)

def transform_module(self, mod: ast.Module) -> IRModule:
"""Transforms the given synr Module to a Relax IRModule.
Expand Down Expand Up @@ -750,7 +786,10 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr:
if isinstance(expr, ast.Attr):
if expr.field.name == "shape":
obj = self.transform_expr(expr.object)
return relay.Call(relay.op.get("shape_of"), [obj], span=self.to_tvm_span(expr.span))
attrs = tvm.ir.attrs.make_node("relay.attrs.ShapeOfAttrs", dtype="int32")
return relay.Call(
relay.op.get("shape_of"), [obj], attrs=attrs, span=self.to_tvm_span(expr.span)
)
else:
# assume it's a hierarchical op identifier (e.g. nn.softmax, relax.call_dps)
op_name = self._parse_attrs_to_str(expr)
Expand Down Expand Up @@ -780,12 +819,30 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr:
)
op = rx.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span))
args = [self.transform_expr(expr.params[1])]
elif isinstance(op, ArithmeticOp):
args = [self.transform_expr(arg) for arg in expr.params]
if all([isinstance(arg, tir.PrimExpr) for arg in args]):
return PRIMEXPR_ARITHMETIC_OP_MAP[op](*args, span=self.to_tvm_span(expr.span))
# otherwise it's just a normal Relax operator call
op = RELAX_ARITHMETIC_OP_MAP[op]
elif isinstance(op, (tvm.ir.Op, relay.Expr)):
args = [self.transform_expr(arg) for arg in expr.params]
else:
self.report_error(f"unsupported function in call: {op}", expr.func_name.span)

if isinstance(op, rx.ExternFunc) or (
isinstance(op, tvm.ir.Op) and op.attrs_type_key != ""
):
attrs_type_key = "DictAttrs" if isinstance(op, rx.ExternFunc) else op.attrs_type_key
kwargs = {}
for key, val in expr.keyword_params.items():
assert isinstance(key, ast.Constant) and isinstance(key.value, str)
kwargs[key.value] = self.transform_expr(val)
attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs)
else:
attrs = None
# TODO(@altanh): should we check for correct arity here eagerly, or defer to a pass?
return relay.Call(op, args, span=self.to_tvm_span(expr.span))
return relay.Call(op, args, attrs=attrs, span=self.to_tvm_span(expr.span))

elif isinstance(expr, ast.Tuple):
fields = [self.transform_expr(field) for field in expr.values]
Expand All @@ -812,14 +869,27 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr:
return tir.IntImm("int32", expr.value, self.to_tvm_span(expr.span))
elif isinstance(expr.value, float):
return tir.FloatImm("float32", expr.value, self.to_tvm_span(expr.span))
elif isinstance(expr.value, str):
# FIXME(@altanh): using StringImm seems to cause problems, but this loses span
return expr.value
elif expr.value is None:
return None
else:
self.report_error(
"unsupported constant expression (we currently only support int and float)",
f"unsupported constant expression: {expr}",
expr.span,
)

elif isinstance(expr, ast.Op):
# TODO(@altanh): might need to generalize from ArithmeticOp if we decide to support
# array slicing syntax
try:
return ArithmeticOp(expr.name)
except ValueError:
self.report_error(f"unsupported built-in operator: {expr.name}", expr.span)

else:
self.report_error("unsupported expression", expr.span)
self.report_error(f"unsupported expression: {expr}", expr.span)

def transform_block(self, block: ast.Block) -> rx.SeqExpr:
"""Transforms the given synr block to a Relax SeqExpr (sequence of Blocks with a final
Expand All @@ -842,7 +912,7 @@ def transform_block(self, block: ast.Block) -> rx.SeqExpr:
parsed_stmt = self.transform_stmt(stmt)
if isinstance(parsed_stmt, rx.DataflowBlock):
if current_block:
# FIXME: span
# FIXME(@altanh): need to manually construct span start & end
blocks.append(rx.BindingBlock(current_block, self.to_tvm_span(stmt.span)))
current_block = []
blocks.append(parsed_stmt)
Expand Down
Loading

0 comments on commit 9731df5

Please sign in to comment.