Skip to content

Commit

Permalink
[TVMScript] Python Expression Precedence
Browse files Browse the repository at this point in the history
This PR addes:

- Awareness of expression (operator) precedence during Python code printing
(`(* 1 (+ 2 3))` prints as `1 * (2 + 3)`)

Tracking issue: #11912

This PR is in draft state because it's branched off an open PR
#12112.
  • Loading branch information
yelite authored and junrushao committed Jul 27, 2022
1 parent 584b0f3 commit 90b3a9c
Show file tree
Hide file tree
Showing 8 changed files with 2,475 additions and 26 deletions.
506 changes: 506 additions & 0 deletions include/tvm/script/printer/doc.h

Large diffs are not rendered by default.

165 changes: 164 additions & 1 deletion python/tvm/script/printer/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# under the License.
"""Doc types for TVMScript Unified Printer"""

from typing import List, Dict, Tuple, Optional, Union, Sequence
from enum import IntEnum, unique
from typing import Dict, List, Optional, Sequence, Tuple, Union

import tvm._ffi
import tvm.ir.container
Expand Down Expand Up @@ -100,6 +100,31 @@ def __iter__(self):
raise RuntimeError(f"{self.__class__} cannot be used as iterable.")


class StmtDoc(Doc):
"""Base class of statement doc"""

@property
def comment(self) -> Optional[str]:
# It has to call the dunder method to avoid infinite recursion
# pylint: disable=unnecessary-dunder-call
return self.__getattr__("comment")
# pylint: enable=unnecessary-dunder-call

@comment.setter
def comment(self, value):
return _ffi_api.StmtDocSetComment(self, value) # type: ignore


@tvm._ffi.register_object("script.printer.StmtBlockDoc")
class StmtBlockDoc(Doc):
"""The container doc that holds a list of StmtDoc."""

stmts: Sequence[StmtDoc]

def __init__(self, stmts: List[StmtDoc]):
self.__init_handle_by_constructor__(_ffi_api.StmtBlockDoc, stmts) # type: ignore


@tvm._ffi.register_object("script.printer.LiteralDoc")
class LiteralDoc(ExprDoc):
"""Doc that represents literal value"""
Expand Down Expand Up @@ -293,3 +318,141 @@ def __init__(
step: Optional[ExprDoc] = None,
):
self.__init_handle_by_constructor__(_ffi_api.SliceDoc, start, stop, step) # type: ignore


@tvm._ffi.register_object("script.printer.AssignDoc")
class AssignDoc(StmtDoc):
"""Doc that represents assign statement."""

lhs: ExprDoc
rhs: Optional[ExprDoc]
annotation: Optional[ExprDoc]

def __init__(self, lhs: ExprDoc, rhs: Optional[ExprDoc], annotation: Optional[ExprDoc] = None):
# pylint: disable=line-too-long
self.__init_handle_by_constructor__(_ffi_api.AssignDoc, lhs, rhs, annotation) # type: ignore
# pylint: enable=line-too-long


@tvm._ffi.register_object("script.printer.IfDoc")
class IfDoc(StmtDoc):
"""Doc that represent if-then-else statement."""

predicate: ExprDoc
then_branch: Sequence[StmtDoc]
else_branch: Sequence[StmtDoc]

def __init__(self, predicate: ExprDoc, then_branch: List[StmtDoc], else_branch: List[StmtDoc]):
# pylint: disable=line-too-long
self.__init_handle_by_constructor__(_ffi_api.IfDoc, predicate, then_branch, else_branch) # type: ignore
# pylint: enable=line-too-long


@tvm._ffi.register_object("script.printer.WhileDoc")
class WhileDoc(StmtDoc):
"""Doc that represents while statement."""

predicate: ExprDoc
body: Sequence[StmtDoc]

def __init__(self, predicate: ExprDoc, body: List[StmtDoc]):
self.__init_handle_by_constructor__(_ffi_api.WhileDoc, predicate, body) # type: ignore


