Skip to content

Commit

Permalink
Add expr precedence
Browse files Browse the repository at this point in the history
  • Loading branch information
yelite committed Jul 21, 2022
1 parent f52ac4d commit d88bcc7
Show file tree
Hide file tree
Showing 2 changed files with 449 additions and 10 deletions.
173 changes: 163 additions & 10 deletions src/script/printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,114 @@ namespace tvm {
namespace script {
namespace printer {

/*!
* \brief Operator precedence
*
* This is based on
* https://docs.python.org/3/reference/expressions.html#operator-precedence
*/
enum class ExprPrecedence : int32_t {
/*! \brief Unknown precedence */
kUnkown = 0,
/*! \brief Lambda Expression */
kLambda = 1,
/*! \brief Conditional Expression */
kIfThenElse = 2,
/*! \brief Boolean OR */
kBooleanOr = 3,
/*! \brief Boolean AND */
kBooleanAnd = 4,
/*! \brief Boolean NOT */
kBooleanNot = 5,
/*! \brief Comparisons */
kComparison = 6,
/*! \brief Bitwise OR */
kBitwiseOr = 7,
/*! \brief Bitwise XOR */
kBitwiseXor = 8,
/*! \brief Bitwise AND */
kBitwiseAnd = 9,
/*! \brief Shift Operators */
kShift = 10,
/*! \brief Addition and subtraction */
kAdd = 11,
/*! \brief Multiplication, division, floor division, remainder */
kMult = 12,
/*! \brief Positive negative and bitwise NOT */
kUnary = 13,
/*! \brief Exponentiation */
kExp = 14,
/*! \brief Index access, attribute access, call and atom expression */
kIdentity = 15,
};

#define DOC_PRECEDENCE_ENTRY(RefType, Precedence) \
{ RefType::ContainerType::RuntimeTypeIndex(), ExprPrecedence::Precedence }

ExprPrecedence GetExprPrecedence(const ExprDoc& doc) {
// Key is the value of OperationDocNode::Kind
static const std::vector<ExprPrecedence> op_kind_precedence = []() {
using OpKind = OperationDocNode::Kind;
std::map<OpKind, ExprPrecedence> raw_table = {
{OpKind::kUSub, ExprPrecedence::kUnary}, //
{OpKind::kInvert, ExprPrecedence::kUnary}, //
{OpKind::kAdd, ExprPrecedence::kAdd}, //
{OpKind::kSub, ExprPrecedence::kAdd}, //
{OpKind::kMult, ExprPrecedence::kMult}, //
{OpKind::kDiv, ExprPrecedence::kMult}, //
{OpKind::kFloorDiv, ExprPrecedence::kMult}, //
{OpKind::kMod, ExprPrecedence::kMult}, //
{OpKind::kPow, ExprPrecedence::kExp}, //
{OpKind::kLShift, ExprPrecedence::kShift}, //
{OpKind::kRShift, ExprPrecedence::kShift}, //
{OpKind::kBitAnd, ExprPrecedence::kBitwiseAnd}, //
{OpKind::kBitOr, ExprPrecedence::kBitwiseOr}, //
{OpKind::kBitXor, ExprPrecedence::kBitwiseXor}, //
{OpKind::kLt, ExprPrecedence::kComparison}, //
{OpKind::kLtE, ExprPrecedence::kComparison}, //
{OpKind::kEq, ExprPrecedence::kComparison}, //
{OpKind::kNotEq, ExprPrecedence::kComparison}, //
{OpKind::kGt, ExprPrecedence::kComparison}, //
{OpKind::kGtE, ExprPrecedence::kComparison}, //
{OpKind::kIfThenElse, ExprPrecedence::kIfThenElse}, //
};

std::vector<ExprPrecedence> table;
table.resize(static_cast<int>(OperationDocNode::Kind::kSpecialEnd) + 1);

for (const auto& kv : raw_table) {
table[static_cast<int>(kv.first)] = kv.second;
}

return table;
}();

// Key is the type index of Doc
static const std::unordered_map<uint32_t, ExprPrecedence> doc_type_precedence = {
DOC_PRECEDENCE_ENTRY(LiteralDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(IdDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(AttrAccessDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(IndexDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(CallDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(LambdaDoc, kLambda), //
DOC_PRECEDENCE_ENTRY(TupleDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(ListDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(DictDoc, kIdentity), //
};

if (const auto* op_doc = doc.as<OperationDocNode>()) {
ExprPrecedence precedence = op_kind_precedence[static_cast<int>(op_doc->kind)];
ICHECK(precedence != ExprPrecedence::kUnkown)
<< "Precedence for operator " << static_cast<int>(op_doc->kind) << " is unknown";
return precedence;
} else if (doc_type_precedence.find(doc->type_index()) != doc_type_precedence.end()) {
return doc_type_precedence.at(doc->type_index());
} else {
ICHECK(false) << "Precedence for doc type " << doc->GetTypeKey() << " is unknown";
throw;
}
}

class PythonDocPrinter : public DocPrinter {
public:
explicit PythonDocPrinter(int indent_spaces = 4) : DocPrinter(indent_spaces) {}
Expand Down Expand Up @@ -98,6 +206,42 @@ class PythonDocPrinter : public DocPrinter {
}
}

/*!
* \brief Print expression and add parenthesis if needed.
*/
void PrintChildExpr(const ExprDoc& doc, ExprPrecedence parent_precedence,
bool parenthesis_for_same_precedence = false) {
ExprPrecedence doc_precedence = GetExprPrecedence(doc);
if (doc_precedence < parent_precedence ||
(parenthesis_for_same_precedence && doc_precedence == parent_precedence)) {
output_ << "(";
PrintDoc(doc);
output_ << ")";
} else {
PrintDoc(doc);
}
}

/*!
* \brief Print expression and add parenthesis if doc has lower precedence than parent.
*/
void PrintChildExpr(const ExprDoc& doc, const ExprDoc& parent,
bool parenthesis_for_same_precedence = false) {
ExprPrecedence parent_precedence = GetExprPrecedence(parent);
return PrintChildExpr(doc, parent_precedence, parenthesis_for_same_precedence);
}

/*!
* \brief Print expression and add parenthesis if doc doesn't have higher precedence than parent.
*
* This function should be used to print an child expression that needs to be wrapped
* by parenthesis even if it has the same precedence as its parent, e.g., the `b` in `a + b`
* and the `b` and `c` in `a if b else c`.
*/
void PrintChildExprConservatively(const ExprDoc& doc, const ExprDoc& parent) {
PrintChildExpr(doc, parent, /*parenthesis_for_same_precedence*/ true);
}

void MaybePrintCommentInline(const StmtDoc& stmt) {
if (stmt->comment.defined()) {
const std::string& comment = stmt->comment.value();
Expand Down Expand Up @@ -161,12 +305,12 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; }

void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) {
PrintDoc(doc->value);
PrintChildExpr(doc->value, doc);
output_ << "." << doc->name;
}

void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) {
PrintDoc(doc->value);
PrintChildExpr(doc->value, doc);
if (doc->indices.size() == 0) {
output_ << "[()]";
} else {
Expand Down Expand Up @@ -226,29 +370,38 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) {
// Unary Operators
ICHECK_EQ(doc->operands.size(), 1);
output_ << OperatorToString(doc->kind);
PrintDoc(doc->operands[0]);
PrintChildExpr(doc->operands[0], doc);
} else if (doc->kind == OpKind::kPow) {
// Power operator is different than other binary operators
// It's right-associative and binds less tightly than unary operator on its right.
// https://docs.python.org/3/reference/expressions.html#the-power-operator
// https://docs.python.org/3/reference/expressions.html#operator-precedence
ICHECK_EQ(doc->operands.size(), 2);
PrintChildExprConservatively(doc->operands[0], doc);
output_ << " ** ";
PrintChildExpr(doc->operands[1], ExprPrecedence::kUnary);
} else if (doc->kind < OpKind::kBinaryEnd) {
// Binary Operator
ICHECK_EQ(doc->operands.size(), 2);
PrintDoc(doc->operands[0]);
PrintChildExpr(doc->operands[0], doc);
output_ << " " << OperatorToString(doc->kind) << " ";
PrintDoc(doc->operands[1]);
PrintChildExprConservatively(doc->operands[1], doc);
} else if (doc->kind == OpKind::kIfThenElse) {
ICHECK_EQ(doc->operands.size(), 3)
<< "ValueError: IfThenElse requires 3 operands, but got " << doc->operands.size();
PrintDoc(doc->operands[1]);
PrintChildExpr(doc->operands[1], doc);
output_ << " if ";
PrintDoc(doc->operands[0]);
PrintChildExprConservatively(doc->operands[0], doc);
output_ << " else ";
PrintDoc(doc->operands[2]);
PrintChildExprConservatively(doc->operands[2], doc);
} else {
LOG(FATAL) << "Unknown OperationDocNode::Kind " << static_cast<int>(doc->kind);
throw;
}
}

void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) {
PrintDoc(doc->callee);
PrintChildExpr(doc->callee, doc);

output_ << "(";

Expand Down Expand Up @@ -285,7 +438,7 @@ void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) {
output_ << "lambda ";
PrintJoinedDocs(doc->args, ", ");
output_ << ": ";
PrintDoc(doc->body);
PrintChildExpr(doc->body, doc);
}

void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) {
Expand Down
Loading

0 comments on commit d88bcc7

Please sign in to comment.