Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Python Expression Precedence #12148

Merged
merged 2 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ class FunctionDocNode : public StmtDocNode {
/*! \brief Decorators of function. */
Array<ExprDoc> decorators;
/*! \brief The return type of function. */
ExprDoc return_type{nullptr};
Optional<ExprDoc> return_type{NullOpt};
/*! \brief The body of function. */
Array<StmtDoc> body;

Expand Down Expand Up @@ -1100,7 +1100,7 @@ class FunctionDoc : public StmtDoc {
* \param body The body of function.
*/
explicit FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
ExprDoc return_type, Array<StmtDoc> body);
Optional<ExprDoc> return_type, Array<StmtDoc> body);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionDoc, StmtDoc, FunctionDocNode);
};

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/printer/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,15 @@ class FunctionDoc(StmtDoc):
name: IdDoc
args: Sequence[AssignDoc]
decorators: Sequence[ExprDoc]
return_type: ExprDoc
return_type: Optional[ExprDoc]
body: Sequence[StmtDoc]

def __init__(
self,
name: IdDoc,
args: List[AssignDoc],
decorators: List[ExprDoc],
return_type: ExprDoc,
return_type: Optional[ExprDoc],
body: List[StmtDoc],
):
self.__init_handle_by_constructor__(
Expand Down
4 changes: 2 additions & 2 deletions src/script/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ ReturnDoc::ReturnDoc(ExprDoc value) {
}

FunctionDoc::FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
ExprDoc return_type, Array<StmtDoc> body) {
Optional<ExprDoc> return_type, Array<StmtDoc> body) {
ObjectPtr<FunctionDocNode> n = make_object<FunctionDocNode>();
n->name = name;
n->args = args;
Expand Down Expand Up @@ -345,7 +345,7 @@ TVM_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value)
TVM_REGISTER_NODE_TYPE(FunctionDocNode);
TVM_REGISTER_GLOBAL("script.printer.FunctionDoc")
.set_body_typed([](IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
ExprDoc return_type, Array<StmtDoc> body) {
Optional<ExprDoc> return_type, Array<StmtDoc> body) {
return FunctionDoc(name, args, decorators, return_type, body);
});