@tvm._ffi.register_object("script.printer.ForDoc")
class ForDoc(StmtDoc):
"""Doc that represents for statement."""

lhs: ExprDoc
rhs: ExprDoc
body: Sequence[StmtDoc]

def __init__(self, lhs: ExprDoc, rhs: ExprDoc, body: List[StmtDoc]):
self.__init_handle_by_constructor__(_ffi_api.ForDoc, lhs, rhs, body) # type: ignore


@tvm._ffi.register_object("script.printer.ScopeDoc")
class ScopeDoc(StmtDoc):
"""
Doc that represents special scopes.
Specificially, this means the with statment in Python:
with <rhs> as <lhs>:
<body...>
"""

lhs: Optional[ExprDoc]
rhs: ExprDoc
body: Sequence[StmtDoc]

def __init__(self, lhs: Optional[ExprDoc], rhs: ExprDoc, body: List[StmtDoc]):
self.__init_handle_by_constructor__(_ffi_api.ScopeDoc, lhs, rhs, body) # type: ignore


@tvm._ffi.register_object("script.printer.ExprStmtDoc")
class ExprStmtDoc(StmtDoc):
"""Doc that represents an expression as statement."""

expr: ExprDoc

def __init__(self, expr: ExprDoc):
self.__init_handle_by_constructor__(_ffi_api.ExprStmtDoc, expr) # type: ignore


@tvm._ffi.register_object("script.printer.AssertDoc")
class AssertDoc(StmtDoc):
"""Doc that represents assert statement."""

test: ExprDoc
msg: Optional[ExprDoc]

def __init__(self, test: ExprDoc, msg: Optional[ExprDoc] = None):
self.__init_handle_by_constructor__(_ffi_api.AssertDoc, test, msg) # type: ignore


@tvm._ffi.register_object("script.printer.ReturnDoc")
class ReturnDoc(StmtDoc):
"""Doc that represents return statement."""

value: ExprDoc

def __init__(self, value: ExprDoc):
self.__init_handle_by_constructor__(_ffi_api.ReturnDoc, value) # type: ignore


@tvm._ffi.register_object("script.printer.FunctionDoc")
class FunctionDoc(StmtDoc):
"""Doc that represents function definition."""

name: IdDoc
args: Sequence[AssignDoc]
decorators: Sequence[ExprDoc]
return_type: ExprDoc
body: Sequence[StmtDoc]

def __init__(
self,
name: IdDoc,
args: List[AssignDoc],
decorators: List[ExprDoc],
return_type: ExprDoc,
body: List[StmtDoc],
):
# pylint: disable=line-too-long
self.__init_handle_by_constructor__(_ffi_api.FunctionDoc, name, args, decorators, return_type, body) # type: ignore
# pylint: enable=line-too-long


@tvm._ffi.register_object("script.printer.ClassDoc")
class ClassDoc(StmtDoc):
"""Doc that represents class definition."""

name: IdDoc
decorators: Sequence[ExprDoc]
body: Sequence[StmtDoc]

