From 0605028704dc60a9709159e1b649379f49e2aa7a Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 15 Jul 2022 15:28:55 -0400 Subject: [PATCH] [TVMScript] StmtDoc Definitions This PR addes: - All StmtDoc subclasses - Python bindings for StmtDoc Tracking issue: https://github.com/apache/tvm/issues/11912 --- include/tvm/script/printer/doc.h | 506 ++++++++++++++++++ python/tvm/script/printer/doc.py | 163 ++++++ src/script/printer/doc.cc | 170 ++++++ .../unittest/test_tvmscript_printer_doc.py | 260 +++++++++ 4 files changed, 1099 insertions(+) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index f3f980e53f5e..db9e6e0b458c 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -119,6 +119,79 @@ class ExprDoc : public Doc { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode); }; +/*! + * \brief The base class of statement doc. + * + * \sa StmtDoc + */ +class StmtDocNode : public DocNode { + public: + /*! + * \brief The comment of this doc. + * + * The actual position of the comment depends on the type of Doc + * and also the DocPrinter implmenetation. It could be on the same + * line as the statment, or the line above, or inside the statement + * if it spans over multiple lines. + * */ + mutable Optional comment{NullOpt}; + + void VisitAttrs(AttrVisitor* v) { + DocNode::VisitAttrs(v); + v->Visit("comment", &comment); + } + + static constexpr const char* _type_key = "script.printer.StmtDoc"; + TVM_DECLARE_BASE_OBJECT_INFO(StmtDocNode, DocNode); +}; + +/*! + * \brief Reference type of StmtDocNode. + * + * \sa StmtDocNode + */ +class StmtDoc : public Doc { + protected: + StmtDoc() = default; + + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StmtDoc, Doc, StmtDocNode); +}; + +/*! + * \brief The container doc that holds a list of StmtDoc. + * + * \sa StmtBlockDoc + */ +class StmtBlockDocNode : public DocNode { + public: + /*! \brief The list of statements. */ + Array stmts; + + void VisitAttrs(AttrVisitor* v) { + DocNode::VisitAttrs(v); + v->Visit("stmts", &stmts); + } + + static constexpr const char* _type_key = "script.printer.StmtBlockDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(StmtBlockDocNode, DocNode); +}; + +/*! + * \brief Reference type of StmtBlockDocNode. + * + * \sa StmtBlockDocNode + */ +class StmtBlockDoc : public Doc { + public: + /*! + * \brief Constructor of StmtBlockDoc. + * \param stmts The list of statements. + */ + explicit StmtBlockDoc(Array stmts); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StmtBlockDoc, Doc, StmtBlockDocNode); +}; + /*! * \brief Doc that represents literal value. * @@ -219,6 +292,7 @@ class IdDoc : public ExprDoc { * \param name The name of identifier. */ explicit IdDoc(String name); + explicit IdDoc(std::nullptr_t) : ExprDoc(nullptr) {} TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IdDoc, ExprDoc, IdDocNode); }; @@ -640,6 +714,438 @@ class SliceDoc : public Doc { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SliceDoc, Doc, SliceDocNode); }; +/*! + * \brief Doc that represents assign statement. + * + * \sa AssignDoc + */ +class AssignDocNode : public StmtDocNode { + public: + /*! \brief The left hand side of the assignment */ + ExprDoc lhs{nullptr}; + /*! + * \brief The right hand side of the assignment. + * + * If null, this doc represents declaration, e.g. `A: T.Buffer[(1,2)]` + * */ + Optional rhs; + /*! \brief The type annotation of this assignment. */ + Optional annotation; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("lhs", &lhs); + v->Visit("rhs", &rhs); + v->Visit("annotation", &annotation); + } + + static constexpr const char* _type_key = "script.printer.AssignDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(AssignDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of AssignDocNode. + * + * \sa AssignDoc + */ +class AssignDoc : public StmtDoc { + public: + /*! + * \brief Constructor of AssignDoc. + * \param lhs The left hand side of the assignment. + * \param rhs The right hand side of the assignment. + * \param annotation The type annotation of this assigment. + */ + explicit AssignDoc(ExprDoc lhs, Optional rhs, Optional annotation); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AssignDoc, StmtDoc, AssignDocNode); +}; + +/*! + * \brief Doc that represent if-then-else statement. + * + * \sa IfDoc + */ +class IfDocNode : public StmtDocNode { + public: + /*! \brief The predicate of the if-then-else statement. */ + ExprDoc predicate{nullptr}; + /*! \brief The then branch of the if-then-else statement. */ + Array then_branch; + /*! \brief The else branch of the if-then-else statement. */ + Array else_branch; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("predicate", &predicate); + v->Visit("then_branch", &then_branch); + v->Visit("else_branch", &else_branch); + } + + static constexpr const char* _type_key = "script.printer.IfDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of IfDocNode. + * + * \sa IfDocNode + */ +class IfDoc : public StmtDoc { + public: + /*! + * \brief Constructor of IfDoc. + * \param predicate The predicate of the if-then-else statement. + * \param then_branch The then branch of the if-then-else statement. + * \param else_branch The else branch of the if-then-else statement. + */ + explicit IfDoc(ExprDoc predicate, Array then_branch, Array else_branch); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IfDoc, StmtDoc, IfDocNode); +}; + +/*! + * \brief Doc that represents while statement. + * + * \sa WhileDoc + */ +class WhileDocNode : public StmtDocNode { + public: + /*! \brief The predicate of the while statement. */ + ExprDoc predicate{nullptr}; + /*! \brief The body of the while statement. */ + Array body; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("predicate", &predicate); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "script.printer.WhileDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(WhileDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of WhileDocNode. + * + * \sa WhileDocNode + */ +class WhileDoc : public StmtDoc { + public: + /*! + * \brief Constructor of WhileDoc. + * \param predicate The predicate of the while statement. + * \param body The body of the while statement. + */ + explicit WhileDoc(ExprDoc predicate, Array body); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(WhileDoc, StmtDoc, WhileDocNode); +}; + +/*! + * \brief Doc that represents for statement. + * + * Example: + * for 'lhs' in 'rhs': + * 'body...' + * + * \sa ForDoc + */ +class ForDocNode : public StmtDocNode { + public: + /*! \brief The left hand side of the assignment of iterating variable. */ + ExprDoc lhs{nullptr}; + /*! \brief The right hand side of the assignment of iterating variable. */ + ExprDoc rhs{nullptr}; + /*! \brief The body of the for statement. */ + Array body; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("lhs", &lhs); + v->Visit("rhs", &rhs); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "script.printer.ForDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ForDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of ForDocNode. + * + * \sa ForDocNode + */ +class ForDoc : public StmtDoc { + public: + /*! + * \brief Constructor of ForDoc. + * \param lhs The left hand side of the assignment of iterating variable. + * \param rhs The right hand side of the assignment of iterating variable. + * \param body The body of the for statement. + */ + explicit ForDoc(ExprDoc lhs, ExprDoc rhs, Array body); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ForDoc, StmtDoc, ForDocNode); +}; + +/*! + * \brief Doc that represents special scopes. + * + * Specificially, this means the with statment in Python: + * + * with 'rhs' as 'lhs': + * 'body...' + * + * \sa ScopeDoc + */ +class ScopeDocNode : public StmtDocNode { + public: + /*! \brief The name of the scoped variable. */ + Optional lhs{NullOpt}; + /*! \brief The value of the scoped variable. */ + ExprDoc rhs{nullptr}; + /*! \brief The body of the scope doc. */ + Array body; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("lhs", &lhs); + v->Visit("rhs", &rhs); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "script.printer.ScopeDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ScopeDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of ScopeDocNode. + * + * \sa ScopeDocNode + */ +class ScopeDoc : public StmtDoc { + public: + /*! + * \brief Constructor of ScopeDoc. + * \param lhs The name of the scoped variable. + * \param rhs The value of the scoped variable. + * \param body The body of the scope doc. + */ + explicit ScopeDoc(Optional lhs, ExprDoc rhs, Array body); + + /*! + * \brief Constructor of ScopeDoc. + * \param rhs The value of the scoped variable. + * \param body The body of the scope doc. + */ + explicit ScopeDoc(ExprDoc rhs, Array body); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ScopeDoc, StmtDoc, ScopeDocNode); +}; + +/*! + * \brief Doc that represents an expression as statement. + * + * \sa ExprStmtDoc + */ +class ExprStmtDocNode : public StmtDocNode { + public: + /*! \brief The expression represented by this doc. */ + ExprDoc expr{nullptr}; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("expr", &expr); + } + + static constexpr const char* _type_key = "script.printer.ExprStmtDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExprStmtDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of ExprStmtDocNode. + * + * \sa ExprStmtDocNode + */ +class ExprStmtDoc : public StmtDoc { + public: + /*! + * \brief Constructor of ExprStmtDoc. + * \param expr The expression represented by this doc. + */ + explicit ExprStmtDoc(ExprDoc expr); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprStmtDoc, StmtDoc, ExprStmtDocNode); +}; + +/*! + * \brief Doc that represents assert statement. + * + * \sa AssertDoc + */ +class AssertDocNode : public StmtDocNode { + public: + /*! \brief The expression to test. */ + ExprDoc test{nullptr}; + /*! \brief The optional error message when assertion failed. */ + Optional msg{NullOpt}; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("test", &test); + v->Visit("msg", &msg); + } + + static constexpr const char* _type_key = "script.printer.AssertDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(AssertDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of AssertDocNode. + * + * \sa AssertDocNode + */ +class AssertDoc : public StmtDoc { + public: + /*! + * \brief Constructor of AssertDoc. + * \param test The expression to test. + * \param msg The optional error message when assertion failed. + */ + explicit AssertDoc(ExprDoc test, Optional msg = NullOpt); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AssertDoc, StmtDoc, AssertDocNode); +}; + +/*! + * \brief Doc that represents return statement. + * + * \sa ReturnDoc + */ +class ReturnDocNode : public StmtDocNode { + public: + /*! \brief The value to return. */ + ExprDoc value{nullptr}; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "script.printer.ReturnDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReturnDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of ReturnDocNode. + * + * \sa ReturnDocNode + */ +class ReturnDoc : public StmtDoc { + public: + /*! + * \brief Constructor of ReturnDoc. + * \param value The value to return. + */ + explicit ReturnDoc(ExprDoc value); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ReturnDoc, StmtDoc, ReturnDocNode); +}; + +/*! + * \brief Doc that represents function definition. + * + * \sa FunctionDoc + */ +class FunctionDocNode : public StmtDocNode { + public: + /*! \brief The name of function. */ + IdDoc name{nullptr}; + /*! + * \brief The arguments of function. + * + * The `lhs` means argument name, + * `annotation` means argument type, + * and `rhs` means default value. + */ + Array args; + /*! \brief Decorators of function. */ + Array decorators; + /*! \brief The return type of function. */ + ExprDoc return_type{nullptr}; + /*! \brief The body of function. */ + Array body; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("args", &args); + v->Visit("decorators", &decorators); + v->Visit("return_type", &return_type); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "script.printer.FunctionDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of FunctionDocNode. + * + * \sa FunctionDocNode + */ +class FunctionDoc : public StmtDoc { + public: + /*! + * \brief Constructor of FunctionDoc. + * \param name The name of function.. + * \param args The arguments of function. + * \param decorators The decorator of function. + * \param return_type The return type of function. + * \param body The body of function. + */ + explicit FunctionDoc(IdDoc name, Array args, Array decorators, + ExprDoc return_type, Array body); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionDoc, StmtDoc, FunctionDocNode); +}; + +/*! + * \brief Doc that represents class definition. + * + * \sa ClassDoc + */ +class ClassDocNode : public StmtDocNode { + public: + /*! \brief The name of class. */ + IdDoc name{nullptr}; + /*! \brief Decorators of class. */ + Array decorators; + /*! \brief The body of class. */ + Array body; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("decorators", &decorators); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "script.printer.ClassDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ClassDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of ClassDocNode. + * + * \sa ClassDocNode + */ +class ClassDoc : public StmtDoc { + public: + /*! + * \brief Constructor of ClassDoc. + * \param name The name of class. + * \param decorators The decorator of class. + * \param body The body of class. + */ + explicit ClassDoc(IdDoc name, Array decorators, Array body); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ClassDoc, StmtDoc, ClassDocNode); +}; + } // namespace printer } // namespace script } // namespace tvm diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index acdb63dcf250..0ffdce075fa0 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -100,6 +100,31 @@ def __iter__(self): raise RuntimeError(f"{self.__class__} cannot be used as iterable.") +class StmtDoc(Doc): + """Base class of statement doc""" + + @property + def comment(self) -> Optional[str]: + # It has to call the dunder method to avoid infinite recursion + # pylint: disable=unnecessary-dunder-call + return self.__getattr__("comment") + # pylint: enable=unnecessary-dunder-call + + @comment.setter + def comment(self, value): + return _ffi_api.StmtDocSetComment(self, value) # type: ignore + + +@tvm._ffi.register_object("script.printer.StmtBlockDoc") +class StmtBlockDoc(Doc): + """The container doc that holds a list of StmtDoc.""" + + stmts: Sequence[StmtDoc] + + def __init__(self, stmts: List[StmtDoc]): + self.__init_handle_by_constructor__(_ffi_api.StmtBlockDoc, stmts) # type: ignore + + @tvm._ffi.register_object("script.printer.LiteralDoc") class LiteralDoc(ExprDoc): """Doc that represents literal value""" @@ -293,3 +318,141 @@ def __init__( step: Optional[ExprDoc] = None, ): self.__init_handle_by_constructor__(_ffi_api.SliceDoc, start, stop, step) # type: ignore + + +@tvm._ffi.register_object("script.printer.AssignDoc") +class AssignDoc(StmtDoc): + """Doc that represents assign statement.""" + + lhs: ExprDoc + rhs: Optional[ExprDoc] + annotation: Optional[ExprDoc] + + def __init__(self, lhs: ExprDoc, rhs: Optional[ExprDoc], annotation: Optional[ExprDoc] = None): + # pylint: disable=line-too-long + self.__init_handle_by_constructor__(_ffi_api.AssignDoc, lhs, rhs, annotation) # type: ignore + # pylint: enable=line-too-long + + +@tvm._ffi.register_object("script.printer.IfDoc") +class IfDoc(StmtDoc): + """Doc that represent if-then-else statement.""" + + predicate: ExprDoc + then_branch: Sequence[StmtDoc] + else_branch: Sequence[StmtDoc] + + def __init__(self, predicate: ExprDoc, then_branch: List[StmtDoc], else_branch: List[StmtDoc]): + # pylint: disable=line-too-long + self.__init_handle_by_constructor__(_ffi_api.IfDoc, predicate, then_branch, else_branch) # type: ignore + # pylint: enable=line-too-long + + +@tvm._ffi.register_object("script.printer.WhileDoc") +class WhileDoc(StmtDoc): + """Doc that represents while statement.""" + + predicate: ExprDoc + body: Sequence[StmtDoc] + + def __init__(self, predicate: ExprDoc, body: List[StmtDoc]): + self.__init_handle_by_constructor__(_ffi_api.WhileDoc, predicate, body) # type: ignore + + +@tvm._ffi.register_object("script.printer.ForDoc") +class ForDoc(StmtDoc): + """Doc that represents for statement.""" + + lhs: ExprDoc + rhs: ExprDoc + body: Sequence[StmtDoc] + + def __init__(self, lhs: ExprDoc, rhs: ExprDoc, body: List[StmtDoc]): + self.__init_handle_by_constructor__(_ffi_api.ForDoc, lhs, rhs, body) # type: ignore + + +@tvm._ffi.register_object("script.printer.ScopeDoc") +class ScopeDoc(StmtDoc): + """ + Doc that represents special scopes. + + Specificially, this means the with statment in Python: + + with as : + + """ + + lhs: Optional[ExprDoc] + rhs: ExprDoc + body: Sequence[StmtDoc] + + def __init__(self, lhs: Optional[ExprDoc], rhs: ExprDoc, body: List[StmtDoc]): + self.__init_handle_by_constructor__(_ffi_api.ScopeDoc, lhs, rhs, body) # type: ignore + + +@tvm._ffi.register_object("script.printer.ExprStmtDoc") +class ExprStmtDoc(StmtDoc): + """Doc that represents an expression as statement.""" + + expr: ExprDoc + + def __init__(self, expr: ExprDoc): + self.__init_handle_by_constructor__(_ffi_api.ExprStmtDoc, expr) # type: ignore + + +@tvm._ffi.register_object("script.printer.AssertDoc") +class AssertDoc(StmtDoc): + """Doc that represents assert statement.""" + + test: ExprDoc + msg: Optional[ExprDoc] + + def __init__(self, test: ExprDoc, msg: Optional[ExprDoc] = None): + self.__init_handle_by_constructor__(_ffi_api.AssertDoc, test, msg) # type: ignore + + +@tvm._ffi.register_object("script.printer.ReturnDoc") +class ReturnDoc(StmtDoc): + """Doc that represents return statement.""" + + value: ExprDoc + + def __init__(self, value: ExprDoc): + self.__init_handle_by_constructor__(_ffi_api.ReturnDoc, value) # type: ignore + + +@tvm._ffi.register_object("script.printer.FunctionDoc") +class FunctionDoc(StmtDoc): + """Doc that represents function definition.""" + + name: IdDoc + args: Sequence[AssignDoc] + decorators: Sequence[ExprDoc] + return_type: ExprDoc + body: Sequence[StmtDoc] + + def __init__( + self, + name: IdDoc, + args: List[AssignDoc], + decorators: List[ExprDoc], + return_type: ExprDoc, + body: List[StmtDoc], + ): + # pylint: disable=line-too-long + self.__init_handle_by_constructor__(_ffi_api.FunctionDoc, name, args, decorators, return_type, body) # type: ignore + # pylint: enable=line-too-long + + +@tvm._ffi.register_object("script.printer.ClassDoc") +class ClassDoc(StmtDoc): + """Doc that represents class definition.""" + + name: IdDoc + decorators: Sequence[ExprDoc] + body: Sequence[StmtDoc] + + def __init__(self, name: IdDoc, decorators: List[ExprDoc], body: List[StmtDoc]): + # pylint: disable=line-too-long + self.__init_handle_by_constructor__(_ffi_api.ClassDoc, name, decorators, body) # type: ignore + # pylint: enable=line-too-long diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index ed81f9d2dd26..bfff0cfad4fe 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include +#include #include #include @@ -38,6 +40,12 @@ ExprDoc ExprDocNode::Call(Array args, Array kwargs_ return CallDoc(GetRef(this), args, kwargs_keys, kwargs_values); } +StmtBlockDoc::StmtBlockDoc(Array stmts) { + ObjectPtr n = make_object(); + n->stmts = stmts; + this->data_ = std::move(n); +} + LiteralDoc::LiteralDoc(ObjectRef value) { ObjectPtr n = make_object(); n->value = value; @@ -115,6 +123,99 @@ SliceDoc::SliceDoc(Optional start, Optional stop, Optionaldata_ = std::move(n); } +AssignDoc::AssignDoc(ExprDoc lhs, Optional rhs, Optional annotation) { + CHECK(rhs.defined() || annotation.defined()) + << "ValueError: At least one of rhs and annotation needs to be non-null for AssignDoc."; + CHECK(lhs->IsInstance() || annotation == nullptr) + << "ValueError: annotation can only be nonnull if lhs is an identifier."; + + ObjectPtr n = make_object(); + n->lhs = lhs; + n->rhs = rhs; + n->annotation = annotation; + this->data_ = std::move(n); +} + +IfDoc::IfDoc(ExprDoc predicate, Array then_branch, Array else_branch) { + CHECK(!then_branch.empty() || !else_branch.empty()) + << "ValueError: At least one of the then branch or else branch needs to be non-empty."; + + ObjectPtr n = make_object(); + n->predicate = predicate; + n->then_branch = then_branch; + n->else_branch = else_branch; + this->data_ = std::move(n); +} + +WhileDoc::WhileDoc(ExprDoc predicate, Array body) { + ObjectPtr n = make_object(); + n->predicate = predicate; + n->body = body; + this->data_ = std::move(n); +} + +ForDoc::ForDoc(ExprDoc lhs, ExprDoc rhs, Array body) { + ObjectPtr n = make_object(); + n->lhs = lhs; + n->rhs = rhs; + n->body = body; + this->data_ = std::move(n); +} + +ScopeDoc::ScopeDoc(Optional lhs, ExprDoc rhs, Array body) { + ObjectPtr n = make_object(); + n->lhs = lhs; + n->rhs = rhs; + n->body = body; + this->data_ = std::move(n); +} + +ScopeDoc::ScopeDoc(ExprDoc rhs, Array body) { + ObjectPtr n = make_object(); + n->lhs = NullOpt; + n->rhs = rhs; + n->body = body; + this->data_ = std::move(n); +} + +ExprStmtDoc::ExprStmtDoc(ExprDoc expr) { + ObjectPtr n = make_object(); + n->expr = expr; + this->data_ = std::move(n); +} + +AssertDoc::AssertDoc(ExprDoc test, Optional msg) { + ObjectPtr n = make_object(); + n->test = test; + n->msg = msg; + this->data_ = std::move(n); +} + +ReturnDoc::ReturnDoc(ExprDoc value) { + ObjectPtr n = make_object(); + n->value = value; + this->data_ = std::move(n); +} + +FunctionDoc::FunctionDoc(IdDoc name, Array args, Array decorators, + ExprDoc return_type, Array body) { + ObjectPtr n = make_object(); + n->name = name; + n->args = args; + n->decorators = decorators; + n->return_type = return_type; + n->body = body; + this->data_ = std::move(n); +} + +ClassDoc::ClassDoc(IdDoc name, Array decorators, Array body) { + ObjectPtr n = make_object(); + n->name = name; + n->decorators = decorators; + n->body = body; + this->data_ = std::move(n); +} + TVM_REGISTER_NODE_TYPE(DocNode); TVM_REGISTER_NODE_TYPE(ExprDocNode); @@ -125,6 +226,15 @@ TVM_REGISTER_GLOBAL("script.printer.ExprDocCall") .set_body_method, Array, Array>( &ExprDocNode::Call); +TVM_REGISTER_NODE_TYPE(StmtDocNode); +TVM_REGISTER_GLOBAL("script.printer.StmtDocSetComment") + .set_body_typed([](StmtDoc doc, Optional comment) { doc->comment = comment; }); + +TVM_REGISTER_NODE_TYPE(StmtBlockDocNode); +TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array stmts) { + return StmtBlockDoc(stmts); +}); + TVM_REGISTER_NODE_TYPE(LiteralDocNode); TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int); @@ -185,6 +295,66 @@ TVM_REGISTER_GLOBAL("script.printer.SliceDoc") .set_body_typed([](Optional start, Optional stop, Optional step) { return SliceDoc(start, stop, step); }); + +TVM_REGISTER_NODE_TYPE(AssignDocNode); +TVM_REGISTER_GLOBAL("script.printer.AssignDoc") + .set_body_typed([](ExprDoc lhs, Optional rhs, Optional annotation) { + return AssignDoc(lhs, rhs, annotation); + }); + +TVM_REGISTER_NODE_TYPE(IfDocNode); +TVM_REGISTER_GLOBAL("script.printer.IfDoc") + .set_body_typed([](ExprDoc predicate, Array then_branch, Array else_branch) { + return IfDoc(predicate, then_branch, else_branch); + }); + +TVM_REGISTER_NODE_TYPE(WhileDocNode); +TVM_REGISTER_GLOBAL("script.printer.WhileDoc") + .set_body_typed([](ExprDoc predicate, Array body) { + return WhileDoc(predicate, body); + }); + +TVM_REGISTER_NODE_TYPE(ForDocNode); +TVM_REGISTER_GLOBAL("script.printer.ForDoc") + .set_body_typed([](ExprDoc lhs, ExprDoc rhs, Array body) { + return ForDoc(lhs, rhs, body); + }); + +TVM_REGISTER_NODE_TYPE(ScopeDocNode); +TVM_REGISTER_GLOBAL("script.printer.ScopeDoc") + .set_body_typed([](Optional lhs, ExprDoc rhs, Array body) { + return ScopeDoc(lhs, rhs, body); + }); + +TVM_REGISTER_NODE_TYPE(ExprStmtDocNode); +TVM_REGISTER_GLOBAL("script.printer.ExprStmtDoc").set_body_typed([](ExprDoc expr) { + return ExprStmtDoc(expr); +}); + +TVM_REGISTER_NODE_TYPE(AssertDocNode); +TVM_REGISTER_GLOBAL("script.printer.AssertDoc") + .set_body_typed([](ExprDoc test, Optional msg = NullOpt) { + return AssertDoc(test, msg); + }); + +TVM_REGISTER_NODE_TYPE(ReturnDocNode); +TVM_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value) { + return ReturnDoc(value); +}); + +TVM_REGISTER_NODE_TYPE(FunctionDocNode); +TVM_REGISTER_GLOBAL("script.printer.FunctionDoc") + .set_body_typed([](IdDoc name, Array args, Array decorators, + ExprDoc return_type, Array body) { + return FunctionDoc(name, args, decorators, return_type, body); + }); + +TVM_REGISTER_NODE_TYPE(ClassDocNode); +TVM_REGISTER_GLOBAL("script.printer.ClassDoc") + .set_body_typed([](IdDoc name, Array decorators, Array body) { + return ClassDoc(name, decorators, body); + }); + } // namespace printer } // namespace script } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py index 4ff6a0f547d7..040a82901059 100644 --- a/tests/python/unittest/test_tvmscript_printer_doc.py +++ b/tests/python/unittest/test_tvmscript_printer_doc.py @@ -14,6 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +In this test file, we want to make sure the Python code can construct +Doc objects, then access and modify their attributes correctly. +""" + import pytest from tvm.script.printer.doc import ( @@ -29,6 +34,17 @@ ListDoc, DictDoc, SliceDoc, + StmtBlockDoc, + AssignDoc, + IfDoc, + WhileDoc, + ForDoc, + ScopeDoc, + ExprStmtDoc, + AssertDoc, + ReturnDoc, + FunctionDoc, + ClassDoc, ) @@ -244,3 +260,247 @@ def test_expr_doc_call_with(args, kwargs): assert doc.callee == target assert list(doc.args) == args assert dict(zip(doc.kwargs_keys, doc.kwargs_values)) == kwargs + + +@pytest.mark.parametrize( + "stmts", + [ + [], + [ExprStmtDoc(IdDoc("x"))], + [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))], + ], +) +def test_stmt_block_doc(stmts): + doc = StmtBlockDoc(stmts) + + assert list(doc.stmts) == stmts + + +@pytest.mark.parametrize( + "lhs, rhs, annotation", + [ + (IdDoc("x"), IdDoc("y"), None), + (IdDoc("x"), None, IdDoc("int")), + (IdDoc("x"), IdDoc("y"), IdDoc("int")), + ], +) +def test_assign_doc(lhs, rhs, annotation): + doc = AssignDoc(lhs, rhs, annotation) + + assert doc.lhs == lhs + assert doc.rhs == rhs + assert doc.annotation == annotation + + +@pytest.mark.parametrize( + "lhs, rhs, annotation", + [ + (IdDoc("x"), None, None), + (TupleDoc([IdDoc("x"), IdDoc("y")]), None, IdDoc("int")), + (TupleDoc([IdDoc("x"), IdDoc("y")]), IdDoc("u"), IdDoc("int")), + ], +) +def test_invalid_assign_doc(lhs, rhs, annotation): + with pytest.raises(ValueError) as e: + AssignDoc(lhs, rhs, annotation) + assert "AssignDoc" in str(e.value) + + +@pytest.mark.parametrize( + "else_branch", + [ + [], + [ExprStmtDoc(IdDoc("x"))], + [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))], + ], +) +@pytest.mark.parametrize( + "then_branch", + [ + [], + [ExprStmtDoc(IdDoc("x"))], + [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))], + ], +) +def test_if_doc(then_branch, else_branch): + predicate = IdDoc("x") + + if not then_branch and not else_branch: + with pytest.raises(ValueError) as e: + IfDoc(predicate, then_branch, else_branch) + assert "IfDoc" in str(e.value) + return + else: + doc = IfDoc(predicate, then_branch, else_branch) + + assert doc.predicate == predicate + assert list(doc.then_branch) == then_branch + assert list(doc.else_branch) == else_branch + + +@pytest.mark.parametrize( + "body", + [ + [], + [ExprStmtDoc(IdDoc("x"))], + [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))], + ], +) +def test_while_doc(body): + predicate = IdDoc("x") + + doc = WhileDoc(predicate, body) + + assert doc.predicate == predicate + assert list(doc.body) == body + + +@pytest.mark.parametrize( + "body", + [ + [], + [ExprStmtDoc(IdDoc("x"))], + [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))], + ], +) +def test_for_doc(body): + lhs = IdDoc("x") + rhs = IdDoc("y") + + doc = ForDoc(lhs, rhs, body) + + assert doc.lhs == lhs + assert doc.rhs == rhs + assert list(doc.body) == body + + +@pytest.mark.parametrize( + "lhs", + [ + None, + IdDoc("x"), + ], +) +@pytest.mark.parametrize( + "body", + [ + [], + [ExprStmtDoc(IdDoc("x"))], + [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))], + ], +) +def test_scope_doc(lhs, body): + rhs = IdDoc("y") + + doc = ScopeDoc(lhs, rhs, body) + + assert doc.lhs == lhs + assert doc.rhs == rhs + assert list(doc.body) == body + + +def test_expr_stmt_doc(): + expr = IdDoc("x") + + doc = ExprStmtDoc(expr) + + assert doc.expr == expr + + +@pytest.mark.parametrize( + "msg", + [ + None, + LiteralDoc("msg"), + ], +) +def test_assert_doc(msg): + test = IdDoc("x") + + doc = AssertDoc(test, msg) + + assert doc.test == test + assert doc.msg == msg + + +def test_return_doc(): + value = IdDoc("x") + + doc = ReturnDoc(value) + + assert doc.value == value + + +@pytest.mark.parametrize( + "args", + [ + [], + [AssignDoc(IdDoc("x"), None, IdDoc("int"))], + [ + AssignDoc(IdDoc("x"), None, IdDoc("int")), + AssignDoc(IdDoc("y"), LiteralDoc(1), IdDoc("int")), + ], + ], +) +@pytest.mark.parametrize( + "decorators", + [ + [], + [IdDoc("test")], + [IdDoc("test"), IdDoc("test2")], + ], +) +@pytest.mark.parametrize( + "body", + [ + [], + [ExprStmtDoc(IdDoc("x"))], + [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))], + ], +) +def test_function_doc(args, decorators, body): + name = IdDoc("name") + return_type = LiteralDoc(None) + + doc = FunctionDoc(name, args, decorators, return_type, body) + + assert doc.name == name + assert list(doc.args) == args + assert list(doc.decorators) == decorators + assert doc.return_type == return_type + assert list(doc.body) == body + + +@pytest.mark.parametrize( + "decorators", + [ + [], + [IdDoc("test")], + [IdDoc("test"), IdDoc("test2")], + ], +) +@pytest.mark.parametrize( + "body", + [ + [], + [ExprStmtDoc(IdDoc("x"))], + [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))], + ], +) +def test_class_doc(decorators, body): + name = IdDoc("name") + + doc = ClassDoc(name, decorators, body) + + assert doc.name == name + assert list(doc.decorators) == decorators + assert list(doc.body) == body + + +def test_stmt_doc_comment(): + doc = ExprStmtDoc(IdDoc("x")) + assert doc.comment is None + + comment = "test comment" + doc.comment = comment + assert doc.comment == comment