Expand Down
179 changes: 167 additions & 12 deletions src/script/printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,114 @@ namespace tvm {
namespace script {
namespace printer {

/*!
* \brief Operator precedence
*
* This is based on
* https://docs.python.org/3/reference/expressions.html#operator-precedence
junrushao marked this conversation as resolved.
Show resolved Hide resolved
*/
enum class ExprPrecedence : int32_t {
/*! \brief Unknown precedence */
kUnkown = 0,
/*! \brief Lambda Expression */
kLambda = 1,
/*! \brief Conditional Expression */
kIfThenElse = 2,
/*! \brief Boolean OR */
kBooleanOr = 3,
/*! \brief Boolean AND */
kBooleanAnd = 4,
/*! \brief Boolean NOT */
kBooleanNot = 5,
/*! \brief Comparisons */
kComparison = 6,
/*! \brief Bitwise OR */
kBitwiseOr = 7,
/*! \brief Bitwise XOR */
kBitwiseXor = 8,
/*! \brief Bitwise AND */
kBitwiseAnd = 9,
/*! \brief Shift Operators */
kShift = 10,
/*! \brief Addition and subtraction */
kAdd = 11,
/*! \brief Multiplication, division, floor division, remainder */
kMult = 12,
/*! \brief Positive negative and bitwise NOT */
kUnary = 13,
/*! \brief Exponentiation */
kExp = 14,
/*! \brief Index access, attribute access, call and atom expression */
kIdentity = 15,
};

#define DOC_PRECEDENCE_ENTRY(RefType, Precedence) \
{ RefType::ContainerType::RuntimeTypeIndex(), ExprPrecedence::Precedence }
junrushao marked this conversation as resolved.
Show resolved Hide resolved

ExprPrecedence GetExprPrecedence(const ExprDoc& doc) {
// Key is the value of OperationDocNode::Kind
static const std::vector<ExprPrecedence> op_kind_precedence = []() {
using OpKind = OperationDocNode::Kind;
std::map<OpKind, ExprPrecedence> raw_table = {
{OpKind::kUSub, ExprPrecedence::kUnary}, //
{OpKind::kInvert, ExprPrecedence::kUnary}, //
{OpKind::kAdd, ExprPrecedence::kAdd}, //
{OpKind::kSub, ExprPrecedence::kAdd}, //
{OpKind::kMult, ExprPrecedence::kMult}, //
{OpKind::kDiv, ExprPrecedence::kMult}, //
{OpKind::kFloorDiv, ExprPrecedence::kMult}, //
{OpKind::kMod, ExprPrecedence::kMult}, //
{OpKind::kPow, ExprPrecedence::kExp}, //
{OpKind::kLShift, ExprPrecedence::kShift}, //
{OpKind::kRShift, ExprPrecedence::kShift}, //
{OpKind::kBitAnd, ExprPrecedence::kBitwiseAnd}, //
{OpKind::kBitOr, ExprPrecedence::kBitwiseOr}, //
{OpKind::kBitXor, ExprPrecedence::kBitwiseXor}, //
{OpKind::kLt, ExprPrecedence::kComparison}, //
{OpKind::kLtE, ExprPrecedence::kComparison}, //
{OpKind::kEq, ExprPrecedence::kComparison}, //
{OpKind::kNotEq, ExprPrecedence::kComparison}, //
{OpKind::kGt, ExprPrecedence::kComparison}, //
{OpKind::kGtE, ExprPrecedence::kComparison}, //
{OpKind::kIfThenElse, ExprPrecedence::kIfThenElse}, //
};

std::vector<ExprPrecedence> table;
table.resize(static_cast<int>(OperationDocNode::Kind::kSpecialEnd) + 1);
junrushao marked this conversation as resolved.
Show resolved Hide resolved

for (const auto& kv : raw_table) {
table[static_cast<int>(kv.first)] = kv.second;
}

return table;
}();

// Key is the type index of Doc
static const std::unordered_map<uint32_t, ExprPrecedence> doc_type_precedence = {
DOC_PRECEDENCE_ENTRY(LiteralDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(IdDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(AttrAccessDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(IndexDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(CallDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(LambdaDoc, kLambda), //
DOC_PRECEDENCE_ENTRY(TupleDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(ListDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(DictDoc, kIdentity), //
};

if (const auto* op_doc = doc.as<OperationDocNode>()) {
ExprPrecedence precedence = op_kind_precedence[static_cast<int>(op_doc->kind)];
junrushao marked this conversation as resolved.
Show resolved Hide resolved
ICHECK(precedence != ExprPrecedence::kUnkown)
<< "Precedence for operator " << static_cast<int>(op_doc->kind) << " is unknown";
return precedence;
} else if (doc_type_precedence.find(doc->type_index()) != doc_type_precedence.end()) {
return doc_type_precedence.at(doc->type_index());
junrushao marked this conversation as resolved.
Show resolved Hide resolved
} else {
ICHECK(false) << "Precedence for doc type " << doc->GetTypeKey() << " is unknown";
throw;
}
}

class PythonDocPrinter : public DocPrinter {
public:
explicit PythonDocPrinter(int indent_spaces = 4) : DocPrinter(indent_spaces) {}
Expand Down Expand Up @@ -98,6 +206,42 @@ class PythonDocPrinter : public DocPrinter {
}
}

/*!
* \brief Print expression and add parenthesis if needed.
*/
void PrintChildExpr(const ExprDoc& doc, ExprPrecedence parent_precedence,
bool parenthesis_for_same_precedence = false) {
ExprPrecedence doc_precedence = GetExprPrecedence(doc);
if (doc_precedence < parent_precedence ||
(parenthesis_for_same_precedence && doc_precedence == parent_precedence)) {
output_ << "(";
PrintDoc(doc);
output_ << ")";
} else {
PrintDoc(doc);
}
}

/*!
* \brief Print expression and add parenthesis if doc has lower precedence than parent.
*/
void PrintChildExpr(const ExprDoc& doc, const ExprDoc& parent,
bool parenthesis_for_same_precedence = false) {
ExprPrecedence parent_precedence = GetExprPrecedence(parent);
return PrintChildExpr(doc, parent_precedence, parenthesis_for_same_precedence);
}

/*!
* \brief Print expression and add parenthesis if doc doesn't have higher precedence than parent.
*
* This function should be used to print an child expression that needs to be wrapped
* by parenthesis even if it has the same precedence as its parent, e.g., the `b` in `a + b`
* and the `b` and `c` in `a if b else c`.
*/
void PrintChildExprConservatively(const ExprDoc& doc, const ExprDoc& parent) {
PrintChildExpr(doc, parent, /*parenthesis_for_same_precedence*/ true);
}

void MaybePrintCommentInline(const StmtDoc& stmt) {
if (stmt->comment.defined()) {
const std::string& comment = stmt->comment.value();
Expand Down Expand Up @@ -161,12 +305,12 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; }

void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) {
PrintDoc(doc->value);
PrintChildExpr(doc->value, doc);
output_ << "." << doc->name;
}

void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) {
PrintDoc(doc->value);
PrintChildExpr(doc->value, doc);
if (doc->indices.size() == 0) {
output_ << "[()]";
} else {
Expand Down Expand Up @@ -226,29 +370,38 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) {
// Unary Operators
ICHECK_EQ(doc->operands.size(), 1);
output_ << OperatorToString(doc->kind);
PrintDoc(doc->operands[0]);
PrintChildExpr(doc->operands[0], doc);
} else if (doc->kind == OpKind::kPow) {
// Power operator is different than other binary operators
// It's right-associative and binds less tightly than unary operator on its right.
// https://docs.python.org/3/reference/expressions.html#the-power-operator
// https://docs.python.org/3/reference/expressions.html#operator-precedence
ICHECK_EQ(doc->operands.size(), 2);
PrintChildExprConservatively(doc->operands[0], doc);
output_ << " ** ";
PrintChildExpr(doc->operands[1], ExprPrecedence::kUnary);
junrushao marked this conversation as resolved.
Show resolved Hide resolved
} else if (doc->kind < OpKind::kBinaryEnd) {
// Binary Operator
ICHECK_EQ(doc->operands.size(), 2);
PrintDoc(doc->operands[0]);
PrintChildExpr(doc->operands[0], doc);
output_ << " " << OperatorToString(doc->kind) << " ";
PrintDoc(doc->operands[1]);
PrintChildExprConservatively(doc->operands[1], doc);
} else if (doc->kind == OpKind::kIfThenElse) {
ICHECK_EQ(doc->operands.size(), 3)
<< "ValueError: IfThenElse requires 3 operands, but got " << doc->operands.size();
PrintDoc(doc->operands[1]);
PrintChildExpr(doc->operands[1], doc);
output_ << " if ";
PrintDoc(doc->operands[0]);
PrintChildExprConservatively(doc->operands[0], doc);
output_ << " else ";
PrintDoc(doc->operands[2]);
PrintChildExprConservatively(doc->operands[2], doc);
} else {
LOG(FATAL) << "Unknown OperationDocNode::Kind " << static_cast<int>(doc->kind);
throw;
}
}

void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) {
PrintDoc(doc->callee);
PrintChildExpr(doc->callee, doc);

output_ << "(";

Expand Down Expand Up @@ -285,7 +438,7 @@ void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) {
output_ << "lambda ";
PrintJoinedDocs(doc->args, ", ");
output_ << ": ";
PrintDoc(doc->body);
PrintChildExpr(doc->body, doc);
}

void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) {
Expand Down Expand Up @@ -444,8 +597,10 @@ void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) {
PrintJoinedDocs(doc->args, ", ");
output_ << ")";

output_ << " -> ";
PrintDoc(doc->return_type);
if (doc->return_type.defined()) {
output_ << " -> ";
PrintDoc(doc->return_type.value());
}

output_ << ":";

Expand Down
47 changes: 29 additions & 18 deletions tests/python/unittest/test_tvmscript_printer_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,31 @@

import pytest

import tvm
from tvm.script.printer.doc import (
LiteralDoc,
IdDoc,
AssertDoc,
AssignDoc,
AttrAccessDoc,
IndexDoc,
CallDoc,
OperationKind,
OperationDoc,
ClassDoc,
DictDoc,
ExprStmtDoc,
ForDoc,
FunctionDoc,
IdDoc,
IfDoc,
IndexDoc,
LambdaDoc,
TupleDoc,
ListDoc,
DictDoc,
LiteralDoc,
OperationDoc,
OperationKind,
ReturnDoc,
ScopeDoc,
SliceDoc,
StmtBlockDoc,
AssignDoc,
IfDoc,
TupleDoc,
WhileDoc,
ForDoc,
ScopeDoc,
ExprStmtDoc,
AssertDoc,
ReturnDoc,
FunctionDoc,
ClassDoc,
)


Expand Down Expand Up @@ -450,6 +451,13 @@ def test_return_doc():
[IdDoc("test"), IdDoc("test2")],
],
)
@pytest.mark.parametrize(
"return_type",
[
None,
LiteralDoc(None),
],
)
@pytest.mark.parametrize(
"body",
[
Expand All @@ -458,9 +466,8 @@ def test_return_doc():
[ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
],
)
def test_function_doc(args, decorators, body):
def test_function_doc(args, decorators, return_type, body):
name = IdDoc("name")
return_type = LiteralDoc(None)

doc = FunctionDoc(name, args, decorators, return_type, body)

Expand Down Expand Up @@ -504,3 +511,7 @@ def test_stmt_doc_comment():
comment = "test comment"
doc.comment = comment
assert doc.comment == comment


if __name__ == "__main__":
tvm.testing.main()
Loading