def __init__(self, name: IdDoc, decorators: List[ExprDoc], body: List[StmtDoc]):
# pylint: disable=line-too-long
self.__init_handle_by_constructor__(_ffi_api.ClassDoc, name, decorators, body) # type: ignore
# pylint: enable=line-too-long
22 changes: 22 additions & 0 deletions src/script/printer/base_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,28 @@ void DocPrinter::PrintDoc(const Doc& doc) {
PrintTypedDoc(GetRef<DictDoc>(doc_node));
} else if (const auto* doc_node = doc.as<SliceDocNode>()) {
PrintTypedDoc(GetRef<SliceDoc>(doc_node));
} else if (const auto* doc_node = doc.as<StmtBlockDocNode>()) {
PrintTypedDoc(GetRef<StmtBlockDoc>(doc_node));
} else if (const auto* doc_node = doc.as<AssignDocNode>()) {
PrintTypedDoc(GetRef<AssignDoc>(doc_node));
} else if (const auto* doc_node = doc.as<IfDocNode>()) {
PrintTypedDoc(GetRef<IfDoc>(doc_node));
} else if (const auto* doc_node = doc.as<WhileDocNode>()) {
PrintTypedDoc(GetRef<WhileDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ForDocNode>()) {
PrintTypedDoc(GetRef<ForDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ScopeDocNode>()) {
PrintTypedDoc(GetRef<ScopeDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ExprStmtDocNode>()) {
PrintTypedDoc(GetRef<ExprStmtDoc>(doc_node));
} else if (const auto* doc_node = doc.as<AssertDocNode>()) {
PrintTypedDoc(GetRef<AssertDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ReturnDocNode>()) {
PrintTypedDoc(GetRef<ReturnDoc>(doc_node));
} else if (const auto* doc_node = doc.as<FunctionDocNode>()) {
PrintTypedDoc(GetRef<FunctionDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ClassDocNode>()) {
PrintTypedDoc(GetRef<ClassDoc>(doc_node));
} else {
LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
throw;
Expand Down
63 changes: 59 additions & 4 deletions src/script/printer/base_doc_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,22 @@ class DocPrinter {
virtual void PrintTypedDoc(const LiteralDoc& doc) = 0;

/*!
* \brief Virtual method to print a IdDoc
* \brief Virtual method to print an IdDoc
*/
virtual void PrintTypedDoc(const IdDoc& doc) = 0;

/*!
* \brief Virtual method to print a AttrAccessDoc
* \brief Virtual method to print an AttrAccessDoc
*/
virtual void PrintTypedDoc(const AttrAccessDoc& doc) = 0;

/*!
* \brief Virtual method to print a IndexDoc
* \brief Virtual method to print an IndexDoc
*/
virtual void PrintTypedDoc(const IndexDoc& doc) = 0;

/*!
* \brief Virtual method to print a OperationDoc
* \brief Virtual method to print an OperationDoc
*/
virtual void PrintTypedDoc(const OperationDoc& doc) = 0;

Expand Down Expand Up @@ -133,6 +133,61 @@ class DocPrinter {
*/
virtual void PrintTypedDoc(const SliceDoc& doc) = 0;

/*!
* \brief Virtual method to print a StmtBlockDoc
*/
virtual void PrintTypedDoc(const StmtBlockDoc& doc) = 0;

/*!
* \brief Virtual method to print an AssignDoc
*/
virtual void PrintTypedDoc(const AssignDoc& doc) = 0;

/*!
* \brief Virtual method to print an IfDoc
*/
virtual void PrintTypedDoc(const IfDoc& doc) = 0;

/*!
* \brief Virtual method to print a WhileDoc
*/
virtual void PrintTypedDoc(const WhileDoc& doc) = 0;

/*!
* \brief Virtual method to print a ForDoc
*/
virtual void PrintTypedDoc(const ForDoc& doc) = 0;

/*!
* \brief Virtual method to print a ScopeDoc
*/
virtual void PrintTypedDoc(const ScopeDoc& doc) = 0;

/*!
* \brief Virtual method to print an ExprStmtDoc
*/
virtual void PrintTypedDoc(const ExprStmtDoc& doc) = 0;

/*!
* \brief Virtual method to print an AssertDoc
*/
virtual void PrintTypedDoc(const AssertDoc& doc) = 0;

/*!
* \brief Virtual method to print a ReturnDoc
*/
virtual void PrintTypedDoc(const ReturnDoc& doc) = 0;

/*!
* \brief Virtual method to print a FunctionDoc
*/
virtual void PrintTypedDoc(const FunctionDoc& doc) = 0;

/*!
* \brief Virtual method to print a ClassDoc
*/
virtual void PrintTypedDoc(const ClassDoc& doc) = 0;

/*!
* \brief Increase the indent level of any content to be
* printed after this call
Expand Down
Loading

0 comments on commit 90b3a9c

Please sign in to comment.