From e6173430f491c1d88d2ab77ce0ab43a8c602df30 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 21 Dec 2022 00:42:29 -0500 Subject: [PATCH] [REFACTOR][ARCH] Introduce StructInfo M0 (#314) * [IR] Introduce StructInfo * StructInfoFunctor and Analysis Support * [TVMScript] Parse type/shape annotation with StructInfo * remove runtime type assign * Remove type/shape during parsing (#2) * Normalizer prep: simple checks and legacy function renaming. * Struct info deduction in BlockBuilder. * Two TODOs * StructInfo Normalizer Fixes (#3) * StructInfo AST Fix * Fix Extern Func Deduction and shape mutator. * Update VoidStructInfo & globalvar (#4) * Fix passes and proper sinfo propagation. * Refactor EraseToWellDefined to Enable Remapping * [WIP] First stab at symbolic param tracking * Update EraseToWellDefined to support symbolic shape return (#5) * fix R.shape with ndim (#6) * Remove update shape/type * Address review comment, AnnotateTypeShape=>AnnotateStructInfo * Update include/tvm/script/ir_builder/relax/frame.h Co-authored-by: Ruihang Lai * Address comments * Update printer to use structinfo (#7) * Update Error mechanism to prep for obj loc based reporting * Symbolic shape aware function call return value derivation. The main flow works as follows: - Match and populate shape_var_map and var_map by visit each pair of param and call arguments. - Call EraseToWellDefined to map the ret parameter to new result. * [ANALYSIS] Refactor well-form to only look at struct info. * Update comments according to reviews. * Update include/tvm/relax/struct_info.h Co-authored-by: Ruihang Lai Co-authored-by: Siyuan Feng Co-authored-by: Tianqi Chen Co-authored-by: Ruihang Lai --- include/tvm/ir/diagnostic.h | 27 + include/tvm/ir/expr.h | 8 + include/tvm/ir/type.h | 2 +- include/tvm/relax/analysis.h | 246 ++- include/tvm/relax/block_builder.h | 37 +- include/tvm/relax/expr.h | 90 +- include/tvm/relax/expr_functor.h | 27 +- include/tvm/relax/op_attr_types.h | 10 + include/tvm/relax/struct_info.h | 417 +++++ include/tvm/relax/struct_info_functor.h | 151 ++ include/tvm/relax/type.h | 17 +- include/tvm/script/ir_builder/relax/frame.h | 4 + include/tvm/script/ir_builder/relax/ir.h | 65 +- python/tvm/ir/expr.py | 12 + python/tvm/relax/__init__.py | 11 + python/tvm/relax/analysis/analysis.py | 156 +- python/tvm/relax/block_builder.py | 68 +- python/tvm/relax/expr.py | 18 +- python/tvm/relax/expr_functor.py | 11 +- python/tvm/relax/struct_info.py | 231 +++ python/tvm/relax/ty.py | 13 +- python/tvm/script/ir_builder/relax/ir.py | 150 +- python/tvm/script/parser/relax/__init__.py | 4 +- python/tvm/script/parser/relax/entry.py | 104 +- python/tvm/script/parser/relax/parser.py | 63 +- python/tvm/script/parser_v1/parser.py | 1403 ----------------- src/ir/diagnostic.cc | 28 + src/ir/type.cc | 3 +- src/printer/relax_script_printer.cc | 130 +- src/printer/text_printer.h | 18 +- src/relax/analysis/shape_analysis.cc | 55 + src/relax/analysis/struct_info_analysis.cc | 933 +++++++++++ src/relax/analysis/well_formed.cc | 247 +-- src/relax/ir/block_builder.cc | 636 ++++---- src/relax/ir/expr.cc | 179 ++- src/relax/ir/expr_functor.cc | 149 +- src/relax/ir/struct_info.cc | 249 +++ src/relax/ir/struct_info_functor.cc | 130 ++ src/relax/ir/type.cc | 7 +- src/relax/ir/type_analysis.cc | 7 +- src/relax/op/op.cc | 136 +- src/relax/op/op_common.h | 3 +- src/relax/op/tensor/binary.cc | 107 +- src/relax/op/tensor/binary.h | 3 +- src/relax/op/tensor/ternary.cc | 92 +- src/relax/op/tensor/ternary.h | 2 +- src/relax/op/tensor/unary.cc | 59 +- src/relax/transform/fuse_ops.cc | 3 +- src/relax/transform/lambda_lift.cc | 6 +- src/relax/transform/normalize.cc | 22 +- src/relax/transform/to_non_dataflow.cc | 5 +- src/script/ir_builder/ir/ir.cc | 17 + src/script/ir_builder/relax/frame.cc | 3 +- src/script/ir_builder/relax/ir.cc | 194 ++- src/script/ir_builder/relax/utils.h | 1 + .../test_analysis_struct_info_analysis.py | 562 +++++++ .../python/relax/test_analysis_well_formed.py | 4 +- tests/python/relax/test_ast_printer.py | 2 +- tests/python/relax/test_blockbuilder.py | 29 +- tests/python/relax/test_expr.py | 15 +- tests/python/relax/test_expr_functor.py | 26 +- tests/python/relax/test_printer.py | 7 +- tests/python/relax/test_struct_info.py | 222 +++ .../test_transform_canonicalize_bindings.py | 1 + .../relax/test_transform_lambda_lift.py | 5 +- .../python/relax/test_tvmscript_ir_builder.py | 2 +- tests/python/relax/test_tvmscript_parser.py | 13 + tests/python/relax/test_vm.py | 6 +- 68 files changed, 4971 insertions(+), 2692 deletions(-) create mode 100644 include/tvm/relax/struct_info.h create mode 100644 include/tvm/relax/struct_info_functor.h create mode 100644 python/tvm/relax/struct_info.py delete mode 100644 python/tvm/script/parser_v1/parser.py create mode 100644 src/relax/analysis/shape_analysis.cc create mode 100644 src/relax/analysis/struct_info_analysis.cc create mode 100644 src/relax/ir/struct_info.cc create mode 100644 src/relax/ir/struct_info_functor.cc create mode 100644 tests/python/relax/test_analysis_struct_info_analysis.py create mode 100644 tests/python/relax/test_struct_info.py diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index 41130a5be0aa..3a62e5a704bd 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -58,6 +58,14 @@ class DiagnosticNode : public Object { DiagnosticLevel level; /*! \brief The span at which to report an error. */ Span span; + /*! + * \brief The object location at which to report an error. + * + * The object loc provides a location when span is not always + * available during transformation. The error reporter can + * still pick up loc->span if necessary. + */ + ObjectRef loc; /*! \brief The diagnostic message. */ String message; @@ -86,6 +94,18 @@ class Diagnostic : public ObjectRef { static DiagnosticBuilder Warning(Span span); static DiagnosticBuilder Note(Span span); static DiagnosticBuilder Help(Span span); + // variants uses object location + static DiagnosticBuilder Bug(ObjectRef loc); + static DiagnosticBuilder Error(ObjectRef loc); + static DiagnosticBuilder Warning(ObjectRef loc); + static DiagnosticBuilder Note(ObjectRef loc); + static DiagnosticBuilder Help(ObjectRef loc); + // variants uses object ptr. + static DiagnosticBuilder Bug(const Object* loc); + static DiagnosticBuilder Error(const Object* loc); + static DiagnosticBuilder Warning(const Object* loc); + static DiagnosticBuilder Note(const Object* loc); + static DiagnosticBuilder Help(const Object* loc); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Diagnostic, ObjectRef, DiagnosticNode); }; @@ -104,6 +124,11 @@ class DiagnosticBuilder { /*! \brief The span of the diagnostic. */ Span span; + /*! + * \brief The object location at which to report an error. + */ + ObjectRef loc; + template DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*) stream_ << val; @@ -117,6 +142,8 @@ class DiagnosticBuilder { DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {} + DiagnosticBuilder(DiagnosticLevel level, ObjectRef loc) : level(level), loc(loc) {} + operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); } private: diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index aaa7b45809b4..70fed8d9e3d0 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -376,6 +376,13 @@ class RelayExprNode : public BaseExprNode { */ mutable Optional shape_ = Optional(); + /*! + * \brief Stores the result of structure information of the + * expression that encapsulate both static shape and + * runtime information such as shape. + */ + mutable Optional struct_info_ = Optional(); + /*! * \return The checked_type */ @@ -471,6 +478,7 @@ class GlobalVarNode : public RelayExprNode { v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); + v->Visit("struct_info_", &struct_info_); } bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const { diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 12bb127d016b..dc1dbca289e1 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -132,7 +132,7 @@ class PrimType : public Type { * \brief Constructor * \param dtype The corresponding dtype. */ - TVM_DLL explicit PrimType(runtime::DataType dtype); + TVM_DLL explicit PrimType(runtime::DataType dtype, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode); }; diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index f922a9f5c2b0..4467d51f8ace 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -19,22 +19,266 @@ /*! * \file tvm/relax/analysis.h - * \brief The set of Relax specific analysis passes. + * \brief The set of Relax specific analysis on IR. */ #ifndef TVM_RELAX_ANALYSIS_H_ #define TVM_RELAX_ANALYSIS_H_ +#include #include #include #include +#include #include #include +#include #include namespace tvm { namespace relax { +//----------------------------------- +// Shape expression analysis +//---------------------------------- +/*! + * \brief Can prove the two symbolic shape arrays equals to each other. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param ana The analyzer used for integer analysis. + * \return The prove result. + * + * \note This function does best effort prove, which means + * if result is false, there is still possibility that + * two shapes equals to each other during runtime. + */ +TVM_DLL bool CanProveShapeEqual(const Array& lhs, const Array& rhs, + arith::Analyzer* ana); + +/*! + * \brief Can prove the two symbolic shape expressions equals to each other. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param ana The analyzer used for integer analysis. + * + * \note This function does best effort prove, which means + * if result is false, there is still possibility that + * two shapes equals to each other during runtime. + */ +TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana); + +//----------------------------------- +// Foundational StructInfo analysis +//----------------------------------- +/*! + * \brief Get the corresponding static type from a given struct info. + * \param info The struct info. + * \return the corresponding static type. + */ +TVM_DLL Type GetStaticType(const StructInfo& info); + +/*! + * \brief Get the corresponding struct info from static type. + * \param type The input type + * \return the corresponding struct info. + */ +TVM_DLL StructInfo StructInfoFromType(const Type& type); + +// TODO(relax-team): Remove legacy shape related functionalities after phasing out shape_ +/*! + * \brief Get the corresponding struct info from static type. + * \param type The input type + * \param shape_hint The shape hint + * \return the corresponding struct info. + */ +TVM_DLL StructInfo StructInfoFromTypeLegacyShapeHint(const Type& type, Optional shape_hint); + +/*! + * \brief Get the corresponding legacy shape hint from struct info + * \param info The struct info. + * \return the corresponding legacy shape hint. + */ +TVM_DLL Optional GetLegacyShapeHint(const StructInfo& info); + +/*! + * \return Derive the call's ret value struct info from inputs. + * \param func_info The function struct info. + * \param call The call expression to be derived. + * \param ctx The builder context. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return The derived struct info of the call. + * \note call->op field is ignored during derivation and we only rely on information + * presented by func_sinfo. + */ +TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, + const BlockBuilder& ctx, arith::Analyzer* ana = nullptr); + +/*! + * \brief Erase the info to a corresponding more coarse grained + * struct info that is still well-defined(with all the vars in scope). + * + * When we are returning a StructInfo to another scope, + * it is important to remember that StructInfo may carry + * dependencies on var that is not defined the other scope. + * + * In such cases, it is important to call EraseToWellDefined to get + * another StructInfo that **only** contains the vars that are defined + * in the target scope. + * + * For example, consider the following function + * + * \code + * + * @R.function + * def f(x: R.Tensor[(n, m)]): + * k = tir.Var("k", "int64") + * v0 = opaque_fn(x) + * v1 = match_cast(v0, R.Tensor[(n, k)]) + * v2 : R.Tensor[(n + 1, k + 2)] = pad(v1) + * return v2 + * + * \endcode + * + * In the above code, the return value y have shape `(n + 1, k + 2)`, + * However, at the level of function signature, only n, m are defined, + * k is undefined here. + * + * When we call EraseToWellDefined(R.Tensor[(n + 1, k + 2)], fshape_var_map={n: n, m: m}), + * we will obtain R.Tensor(ndim=2), which is an erased info that does not depend + * on k(which is undefined from parameter signature). + * + * However, if we call EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: n, m: m}), + * Then the return value will be R.Tensor[(n + 1, m)], because both n and m are defined. + * + * We can also make these var map to return a different expression. + * For example, EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: 2, m: m}) + * will give us R.Tensor[(3, m)], where n get replaced by 2. + * + * Use this function in the following scenarios: + * - Decide the struct_info of expr with sub-scopes, such as If, SeqExpr + * - Decide the deduced return struct_info of a function that can be fully decided by params. + * + * \param info The struct info. + * \param f_shape_var_map callback function to specify + * whether a symbolic shape var is defined and the value it maps to, + * return nullopt if var is undefined. + * \param f_var_defined callback function to specify + * whether a var is defined in the target scope and the value it maps to, + * return nullopt if var is undefined. + * \param ana Optional context analyzer to prove symbolic expression equality. + * + * \return the corresponding erased struct info. + */ +TVM_DLL StructInfo +EraseToWellDefined(const StructInfo& info, + std::function(const tir::Var& var)> f_shape_var_map = nullptr, + std::function(const Var& var)> f_var_map = nullptr, + arith::Analyzer* ana = nullptr); + +/*! + * \brief EraseToWellDefined variant with map. + * \param info The struct info. + * \param f_shape_var_map callback function to specify + * whether a symbolic shape var is defined and the value it maps to, + * return nullopt if var is undefined. + * \param f_var_defined callback function to specify + * whether a var is defined in the target scope and the value it maps to, + * return nullopt if var is undefined. + * \param ana Optional context analyzer to prove symbolic expression equality. + * + * \return the corresponding erased struct info. + */ +TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, Map shape_var_map, + Map var_map, arith::Analyzer* ana = nullptr); + +/*! + * \brief Fine grained result of base check. + * + * This analysis comes with different levels of checking failures + * that can help to customize the compilation decisions. + * + * For a given pair of lhs_struct_info, rhs_struct_info. We adopt + * the following terminology: + * - LSet = {value | value mactches lhs_struct_info} + * - RSet = {value | value mactches rhs_struct_info} + * + * See the definition of each level below. + */ +enum class BaseCheckResult { + /*! + * \brief The two value sets have no intersection at all: Interset(LSet, RSet) = empty + */ + kFailL0 = 0, + /*! + * \brief LSet is not superset of RSet by only looking at static information. + * + * \note This level will trigger static type checking error when lhs is param and rhs is arg. + */ + kFailL1 = 1, + /*! + * \brief WLSet is not superset of RSet because of mismatch in value information. + * + * L1-level mismatches in params of FuncStructInfo is categorized as + * If lhs is FuncStructInfo, then L1-level mismatch in its params + * is categorized as L2-level mismatch for lhs. + * + * Design considerations for functions: + * - (a) We want to be able to erase type/value in function signature + * when we unify function struct info and preserve simpler representations. + * - (b) We automatically insert match_cast at function boundary, so + * we can erase (int)->int argument as (object)->int. + * The input shape/type mismatch will be detected by runtime checks at function boundary. + * This behavior is also consistent with the PackedFunc behavior. + * + * \note This level means there is no problem about static known information. + * It is OK for the checker to do best effort and return this value. + */ + kFailL2 = 2, + /*! \brief LSet is superset of RSet. */ + kPass = 3 +}; + +/*! + * \brief Run a base check to see if base subsumes derived. + * + * This function returns fine-grained base-check result on reasons of failure. + * + * \param base The base struct info. + * \param derived The derived struct info. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return Whether the relation holds. + * + * \sa BaseCheckResult + */ +TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived, + arith::Analyzer* ana = nullptr); + +/*! + * \brief Check the relation of two struct info to see if one subsumes another one. + * + * \param base The base struct info. + * \param derived The derived struct info. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return Whether the relation holds. + */ +TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, + arith::Analyzer* ana = nullptr); + +/*! + * \brief Unify the two struct info their least common ancestor. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return The unified information. + */ +TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, + arith::Analyzer* ana = nullptr); +//----------------------------------- +// General IR analysis +//---------------------------------- /*! * \brief Check if the IRModule is well formed. * diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 95f5ded0ab24..e9392b4fb1d2 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -105,6 +105,12 @@ class BlockBuilderNode : public Object { */ virtual void UpdateFunction(const GlobalVar& gv, BaseFunc function) = 0; + /*! + * \brief Report an error during transformation construction. + * \param diagnostic The diagnostic information. + */ + virtual void ReportFatal(const Diagnostic& diagnostic) = 0; + //------------------------------- // Scope management //------------------------------- @@ -116,6 +122,23 @@ class BlockBuilderNode : public Object { */ virtual Optional LookupBinding(const Var& var) = 0; + /*! + * \brief Begin a new scope, with optional parameters that + * are visible within the scope. + * + * \param params Parameters that are visible within the scope. + * + * \note This function should be called when new scope is introduced + * (function, seq) to properly track the variable availability + * and help the best effort deduction. + * + * \sa EndScope + */ + virtual void BeginScope(Optional> params) = 0; + + /*! \brief End the previously defined scope. */ + virtual void EndScope() = 0; + /*! \brief Begin to build a DataflowBlock. */ virtual void BeginDataflowBlock() = 0; @@ -202,12 +225,20 @@ class BlockBuilderNode : public Object { * \param expr The input expression. * \return The normalized expression. * - * \note Invariant: If any of the sub expr have a shape field, - * they are required to already be in the normal form. - * This is because we cannot normalize shape in argument values. + * \note Invariant: If any of the sub expr have struct_info field. + * they must have already been normalized. */ virtual Expr Normalize(const Expr& expr) = 0; + /*! + * \brief Normalize argument to a call or another IRNode. + * \param expr The input expression. + * \return The normalized expression. + * + * \note This function will create a binding var for non-leaf expressions such as Call. + */ + virtual Expr NormalizeArgument(const Expr& expr) = 0; + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "relax.BlockBuilder"; TVM_DECLARE_BASE_OBJECT_INFO(BlockBuilderNode, Object); diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index e6f1571fb9c3..54a409fc5e61 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -37,6 +37,59 @@ using Expr = RelayExpr; using ExprNode = RelayExprNode; using relay::Id; +/*! + * \brief Base type of all structure information. + * + * StructInfo stores possible structure information + * deduced during compile-time. It encapsulates + * both static type and runtime information such + * as shape. + * + * StructInfo of each non-primitive Expr can be + * deduced during compilation in a "best-effort" manner. + * + * When struct_info appears in function parameter and return + * signatures. They will imply a runtime check that matches + * the structure information with the value. + * + * When it appears in Expr, they follow "assume-semantics", + * which means the compiler will take the deduced information as it is + * and only do best effort prove and checks. + * + * Each struct info can be uniquely erased to a static-type. + * The compiler will still compile the code(with less information) + * when we erase to the static type. + * + * If an StructInfo contains an Expr field, then that field + * must be normalized already through NormalizeArg. + * This invariant will be checked in constructors + * and help us to simplify our assumption + * during struct info deduction. + */ +class StructInfoNode : public Object { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + static constexpr const char* _type_key = "StructInfo"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 5; + TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object); +}; + +/*! + * \brief Managed reference to StructInfoNode. + * \sa StructInfoNode + */ +class StructInfo : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(StructInfo, ObjectRef, StructInfoNode); +}; + /*! * \brief Call corresponds to callable invocation. * Corresponds to operation in computational graph terminology. @@ -85,6 +138,7 @@ class CallNode : public ExprNode { v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); v->Visit("shape_", &shape_); + v->Visit("struct_info_", &struct_info_); } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { @@ -163,6 +217,7 @@ class IfNode : public ExprNode { v->Visit("span", &span); v->Visit("shape_", &shape_); v->Visit("_checked_type_", &checked_type_); + v->Visit("struct_info_", &struct_info_); } bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { @@ -218,6 +273,7 @@ class TupleNode : public ExprNode { v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); v->Visit("shape_", &shape_); + v->Visit("struct_info_", &struct_info_); } bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const { @@ -275,6 +331,7 @@ class TupleGetItemNode : public ExprNode { v->Visit("index", &index); v->Visit("span", &span); v->Visit("shape_", &shape_); + v->Visit("struct_info_", &struct_info_); v->Visit("_checked_type_", &checked_type_); } @@ -324,6 +381,7 @@ class ShapeExprNode : public ExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("values", &values); v->Visit("shape_", &shape_); + v->Visit("struct_info_", &struct_info_); v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -362,6 +420,7 @@ class RuntimeDepShapeNode : public ExprNode { public: void VisitAttrs(AttrVisitor* v) { v->Visit("shape_", &shape_); + v->Visit("struct_info_", &struct_info_); v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -399,10 +458,11 @@ class VarNode : public ExprNode { const String& name_hint() const { return vid->name_hint; } void VisitAttrs(AttrVisitor* v) { - v->Visit("_checked_type_", &checked_type_); v->Visit("vid", &vid); - v->Visit("span", &span); v->Visit("shape_", &shape_); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); } bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { @@ -443,9 +503,10 @@ class DataflowVarNode : public VarNode { public: void VisitAttrs(AttrVisitor* v) { v->Visit("vid", &vid); - v->Visit("span", &span); v->Visit("shape_", &shape_); + v->Visit("struct_info_", &struct_info_); v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); } bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const { @@ -500,6 +561,7 @@ class ConstantNode : public ExprNode { v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); v->Visit("shape_", &shape_); + v->Visit("struct_info_", &struct_info_); } bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const { @@ -691,6 +753,7 @@ class SeqExprNode : public ExprNode { v->Visit("blocks", &blocks); v->Visit("body", &body); v->Visit("shape_", &shape_); + v->Visit("struct_info_", &struct_info_); v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -739,8 +802,9 @@ class FunctionNode : public BaseFuncNode { v->Visit("ret_shape", &ret_shape); v->Visit("_checked_type_", &checked_type_); v->Visit("shape_", &shape_); - v->Visit("span", &span); + v->Visit("struct_info_", &struct_info_); v->Visit("attrs", &attrs); + v->Visit("span", &span); } bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { @@ -808,6 +872,7 @@ class ExternFuncNode : public BaseFuncNode { void VisitAttrs(AttrVisitor* v) { v->Visit("global_symbol", &global_symbol); + v->Visit("struct_info_", &struct_info_); v->Visit("span", &span); } @@ -830,23 +895,6 @@ class ExternFunc : public BaseFunc { TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode); }; -/*! - * \brief Update the type of an Expr. - * \param expr The Expr whose type to be updated. - * \param type The type assigned to the checked_type_ of \p expr. - * \note We ensure idempotence, that is we can only update the checked_type_ of an Expr if it's - * nullptr. - */ -void UpdateType(Expr expr, Type type); - -/*! - * \brief Update the shape of an Expr. - * \param expr The Expr whose shape to be updated. - * \param shape The shape assigned to the shape_ of \p expr. - * \note We ensure idempotence, that is we can only update the shape_ of an Expr if it's nullptr. - */ -void UpdateShape(Expr expr, Optional shape); - } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 0021b04b4881..306e961badb8 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -253,6 +253,7 @@ class ExprVisitor : public ExprFunctor { virtual void VisitType(const Type& t); virtual void VisitSpan(const Span& span); + virtual void VisitPrimExpr(const PrimExpr& expr); private: using TSelf = ExprVisitor; @@ -305,6 +306,13 @@ class ExprMutatorBase : public ExprFunctor { * visitor for types which transform them appropriately. */ virtual Type VisitType(const Type& t); + + /*! + * \brief Used to visit the PrimExpr inside of expressions. + * + * Can be overloaded to transform the shape expressions. + */ + virtual PrimExpr VisitPrimExpr(const PrimExpr& expr); }; /*! @@ -387,10 +395,14 @@ class ExprMutator : public ExprMutatorBase { /*! * \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If. - * \param expr The expr to be visited. + * + * \param body_expr The body to be visited. + * \param params Optional parameters that are visible within the scope. * \return The expr after visiting. + * + * \note The body_expr must be an SeqExpr in the normal form. */ - Expr VisitWithNewScope(const Expr& expr); + Expr VisitWithNewScope(const Expr& body_expr, Optional> params = NullOpt); /*! * \brief Look up the value bound to a variable. @@ -412,14 +424,13 @@ class ExprMutator : public ExprMutatorBase { } /*! - * \brief Create a new var with specified shape and type if the original var's shape or type does + * \brief Create a new var with specified struct_info if the original var's shape or type does * not match with the specified ones. * \param var The var to be updated. - * \param shape The specified shape. - * \param type The specified type. - * \return The var filled with \p shape and \p type. + * \param struct_info The struct info to be updated. + * \return The var filled with struct_info */ - Var WithShapeAndType(Var var, Optional shape, Type type); + Var WithStructInfo(Var var, StructInfo struct_info); /*! \brief Internal block builder to emit bindings during rewriting. */ BlockBuilder builder_; @@ -793,7 +804,7 @@ class PyExprMutatorNode : public Object, public ExprMutator { using ExprMutator::LookupBinding; using ExprMutator::var_remap_; using ExprMutator::VisitWithNewScope; - using ExprMutator::WithShapeAndType; + using ExprMutator::WithStructInfo; void VisitAttrs(AttrVisitor* v) { v->Visit("builder_", &builder_); } static constexpr const char* _type_key = "expr_functor.PyExprMutator"; diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index d935c50575ba..3722d8eb4512 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -24,6 +24,7 @@ #ifndef TVM_RELAX_OP_ATTR_TYPES_H_ #define TVM_RELAX_OP_ATTR_TYPES_H_ +#include #include #include #include @@ -44,6 +45,15 @@ namespace relax { using FInferShape = runtime::TypedPackedFunc; +/*! + * \brief Infer output struct info given the call + * + * \param call The call expression to be derived. + * \param ctx The builder context. + */ +using FInferStructInfo = + runtime::TypedPackedFunc; + /*! * \brief Infer the output type for operators. This function will * be invoked to fill the \p checked_type_ field of expressions. diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h new file mode 100644 index 000000000000..6efe8fdef2e0 --- /dev/null +++ b/include/tvm/relax/struct_info.h @@ -0,0 +1,417 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RELAX_STRUCT_INFO_H_ +#define TVM_RELAX_STRUCT_INFO_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Opaque object. + */ +class ObjectStructInfoNode : public StructInfoNode { + public: + void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const ObjectStructInfoNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.ObjectStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(ObjectStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to ObjectStructInfoNode. + * \sa ObjectStructInfoNode + */ +class ObjectStructInfo : public StructInfo { + public: + TVM_DLL ObjectStructInfo(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectStructInfo, StructInfo, ObjectStructInfoNode); +}; + +/*! + * \brief Primitive value. + */ +class PrimStructInfoNode : public StructInfoNode { + public: + /*! \brief Underlying data type of the primitive value */ + DataType dtype; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); } + + static constexpr const char* _type_key = "relax.PrimStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to PrimStructInfoNode. + * \sa PrimStructInfoNode + */ +class PrimStructInfo : public StructInfo { + public: + TVM_DLL PrimStructInfo(DataType dtype, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, PrimStructInfoNode); +}; + +/*! + * \brief StructInfo of shape value. + */ +class ShapeStructInfoNode : public StructInfoNode { + public: + /*! \brief optionally stores the symbolic value patterns of the shape */ + Optional> values; + /*! + * \brief The number of dimension of the shape, can be unknown. + * \sa kUnknownDim + */ + int ndim; + + /*! \return Whether the struct info contains unknown ndim. */ + bool IsUnknownNdim() const { return ndim == kUnknownDim; } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("values", &values); + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } + + bool SEqualReduce(const ShapeStructInfoNode* other, SEqualReducer equal) const { + return equal(values, other->values) && equal(ndim, other->ndim); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(values); + hash_reduce(ndim); + } + + static constexpr const char* _type_key = "relax.ShapeStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to ShapeStructInfoNode. + * \sa ShapeStructInfoNode + */ +class ShapeStructInfo : public StructInfo { + public: + /*! + * \brief Construction with known symbolic shape patterns + * \param values The symbolic shape values + * \param span The span of the AST. + */ + TVM_DLL ShapeStructInfo(Array values, Span span = Span()); + /*! + * \brief Construction with known unknown symbolic shape patterns. + * \param ndim Number of dimensions -- can be kUnknownDim + * \param span The span of the AST. + */ + TVM_DLL ShapeStructInfo(int ndim, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeStructInfo, StructInfo, ShapeStructInfoNode); +}; + +/*! + * \brief StructInfo of Tensor. + */ +class TensorStructInfoNode : public StructInfoNode { + public: + /*! + * \brief optionally store the shape expression of the tensor. + * \note shape must be normalized: it can only be NullOpt or ShapeExpr or Var. + */ + Optional shape; + /*! \brief The content data type, use void to denote the dtype is unknown. */ + DataType dtype; + /*! + * \brief The number of dimension of the tensor, can be unknown. + * \sa kUnknownDim + */ + int ndim; + + /*! \return Whether the struct info contains unknown ndim. */ + bool IsUnknownNdim() const { return ndim == kUnknownDim; } + + /*! \return Whether the struct info contains unknown dtype. */ + bool IsUnknownDtype() const { return dtype.is_void(); } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } + + bool SEqualReduce(const TensorStructInfoNode* other, SEqualReducer equal) const { + return equal(shape, other->shape) && equal(ndim, other->ndim) && equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(shape); + hash_reduce(dtype); + hash_reduce(ndim); + } + + static constexpr const char* _type_key = "relax.TensorStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to TensorStructInfoNode. + * \sa TensorStructInfoNode + */ +class TensorStructInfo : public StructInfo { + public: + /*! + * \brief Construction with a known shape expression. + * \param shape The shape of the tensor. + * \param dtype The data type of tensor's elements. + * \param span The span of the AST. + * + * \note shape must already be normalized. + */ + TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Span span = Span()); + + /*! + * \brief Construction with an unknown shape expression. + * \param dtype The data type of tensor's elements. + * \param ndim The number of dimensions + * \param span The span of the AST. + */ + TVM_DLL TensorStructInfo(DataType dtype, int ndim, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode); +}; + +/*! + * \brief StructInfo of Tuple. + */ +class TupleStructInfoNode : public StructInfoNode { + public: + /*! \brief The struct info of tuple fields. */ + Array fields; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("fields", &fields); + v->Visit("span", &span); + } + + bool SEqualReduce(const TupleStructInfoNode* other, SEqualReducer equal) const { + return equal(fields, other->fields); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } + + static constexpr const char* _type_key = "relax.TupleStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to TupleStructInfoNode. + * \sa TupleStructInfoNode + */ +class TupleStructInfo : public StructInfo { + public: + /*! + * \brief Constructor + * \param fields Struct info of tuple fields. + * \param span The span of the AST. + */ + TVM_DLL TupleStructInfo(Array fields, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo, TupleStructInfoNode); +}; + +/*! + * \brief custom-defined StructInfo derivation function. + * \param call The call expression to be derived. + * \param ctx The builder context. + * \return The derived struct info of the call. + */ +using StructInfoDeriveFunc = TypedEnvFunc; + +/*! + * \brief Structure information about function. + * + * This data structure contains enough information for us to + * do best-effort structure information deduction. + */ +class FuncStructInfoNode : public StructInfoNode { + public: + /*! + * \brief The parameter struct info of the function. + * \note When params is NullOpt means the function can take arbitrary number of arguments. + * We define such functions as Opaque function. + */ + Optional> params; + /*! + * \brief The struct info of the function's return value. + */ + StructInfo ret; + /*! + * \brief Derivation function of opaque functions that may take any number of parameters. + * \note When derive_func is not empty, then params should be NullOpt, + * ret should be ObjectStructInfo() + */ + Optional derive_func; + + /*! + * \return Whether the func struct info is opaque. + * \note We define a function as opaque we have no constraints on params. + */ + bool IsOpaque() const { return !params.defined(); } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("ret", &ret); + v->Visit("derive_func", &derive_func); + v->Visit("span", &span); + } + + bool SEqualReduce(const FuncStructInfoNode* other, SEqualReducer equal) const { + return equal.DefEqual(params, other->params) && equal(ret, other->ret) && + equal(derive_func, other->derive_func); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(params); + hash_reduce(ret); + hash_reduce(derive_func); + } + + static constexpr const char* _type_key = "relax.FuncStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuncStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to FuncStructInfoNode. + * \sa FuncStructInfoNode + */ +class FuncStructInfo : public StructInfo { + public: + /*! + * \brief Constructor from parameter struct info and return value struct info. + * \param params The struct info of function parameters. + * \param ret The return value struct info. + * \param span The span of the AST. + * + * \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from + * params. If you are unsure, you can always erase ret to static. + */ + TVM_DLL FuncStructInfo(Array params, StructInfo ret, Span span = Span()); + + /*! + * \brief Constructing an opaque function struct info using derive_func. + * + * \param derive_func Derivation function. + * \param span The span of the AST. + * + * \return The FuncStructInfo for opaque packedfunc. + * \note Defaults to an derive func that always return ObjectStructInfo if not specified. + */ + TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, Span span = Span()); + + /*! + * \brief Construct an opaque function using from return struct info. + * + * \param ret The struct info of the return value. + * \param span The span of the AST. + * + * \return The FuncStructInfo for opaque packedfunc. + * \note Defaults to an derive func that always return ObjectStructInfo if not specified. + */ + TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode); +}; + +/*! + * \brief Match and check if expr have StructInfo T and return it. + * + * \param expr The input expression. + * \return The result of match. + * \tparam T the underlying structure info type + */ +template +inline Optional MatchStructInfo(const Expr& expr) { + using TNode = typename T::ContainerType; + if (const TNode* ptr = expr->struct_info_.as()) { + return GetRef(ptr); + } else { + return NullOpt; + } +} + +/*! + * \brief Get the structure info of a given expr and try to cast it as const T*. + * + * \param expr The input expression. + * \return The pointer. Returns nullptr if the type does not match + * \tparam T the underlying structure info type + */ +template +inline const T* GetStructInfoAs(const Expr& expr) { + ICHECK(expr->struct_info_.defined()) + << "The struct_info is not populated, check if you have normalized the expr"; + return expr->struct_info_.as(); +} + +/*! + * \brief Get the underlying structure info of expr. + * + * \param expr The input expression. + * \return underlying struct info. + */ +inline StructInfo GetStructInfo(const Expr& expr) { + auto* ptr = expr->struct_info_.as(); + ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr"; + return GetRef(ptr); +} + +/*! + * \brief Update the struct info of an Expr. + * \param expr The Expr whose struct info to be updated. + * \param shape The struct_info assigned. + * \note We ensure idempotence, that is we can only update the struct_info of an Expr only + * if the original one is nullptr. + */ +TVM_DLL void UpdateStructInfo(Expr expr, StructInfo struct_info); + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_STRUCT_INFO_H_ diff --git a/include/tvm/relax/struct_info_functor.h b/include/tvm/relax/struct_info_functor.h new file mode 100644 index 000000000000..382b4ab2c936 --- /dev/null +++ b/include/tvm/relax/struct_info_functor.h @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/struct_info_functor.h + * \brief Functors and visitors for struct info. + */ +#ifndef TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ +#define TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ + +#include +#include + +#include + +namespace tvm { +namespace relax { + +template +class StructInfoFunctor; + +// functions to be overriden. +#define STRUCT_INFO_FUNCTOR_DEFAULT \ + { return VisitStructInfoDefault_(op, std::forward(args)...); } + +#define TVM_STRUCT_INFO_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitStructInfo_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class StructInfoFunctor { + private: + using TSelf = StructInfoFunctor; + using FStructInfo = tvm::NodeFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~StructInfoFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const StructInfo& n, Args... args) { + return VisitStructInfo(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitStructInfo(const StructInfo& n, Args... args) { + ICHECK(n.defined()); + static FStructInfo vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitStructInfo_(const ObjectStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const PrimStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const ShapeStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const TensorStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const TupleStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const FuncStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfoDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; // unreachable, written to stop compiler warning + } + + private: + // initialize the vtable. + static FStructInfo InitVTable() { + FStructInfo vtable; + // Set dispatch + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(ObjectStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(PrimStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(ShapeStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TensorStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TupleStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(FuncStructInfoNode); + return vtable; + } +}; + +#undef TVM_STRUCT_INFO_FUNCTOR_DISPATCH + +/*! + * \brief A struct info visitor. + */ +class TVM_DLL StructInfoVisitor : public StructInfoFunctor { + public: + void VisitStructInfo_(const ObjectStructInfoNode* op) override; + void VisitStructInfo_(const PrimStructInfoNode* op) override; + void VisitStructInfo_(const ShapeStructInfoNode* op) override; + void VisitStructInfo_(const TensorStructInfoNode* op) override; + void VisitStructInfo_(const TupleStructInfoNode* op) override; + void VisitStructInfo_(const FuncStructInfoNode* op) override; + + protected: + // two functions to override when visit expr fields in struct info. + virtual void VisitStructInfoExprField(const Expr& expr) {} + virtual void VisitStructInfoExprField(const PrimExpr& expr) {} +}; + +/*! + * \brief StructInfoMutator that mutates struct info. + */ +class TVM_DLL StructInfoMutator : public StructInfoFunctor { + public: + StructInfo VisitStructInfo_(const ObjectStructInfoNode* op) override; + StructInfo VisitStructInfo_(const PrimStructInfoNode* op) override; + StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) override; + StructInfo VisitStructInfo_(const TensorStructInfoNode* op) override; + StructInfo VisitStructInfo_(const TupleStructInfoNode* op) override; + StructInfo VisitStructInfo_(const FuncStructInfoNode* op) override; + + protected: + // two functions to override when visit expr fields in struct info. + virtual Expr VisitStructInfoExprField(const Expr& expr) { return expr; } + virtual PrimExpr VisitStructInfoExprField(const PrimExpr& expr) { return expr; } +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index 83a24eb7ef32..0e66c9516866 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -42,11 +42,19 @@ static constexpr int kUnknownDim = -1; class ShapeTypeNode : public TypeNode { public: - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } + /*! \brief size of the shape. */ + int ndim; - bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { return true; } + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { + return equal(ndim, other->ndim); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(ndim); } static constexpr const char* _type_key = "relax.ShapeType"; TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTypeNode, TypeNode); @@ -54,7 +62,8 @@ class ShapeTypeNode : public TypeNode { class ShapeType : public Type { public: - TVM_DLL ShapeType(Span span = Span()); + // TODO(relax-team): remove the default value later. + TVM_DLL ShapeType(int ndim = kUnknownDim, Span span = Span()); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode); }; diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index bfd2a2b45280..3ed4f096d3c3 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -100,6 +100,10 @@ class FunctionFrameNode : public SeqExprFrameNode { * \sa ret_type */ Optional ret_shape; + /*! + * \brief The function return struct info. + */ + Optional ret_sinfo; /*! \brief The function attributes. */ Map attrs; /*! \brief The block builder to create Relax function. */ diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index b5630cb366eb..432d4fd3408c 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -20,6 +20,7 @@ #define TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ #include +#include #include #include @@ -28,45 +29,17 @@ namespace script { namespace ir_builder { namespace relax { -////////////////////////////// Shaped Type ////////////////////////////// +//////////////////////////////// Tensor ///////////////////////////////// /*! - * \brief A temporary data structure for unified type and shape in ir_builder. - * \note Used for `R.Tensor` and `R.Tuple` - */ -class ShapedTypeNode : public runtime::Object { - public: - /*! \brief The type, usually is DynTensorType or TupleType */ - Type type; - /*! \brief The shape, which is optional. */ - Optional shape; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("type", &type); - v->Visit("shape", &shape); - } - - static constexpr const char* _type_key = "script.ir_builder.relax.ShapedType"; - TVM_DECLARE_FINAL_OBJECT_INFO(ShapedTypeNode, runtime::Object); -}; - -class ShapedType : public runtime::ObjectRef { - public: - TVM_DLL explicit ShapedType(Type type, Optional shape); - - TVM_DEFINE_OBJECT_REF_METHODS(ShapedType, ObjectRef, ShapedTypeNode); -}; - -/*! - * \brief Create a ShapedType for a DynTensor. + * \brief Create a TensorStructInfo. * \param shape The shape of the tensor. It's runtime dependent if `shape` is None. * \param dtype The element data type of the tensor. It's runtime dependent if `dtype` is None. * \param ndim The number of dimensions of the tensor. It's runtime dependent if `ndim` is -1. - * \return The ShapedType that is only used in ir_builder. + * \return The TensorStructInfo. */ -TVM_DLL ShapedType Tensor(Optional> shape, DataType dtype, int ndim = -1); - -TVM_DLL ShapedType CreateShapedTuple(Array types, Array> shapes); +TVM_DLL tvm::relax::TensorStructInfo Tensor(Optional> shape, DataType dtype, + int ndim = -1); /////////////////////////////// Function //////////////////////////////// @@ -79,11 +52,10 @@ TVM_DLL FunctionFrame Function(); /*! * \brief Add a parameter to the last function frame. * \param name The name of the parameter. - * \param type The type of the parameter. - * \param shape The shape of the parameter. + * \param struct_info The struct_info of the parameter. * \return The created function parameter var. */ -TVM_DLL tvm::relax::Var Arg(const String& name, const Type& type, const tvm::relax::Expr& shape); +TVM_DLL tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_info); /*! * \brief Specify the name of the last function frame. @@ -98,16 +70,10 @@ TVM_DLL void FuncName(const String& name); TVM_DLL void FuncAttrs(Map attrs); /*! - * \brief Specify the return type of the last function frame. - * \param ret_type The return type. Note: it's a standard `tvm::Type` instead of ShapedType. - */ -TVM_DLL void FuncRetType(tvm::Type ret_type); - -/*! - * \brief Specify the return shape of the last function frame. - * \param ret_shape The return shape. + * \brief Specify the return struct info of the last function frame. + * \param ret_sinfo The return struct info. */ -TVM_DLL void FuncRetShape(tvm::relax::Expr ret_shape); +TVM_DLL void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo); /*! * \brief Specify the return value of the last function frame. @@ -158,15 +124,14 @@ TVM_DLL Optional EmitMatchShape(const tvm::relax::Expr& value, ///////////////////////////// Type Deduce ////////////////////////////// /*! - * \brief Annotate and check the type and shape of relax var. + * \brief Annotate the struct info of a var. * \param var The input var to be annotated. - * \param anno_type The annotated type. - * \param anno_shape The annotated shape, which can be undefined. + * \param anno_struct_info The annotated struct info, which can be undefined. * \note This function will check if the type of var is compatible with the annotated type. * And we annotate to the var with more detailed type. */ -TVM_DLL void AnnotateTypeShape(const tvm::relax::Var& var, const Type& anno_type, - const Optional& anno_shape); +TVM_DLL void AnnotateStructInfo(const tvm::relax::Var& var, + const tvm::relax::StructInfo& anno_struct_info); ///////////////////////////// If Then Else ///////////////////////////// diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 81f46058d2a4..6a07aed8dd7e 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Common expressions data structures in the IR.""" +import tvm import tvm._ffi from .base import Node @@ -62,6 +63,17 @@ def shape(self): """ return _ffi_api.RelayExprShape(self) + @property + def struct_info(self) -> "tvm.relax.StructInfo": + """Get the struct info field + + Returns + ------- + struct_info : tvm.relax.StructInfo + The struct info if available. + """ + return _ffi_api.ExprStructInfo(self) + @tvm._ffi.register_object("GlobalVar") class GlobalVar(RelayExpr): diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 7230391f70be..4b80b16d2679 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -25,6 +25,7 @@ from . import analysis from . import transform from . import expr_functor +from . import struct_info # Expr Expr = expr.Expr @@ -80,3 +81,13 @@ ExprFunctor = expr_functor.ExprFunctor PyExprVisitor = expr_functor.PyExprVisitor PyExprMutator = expr_functor.PyExprMutator + + +# StructInfo +StructInfo = struct_info.StructInfo +ObjectStructInfo = struct_info.ObjectStructInfo +PrimStructInfo = struct_info.PrimStructInfo +ShapeStructInfo = struct_info.ShapeStructInfo +TensorStructInfo = struct_info.TensorStructInfo +TupleStructInfo = struct_info.TupleStructInfo +FuncStructInfo = struct_info.FuncStructInfo diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 1b4579713371..59824c16ad1d 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -21,14 +21,166 @@ configuring the passes and scripting them in Python. """ -from typing import Dict, List +from typing import Dict, List, Optional +from enum import IntEnum import tvm from tvm import tir -from tvm.relax.expr import DataflowBlock, GlobalVar, Var, Expr, Function, Binding +from tvm.relax.ty import Type +from tvm.relax.struct_info import StructInfo, FuncStructInfo +from tvm.relax.expr import DataflowBlock, GlobalVar, Var, Expr, Function, Binding, Call from . import _ffi_api +def get_static_type(sinfo: StructInfo) -> Type: + """Get the corresponding static type from a StructInfo. + + Parameters + ---------- + sinfo : StructInfo + The input struct info. + + Returns + ------- + ret : Type + The corresponding static type. + """ + return _ffi_api.GetStaticType(sinfo) # type: ignore + + +def get_legacy_shape_hint(sinfo: StructInfo) -> Optional[Expr]: + """Get the corresponding shape from a StructInfo. + + Parameters + ---------- + sinfo : StructInfo + The input struct info. + + Returns + ------- + ret : Type + The corresponding shape. + """ + return _ffi_api.GetLegacyShapeHint(sinfo) # type: ignore + + +def erase_to_well_defined( + sinfo: StructInfo, + shape_var_map: Dict[tir.Var, tir.PrimExpr] = None, + var_map: Dict[Var, Expr] = None, +) -> StructInfo: + """Erase sinfo into a well defined form. + + This function removes the StructInfo's dependencies on shape and vars that + are not defined in given maps. + + Parameters + ---------- + sinfo : StructInfo + The input struct info. + + shape_var_map : Dict[tir.Var, tir.PrimExpr] + Specifies the defined shape vars and the values they should map to. + + var_map : Dict[Var, Expr] + Specifies the defined vars and the values they should map to. + + Returns + ------- + ret : StructInfo + The corresponding erased struct info. + """ + shape_var_map = {} if shape_var_map is None else shape_var_map + var_map = {} if var_map is None else var_map + + return _ffi_api.EraseToWellDefined(sinfo, shape_var_map, var_map) # type: ignore + + +class BaseCheckResult(IntEnum): + """Return result of fine-grained base check. + + Note + ---- + Base check comes with fine-grained fail levels. + + - FAIL_L0: The lhs and rhs have no intersection at all. + - FAIL_L1: We get the failure by looking at static information. + - FAIL_L2: We get the failure due to unknown symbolic variable relations. + """ + + FAIL_L0 = 0 + FAIL_L1 = 1 + FAIL_L2 = 2 + PASS = 3 + + +def struct_info_base_check(base: StructInfo, derived: StructInfo) -> BaseCheckResult: + """Run a base check to see if base subsumes derived. + + Parameters + ---------- + base: StructInfo + The base struct info. + + derived: StructInfo + The derived struct info. + + Returns + ------- + ret : StructInfo + The derived return value struct info. + """ + return _ffi_api.StructInfoBaseCheck(base, derived) # type: ignore + + +def derive_call_ret_struct_info( + func_sinfo: FuncStructInfo, call: Call, ctx: "tvm.relax.BlockBuilder" +) -> StructInfo: + """Derive the call's ret value struct info from inputs. + + Parameters + ---------- + func_sinfo: FuncStructInfo + The call's function signature. + + call: Call + The call expression + + ctx: tvm.relax.BlockBuilder + The context block builder. + + Returns + ------- + ret : StructInfo + The derived return value struct info. + + Note + ---- + This is an internal derivation function, call.op field is + ignored in this case and the derivation only depends on func_sinfo. + """ + return _ffi_api.DeriveCallRetStructInfo(func_sinfo, call, ctx) # type: ignore + + +def struct_info_lca(lhs: StructInfo, rhs: StructInfo) -> StructInfo: + """Unify the two struct info their least common ancestor. + + Parameters + ---------- + lhs: StructInfo + The left operand. + + rhs: StructInfo + The right operand. + + Returns + ------- + ret : StructInfo + The corresponding lca result. + """ + return _ffi_api.StructInfoLCA(lhs, rhs) # type: ignore + + def post_order_visit(expr, fvisit): """Recursively visit the ir in post DFS order node, apply fvisit. Each node is guaranteed to be visited diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 50e96a937eec..3d407dd89b60 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -78,6 +78,31 @@ def __exit__(self, ptype, value, trace): self._bb._begin_binding_block() +class TestingScope(object): + """Auxiliary scope for testing purposes""" + + def __init__(self, block_builder, def_vars): + self._bb = block_builder + shape_vars = [] + for var in def_vars: + if isinstance(var, tvm.tir.Var): + shape_vars.append(var) + else: + raise ValueError("def_vars only can take tir.Var") + # setup a dummy var so shape is in scope. + sparam = tvm.relax.Var("sparam") + tvm.relax.expr._update_struct_info(sparam, tvm.relax.ShapeStructInfo(shape_vars)) + self._scope_params = [sparam] + + def __enter__(self): + self._bb.begin_scope(self._scope_params) + self._bb._begin_dataflow_block() + + def __exit__(self, ptype, value, trace): + self._bb._end_block() + self._bb.end_scope() + + @tvm._ffi.register_object("relax.BlockBuilder") class BlockBuilder(Object): """A builder to build Relax IR for testing and dev. @@ -158,6 +183,7 @@ def _enter_function_scope(self, name, params, attrs): self._func_name = name self._func_params = params self._func_attrs = attrs + self.begin_scope(params) self._begin_binding_block() def _exit_function_scope(self, exc_type, exc_val, exc_tb): @@ -282,6 +308,21 @@ def function( attrs = {} return FunctionScope(self, name, params, attrs) + def testing_scope(self, def_vars: List[tir.Var]) -> TestingScope: + """Start a scope for unit-testing purposes. + + Parameters + ---------- + def_vars: List[tir.Var] + List of symbolic variables that are marked as defined in scope. + + Returns + ------- + ret: TestingScope + A TestingScope to setup builder for emit and other purposes. + """ + return TestingScope(self, def_vars) + def dataflow(self) -> DataflowScope: """Annotate a Relax dataflow block. @@ -581,12 +622,11 @@ def emit_func_output( if isinstance(output, (list, tuple)): output = Tuple(output) - self._func_ret = self.normalize(output) block = self._end_block() if len(block.bindings) > 0: self._blocks.append(block) - seqe = self.normalize(rx.SeqExpr(self._blocks, self._func_ret)) + seqe = self.normalize(rx.SeqExpr(self._blocks, output)) # The function's checked_type_ relies on the function body(seqe) to have deduced type # TODO(@yuchen): handle the case where the body's checked_type_ is null @@ -594,6 +634,7 @@ def emit_func_output( func = rx.Function(self._func_params, seqe, None, rx.RuntimeDepShape()) for key, value in self._func_attrs.items(): func = func.with_attr(key, value) + self.end_scope() self.add_func(func, self._func_name) def normalize(self, expr: Expr) -> Expr: @@ -754,3 +795,26 @@ def lookup_binding(self, var: Var) -> Optional[Expr]: The Expr bound to the input var. """ return _ffi_api.BlockBuilderLookupBinding(self, var) # type: ignore + + def begin_scope(self, params: Optional[List[Var]] = None) -> None: + """Begin a new scope, with optional parameters that + are visible within the scope. + + Parameters + ---------- + params: Optional[List[Var]] + Parameters that are visible within the scope. + + Note + ---- + This function should be called when new scope is introduced + (function, seq) to properly track the variable availability + and help the best effort deduction. + """ + + return _ffi_api.BlockBuilderBeginScope(self, params) # type: ignore + + def end_scope(self) -> None: + """End the current scope. Please see `begin_scope` for details""" + + return _ffi_api.BlockBuilderEndScope(self) # type: ignore diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 7e074b4e8c15..6ffd03d220d0 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -19,11 +19,12 @@ """The expression nodes of Relax.""" from typing import Any, List, Optional, Union import typing +import numpy as _np # type: ignore import tvm import tvm._ffi -import numpy as _np # type: ignore from tvm.runtime import ndarray as _nd +import tvm.relax from tvm._ffi import base as _base from .. import relay @@ -452,8 +453,11 @@ def const( if not dtype: # when dtype is None: int maps to "int32", float maps to "float32" - dtype = {_np.dtype("int64"): _np.int32, _np.dtype("float64"): _np.float32}.get( - value.dtype, None + dtype = { # type: ignore + _np.dtype("int64"): _np.int32, # type: ignore + _np.dtype("float64"): _np.float32, # type: ignore + }.get( + value.dtype, None # type: ignore ) if isinstance(value, (_np.ndarray, _np.generic)): @@ -472,9 +476,5 @@ def te_tensor(value: Expr, name: str = "rxplaceholder"): return _ffi_api.TETensor(value, name) # type: ignore -def _update_type(expr: Expr, type: Type) -> None: - _ffi_api.UpdateType(expr, type) # type: ignore - - -def _update_shape(expr: Expr, shape: Optional[tvm.runtime.Object]) -> None: - _ffi_api.UpdateShape(expr, shape) # type: ignore +def _update_struct_info(expr: Expr, struct_info: Optional["tvm.relax.StructInfo"]) -> None: + _ffi_api.UpdateStructInfo(expr, struct_info) # type: ignore diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py index ecaae9fdf6ab..b05ff53c77b5 100644 --- a/python/tvm/relax/expr_functor.py +++ b/python/tvm/relax/expr_functor.py @@ -31,6 +31,7 @@ from .expr import Call, If, TupleGetItem from .expr import Binding, MatchShape, VarBinding from .expr import BindingBlock, DataflowBlock +from .struct_info import StructInfo from ..relay import Id from ..ir.module import IRModule from .block_builder import BlockBuilder @@ -1444,7 +1445,7 @@ def lookup_binding(self, var: Var) -> Optional[Expr]: # Using self._outer() to ref _PyExprMutator return _ffi_api.PyExprMutatorLookupBinding(self._outer(), var) # type: ignore - def with_shape_and_type(self, var: Var, shape: Optional[Object], t: Type) -> Var: + def with_struct_info(self, var: Var, struct_info: StructInfo) -> Var: """Create a new var with specified shape and type if the original var's shape or type does not match with the specified ones. @@ -1452,10 +1453,8 @@ def with_shape_and_type(self, var: Var, shape: Optional[Object], t: Type) -> Var ---------- var : Var The var to be updated. - shape : Optional[Object] - The specified shape. - t : Type - The specified type. + struct_info : StructInfo + The struct info. Returns ------- @@ -1463,4 +1462,4 @@ def with_shape_and_type(self, var: Var, shape: Optional[Object], t: Type) -> Var The var filled with shape and type. """ # Using self._outer() to ref _PyExprMutator - return _ffi_api.PyExprMutatorWithShapeAndType(self._outer(), var, shape, t) # type: ignore + return _ffi_api.PyExprMutatorWithStructInfo(self._outer(), var, struct_info) # type: ignore diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py new file mode 100644 index 000000000000..380ff6d38676 --- /dev/null +++ b/python/tvm/relax/struct_info.py @@ -0,0 +1,231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-import +"""The struct info nodes of the Relax language.""" +from typing import List, Optional, Tuple, Union + +import tvm._ffi +import tvm + +from tvm.ir import Span, Node, EnvFunc, Array, Type +from tvm.tir import PrimExpr +from .expr import Var, Expr, ShapeExpr + +from . import _ffi_api, ty, expr + + +class StructInfo(Node): + """The base class of all StructInfo. + + StructInfo contains both the static type + and runtime structural information. + """ + + def __eq__(self, other): + """Compare two struct info for structural equivalence.""" + return tvm.ir.structural_equal(self, other) + + def __ne__(self, other): + return not self.__eq__(other) + + def same_as(self, other): + """Overload with structural equality.""" + return super().__eq__(other) + + def is_base_of(self, derived: "StructInfo") -> bool: + """Check if self is base of another derived struct info. + + Parameters + ---------- + derived : StructInfo + The derived struct info to be checked. + + Returns + ------- + result : bool + The check result. + """ + return _ffi_api.StructInfoIsBaseOf(self, derived) # type: ignore + + +@tvm._ffi.register_object("relax.ObjectStructInfo") +class ObjectStructInfo(StructInfo): + """StructInfo of an Object.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ObjectStructInfo, span) # type: ignore + + +@tvm._ffi.register_object("relax.PrimStructInfo") +class PrimStructInfo(StructInfo): + """StructInfo of a primitive POD value. + + Parameters + ---------- + dtype : str + The data type of the prim value. + """ + + dtype: str + + def __init__(self, dtype: str, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.PrimStructInfo, dtype, span) # type: ignore + + +@tvm._ffi.register_object("relax.ShapeStructInfo") +class ShapeStructInfo(StructInfo): + """StructInfo of a shape value. + + Parameters + ---------- + values : Optional[List[PrimExpr]] + The symbolic shape values if known. + + ndim : Optional[int] + The size of the shape. + + Note + ---- + Do not specify values and ndim at the same time. + """ + + values: Optional[List[PrimExpr]] + ndim: int + span: Span + + def __init__( + self, values: Optional[List[PrimExpr]] = None, ndim: int = -1, span: Span = None + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ShapeStructInfo, values, ndim, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.TensorStructInfo") +class TensorStructInfo(StructInfo): + """StructInfo of a Tensor value. + + Parameters + ---------- + shape : Optional[Expr] + The shape expression. + + dtype : Optional[str] + The content data type. + + ndim : Optional[int] + The number of dimensions of the tensor. + + Note + ---- + Do not specify shape and ndim at the same time. + """ + + shape: Optional[Expr] + dtype: str + ndim: int + span: Span + + def __init__( + self, + shape: Union[Optional[Expr], List[PrimExpr]] = None, + dtype: str = "float32", + ndim: int = -1, + span: Span = None, + ) -> None: + if isinstance(shape, (list, tuple, Array)): + shape = ShapeExpr(shape) + + self.__init_handle_by_constructor__( + _ffi_api.TensorStructInfo, shape, dtype, ndim, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.TupleStructInfo") +class TupleStructInfo(StructInfo): + """StructInfo of a Tuple value. + + Parameters + ---------- + fields: List[StructInfo] + The struct info of the fields. + """ + + fields: List[StructInfo] + span: Span + + def __init__(self, fields: List[StructInfo], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.TupleStructInfo, fields, span) # type: ignore + + +@tvm._ffi.register_object("relax.FuncStructInfo") +class FuncStructInfo(StructInfo): + """StructInfo of a function value. + + Parameters + ---------- + params: List[StructInfo] + The struct info of the fields. + + ret: StructInfo + The struct info of return value + """ + + params: Optional[List[StructInfo]] + ret: StructInfo + derive_func: Optional[EnvFunc] + span: Span + + def __init__(self, params: List[StructInfo], ret: StructInfo, span: Span = None) -> None: + self.__init_handle_by_constructor__( + _ffi_api.FuncStructInfo, params, ret, span # type: ignore + ) + + @staticmethod + def opaque_func( + *, + ret: Optional[StructInfo] = None, + derive_func: Optional[EnvFunc] = None, + span: Span = None, + ) -> "FuncStructInfo": + """ + Create an opaque FuncStructInfo. + + The opaque function takes either a ret + that specificies the struct info of the return value + or a derive_func that provides a customized derivation rule. + + Parameters + ---------- + ret: Optional[StructInfo] + The struct info of the the function return value. + + derive_func: Optional[EnvFunc] + The environment function used for derivation + + span: Optional[Span] + Optional span information of the ast. + + Returns + ------- + info: FuncStructInfo + + Note + ---- + We cannot specify ret and derive_func simultaneously. + """ + return _ffi_api.FuncStructInfoOpaqueFunc(ret, derive_func, span) # type: ignore diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py index 64d63e6acfae..784643843341 100644 --- a/python/tvm/relax/ty.py +++ b/python/tvm/relax/ty.py @@ -24,10 +24,17 @@ @tvm._ffi.register_object("relax.ShapeType") class ShapeType(Type): - """The type of shape in Relax.""" + """The type of shape in Relax. - def __init__(self, span: Span = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.ShapeType, span) # type: ignore + Parameters + ---------- + ndim : Optional[int] + The size of the shape. + """ + + # TODO(relax-team): consider make ndim mandatory + def __init__(self, ndim: int = -1, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ShapeType, ndim, span) # type: ignore @tvm._ffi.register_object("relax.ObjectType") diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index fd81341a8d39..afc7a5cf5c60 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -14,16 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=redefined-builtin, wrong-import-order +# pylint: disable=redefined-builtin, wrong-import-order, no-member, invalid-name """IRBuilder for Relax dialect""" import functools from typing import Dict, List, Optional, Tuple, Union import tvm -from tvm._ffi import register_object as _register_object from tvm.ir import Type -from tvm.relax import Call, Expr, ExternFunc, ShapeExpr, TupleGetItem, TupleType, Var, const +from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, TupleType, Var, const +from tvm.relax.struct_info import StructInfo, TensorStructInfo +from tvm.relax.analysis import get_static_type ############################### Operators ############################### from tvm.relax.op import ( @@ -40,7 +41,6 @@ unique, memory, ) -from tvm.relax.ty import ObjectType, ShapeType, DynTensorType from tvm.relax.utils import convert_to_expr from tvm.runtime import Object as tvm_Object from tvm.tir import PrimExpr @@ -51,19 +51,11 @@ ############################## Tensor Type ############################## -@_register_object("script.ir_builder.relax.ShapedType") -class ShapedType(tvm_Object): - """A temporary Tensor type for `R.Tensor` in ir_builder.""" - - type: DynTensorType - shape: Optional[Expr] - - def tensor( shape: Optional[List[Union[PrimExpr, str]]] = None, dtype: Optional[str] = None, ndim: int = -1, -) -> ShapedType: +) -> TensorStructInfo: """Helper function for `R.Tensor` in parser Parameters ---------- @@ -75,8 +67,8 @@ def tensor( The number of dimensions of the tensor. It's runtime dependent if `ndim` is -1. Returns ------- - tensor_type: ShapedType - The ShapedType that is only used in ir_builder. + ret: TensorStructInfo + The result TensorStructInfo """ if shape is not None: @@ -90,26 +82,9 @@ def tensor( return _ffi_api.Tensor(shape, dtype, ndim) # pylint: disable=no-member # type: ignore -def create_shaped_tuple(types: List[Type], shapes: List[Optional[Expr]]) -> ShapedType: - """Helper function for `R.Tuple` in parser - Parameters - ---------- - types: List[Type] - The list of type of it's fields - shapes: List[Optional[Expr]] - The list of shape of it's fields. - Returns - ------- - tuple_type: ShapedType - The ShapedType that is only used in ir_builder. - """ - return _ffi_api.CreateShapedTuple(types, shapes) # pylint: disable=no-member # type: ignore - - ############################## Other Types ############################## -Object = ObjectType() # pylint: disable=invalid-name -Shape = ShapeType() # pylint: disable=invalid-name +Object = tvm.relax.ObjectStructInfo() # pylint: disable=invalid-name Void = TupleType([]) # pylint: disable=invalid-name ############################### Function ################################ @@ -125,30 +100,22 @@ def function() -> frame.FunctionFrame: return _ffi_api.Function() # pylint: disable=no-member # type: ignore -def arg(name: str, type: Union[Type, ShapedType], shape: Optional[ShapeExpr] = None) -> Var: +def arg(name: str, struct_info: StructInfo) -> Var: """Add a parameter to the last function frame. Parameters ---------- name: str The name of the parameter. - type: Union[Type, ShapedType] - The type of the parameter. It can be a typical TVM Type or a ShapedType, - which contains both type and shape - shape: Optional[ShapeExpr] - The shape of the parameter. + struct_info: StructInfo + The Struct Info of the parameter + Returns ------- var: Var The created function parameter var. """ - if isinstance(type, ShapedType): - if shape is not None: - raise ValueError("Cannot specify the shape if we use ShapedType") - shape = type.shape - type = type.type - - return _ffi_api.Arg(name, type, shape) # pylint: disable=no-member # type: ignore + return _ffi_api.Arg(name, struct_info) # pylint: disable=no-member # type: ignore def func_name(name: str) -> None: @@ -171,27 +138,14 @@ def func_attr(attrs: Dict[str, tvm_Object]) -> None: return _ffi_api.FuncAttrs(attrs) # pylint: disable=no-member # type: ignore -def func_ret_type(ret_type: Union[ShapedType, Type]) -> None: - """Specify the return type of the last function frame. +def func_ret_struct_info(ret_sinfo: StructInfo) -> None: + """Specify the return struct info of the last function frame. Parameters ---------- - ret_type: Union[ShapedType, Type] - The function return type. + ret_type: StructInfo + The function return struct info. """ - if isinstance(ret_type, ShapedType): - ret_type = ret_type.type - return _ffi_api.FuncRetType(ret_type) # pylint: disable=no-member # type: ignore - - -def func_ret_shape(ret_shape: Expr) -> None: - """Specify the return shape of the last function frame. - - Parameters - ---------- - ret_shape: Expr - The function return shape. - """ - return _ffi_api.FuncRetShape(ret_shape) # pylint: disable=no-member # type: ignore + return _ffi_api.FuncRetStructInfo(ret_sinfo) # pylint: disable=no-member # type: ignore def func_ret_value(value: Expr) -> None: @@ -233,7 +187,7 @@ def output(*vars: Tuple[Var]) -> None: def call_packed( func: str, *args: List[Expr], - type_args: Optional[Union[ShapedType, List[ShapedType]]] = None, + type_args: Optional[Union[StructInfo, List[StructInfo]]] = None, **kwargs: Dict[str, Expr], ) -> Call: """Create a relax Call, which calls a packed function. @@ -243,7 +197,7 @@ def call_packed( The name of extern function. args : List[Expr] The arguments. - type_args: Optional[Union[ShapedType, List[ShapedType]]] + type_args: Optional[Union[StructInfo, List[StructInfo]]] List of Types kwargs: Dict[str, Expr] The keyword arguments. @@ -264,13 +218,13 @@ def call_packed( for i, argument in enumerate(type_args): if callable(argument): argument = argument() - if isinstance(argument, ShapedType): - type_args[i] = argument.type + if isinstance(argument, StructInfo): + type_args[i] = get_static_type(argument) elif isinstance(argument, Type): type_args[i] = argument else: raise TypeError( - "call_packed `type_args` is expected to be list of ShapedType/Type, " + "call_packed `type_args` is expected to be list of StructInfo/Type, " f"but got {type(arg)}" ) @@ -289,7 +243,7 @@ def call_packed( def _tensor_type_wrapper(func): - """A wrapper to convert builder.ShapedType to relax.DynTensorType""" + """A wrapper to convert StructInfo to relax.DynTensorType""" def _convert_tensor_type(args): if isinstance(args, (list, tuple)): @@ -297,7 +251,7 @@ def _convert_tensor_type(args): return type(args)(new_args) if isinstance(args, dict): return {_convert_tensor_type(k): _convert_tensor_type(v) for k, v in args.items()} - return args.type if isinstance(args, ShapedType) else args + return get_static_type(args) if isinstance(args, StructInfo) else args @functools.wraps(func) def wrapped(*args, **kwargs): @@ -349,21 +303,23 @@ def emit_match_shape(value: Expr, pattern: List[PrimExpr], emit_var: bool) -> Op ############################# Type Deduce ############################## -def annotate_type_shape(var: Var, anno_type: Type, anno_shape: ShapeExpr) -> None: - """Annotate and check the type of relax var. +def annotate_struct_info(var: Var, anno_struct_info: StructInfo) -> None: + """Annotate the struct info of relax var. + Parameters ---------- var: Var The input var to be annotated. - anno_type: Type - The annotated type - anno_shape: ShapeExpr - The annotated shape + anno_struct_info: StructInfo + The annotated struct info """ - _ffi_api.AnnotateTypeShape(var, anno_type, anno_shape) + _ffi_api.AnnotateStructInfo(var, anno_struct_info) + + +############################# If Then Else ############################# def If(condition: Expr) -> frame.IfFrame: # pylint: disable=invalid-name @@ -401,14 +357,44 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name return _ffi_api.Else() # pylint: disable=no-member # type: ignore +######################## Symbolic Shape Rewriter ######################## + + +def RewriteSymbolicShape( + struct_info: StructInfo, + var_table: Dict[str, tvm.tir.Var], +) -> Tuple[StructInfo, List[tvm.tir.Var]]: + """Helper function to rewrite symbolic shape + + This function remaps the symbolic shape by + mapping certain vars to new variables. + + struct_info: StructInfo + The input struct info + + var_table: Dict[str, tvm.tir.Var] + Dictionary to map name of var to a new var. + + Returns + ------- + rewritten_info : StructInfo + The rewritten StructInfo + + undefined_vars: List[tvm.tir.Var] + List of undefined vars. + """ + return _ffi_api.RewriteSymbolicShape( + struct_info, var_table + ) # pylint: disable=no-member # type: ignore + + ############################### Importer ############################### __all__ = [ "Else", "If", "Object", - "Shape", - "ShapedType", + "RewriteSymbolicShape", "Then", "TupleGetItem", "Void", @@ -419,15 +405,13 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name "call_packed", "call_tir", "const", - "create_shaped_tuple", "dataflow", "emit", "emit_match_shape", "ewise_fma", "func_attr", "func_name", - "func_ret_type", - "func_ret_shape", + "func_ret_struct_info", "func_ret_value", "function", "invoke_closure", diff --git a/python/tvm/script/parser/relax/__init__.py b/python/tvm/script/parser/relax/__init__.py index 5bf0a21ca543..4e4f9240356b 100644 --- a/python/tvm/script/parser/relax/__init__.py +++ b/python/tvm/script/parser/relax/__init__.py @@ -18,6 +18,6 @@ from ...ir_builder.relax import * # pylint: disable=redefined-builtin from ...ir_builder.relax import ir as _relax from . import parser as _parser -from .entry import Callable, Tensor, Tuple, function, match_shape +from .entry import Callable, Shape, Tensor, Tuple, function, match_shape -__all__ = _relax.__all__ + ["Callable", "Tensor", "Tuple", "function", "match_shape"] +__all__ = _relax.__all__ + ["Callable", "Shape", "Tensor", "Tuple", "function", "match_shape"] diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 06ad29efeafe..58d857d0f805 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -17,18 +17,18 @@ # pylint: disable=missing-docstring, invalid-name import inspect from typing import Callable as _Callable -from typing import List, Optional +from typing import List, Optional, Tuple from typing import TypeVar as _TypeVar from typing import Union -from tvm.ir import FuncType, TypeConstraint, TypeVar -from tvm.relax import DynTensorType, Expr, Function +from tvm import relax +from tvm.relax import DynTensorType, Expr, Function, StructInfo from tvm.relax import Tuple as RxTuple from tvm.relax import Type, Var from tvm.runtime import ObjectGeneric from tvm.tir import PrimExpr -from ...ir_builder.relax import ShapedType, tensor, create_shaped_tuple +from ...ir_builder.relax import tensor from .._core import parse, utils FType = _TypeVar("FType", bound=_Callable) @@ -54,7 +54,7 @@ def __call__( shape: Optional[List[Union[PrimExpr, str]]] = None, dtype: str = None, ndim: int = -1, - ) -> ShapedType: + ) -> relax.TensorStructInfo: # scalar tensor case if shape is not None and len(shape) == 0: shape = [] @@ -86,36 +86,24 @@ class CallableProxy: a set of type constraints which we omit for the time being, a sequence of argument types, and a return type. - We can informally write them as: - `forall (type_params), (arg_types) -> ret_type where type_constraints` - Parameters ---------- - arg_types : List[Union[Type, ShapedType]] - The argument types - - ret_type : Type - The return type. + params : List[StructInfo] + The argument StructInfo - type_params : Optional[List[TypeVar]] - The type parameters + ret : StructInfo + The return StructInfo. - type_constraints : Optional[List[TypeConstraint]] - The type constraints. """ def __call__( self, - arg_types: List[Union[Type, ShapedType]], - ret_type: Type, - type_params: Optional[List[TypeVar]] = None, - type_constraints: Optional[List[TypeConstraint]] = None, - ) -> FuncType: - if isinstance(arg_types, ShapedType): - arg_types = [arg_types] - arg_types = [_convert_type(ty) for ty in arg_types] - ret_type = _convert_type(ret_type) - return FuncType(arg_types, ret_type, type_params, type_constraints) + params: Union[StructInfo, List[StructInfo], Tuple[StructInfo]], + ret: StructInfo, + ) -> relax.FuncStructInfo: + if not isinstance(params, (list, tuple)): + params = [params] + return relax.FuncStructInfo(params, ret) def __getitem__(self, keys) -> Var: return self(*keys) # pylint: disable=no-member # type: ignore @@ -131,25 +119,29 @@ class TupleProxy: Parameters ---------- - fields : List[Type] + fields : List[Union[Expr, Type, StructInfo]] The fields in the tuple """ def __call__( self, - *fields: List[Union[Expr, Type, ShapedType]], - ) -> Union[Expr, ShapedType]: + *fields: List[Union[Expr, Type, StructInfo]], + ) -> Union[Expr, StructInfo]: if len(fields) == 1 and isinstance(fields[0], (tuple, list)): fields = fields[0] + # TODO(siyuan): Revisit this part if all([isinstance(f, Expr) for f in fields]): return RxTuple(fields) - elif all([isinstance(f, (ShapedType, Type, TensorProxy)) for f in fields]): - types = [_convert_type(ty) for ty in fields] - shapes = [ty.shape if isinstance(ty, ShapedType) else None for ty in fields] - return create_shaped_tuple(types, shapes) else: - raise TypeError(f"Invalid tuple type: {fields}") + fields = list(fields) + for i, x in enumerate(fields): + if callable(x): + fields[i] = x() + if all([isinstance(f, StructInfo) for f in fields]): + return relax.TupleStructInfo(fields) + else: + raise TypeError(f"Invalid tuple type: {fields}") def __getitem__(self, keys) -> Var: return self(*keys) # pylint: disable=no-member # type: ignore @@ -157,6 +149,34 @@ def __getitem__(self, keys) -> Var: Tuple = TupleProxy() +############################### R.Shape ################################ + + +class ShapeProxy: + """The type of shape values. + + Parameters + ---------- + values : Optional[List[PrimExpr]] + The symbolic shape values if known. + + ndim : Optional[int] + The size of the shape. + """ + + def __call__( + self, + values: Optional[List[PrimExpr]] = None, + ndim: int = -1, + ) -> StructInfo: + return relax.ShapeStructInfo(values, ndim) + + def __getitem__(self, keys) -> Var: + return self(*keys) # pylint: disable=no-member # type: ignore + + +Shape = ShapeProxy() + ############################ R.match_shape ############################# class MatchShapePair: value: Expr @@ -173,17 +193,3 @@ def match_shape(value: Expr, pattern: List[PrimExpr]): if pattern is None: raise ValueError("pattern of match_shape cannot be None") return MatchShapePair(value, pattern) - - -################################ utils ################################# - - -def _convert_type(ty: Union[Type, ShapedType, TensorProxy]) -> Type: - if isinstance(ty, TensorProxy): - return ty().type - if isinstance(ty, ShapedType): - return ty.type - elif isinstance(ty, Type): - return ty - else: - raise TypeError(f"Expect a Type or ShapedType, but got: {ty}") diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 442bccada45f..cf37d0183caf 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -16,11 +16,12 @@ # under the License. # pylint: disable=missing-docstring -from typing import Any, Union import numbers +from typing import Any, Optional, Tuple, Union from tvm import relax, tir from tvm.ir import Type +from tvm.relax import StructInfo, Expr from tvm.relax.utils import convert_to_expr from tvm.script.ir_builder.relax.frame import BlockFrame @@ -28,7 +29,7 @@ from ...ir_builder import relax as R from ...ir_builder.base import IRBuilder from .._core import Parser, dispatch, doc -from .entry import MatchShapePair, Tensor, ShapedType +from .entry import MatchShapePair def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: @@ -107,17 +108,21 @@ def eval_shape_annotation( return None -def eval_type_annotation(self: Parser, node: Union[doc.Expression, doc.expr]) -> Any: - type_annotation = self.eval_expr(node) - if callable(type_annotation): - type_annotation = Tensor() - if isinstance(type_annotation, ShapedType): - shape = eval_shape_annotation(self, node, type_annotation.shape) - return type_annotation.type, shape +# pylint: disable=inconsistent-return-statements +def eval_type_annotation( + self: Parser, node: Union[doc.Expression, doc.expr] +) -> Tuple[Type, Optional[Expr], StructInfo]: + annotation = self.eval_expr(node) + if callable(annotation): + annotation = annotation() + if isinstance(annotation, StructInfo): + var_table = {k: v for k, v in self.var_table.get().items() if isinstance(v, tir.Var)} + annotation, undefined_vars = R.RewriteSymbolicShape(annotation, var_table) + for var in undefined_vars: + self.var_table.add(var.name, var) + return annotation else: - if not isinstance(type_annotation, Type): - self.report_error(node, f"Unsupported type annotation {type(type_annotation)}") - return type_annotation, None + self.report_error(node, f"Unsupported type annotation {annotation}") @dispatch.register(token="relax", type_name="FunctionDef") @@ -126,13 +131,9 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: with R.function(): R.func_name(node.name) if node.returns is not None: - ann_type, ann_shape = eval_type_annotation(self, node.returns) - R.func_ret_type(ann_type) + ann_sinfo = eval_type_annotation(self, node.returns) + R.func_ret_struct_info(ann_sinfo) - # TODO(relax-team): remove the following line when fixing ret_shape issue - ann_shape = relax.RuntimeDepShape() - - R.func_ret_shape(ann_shape) with self.with_dispatch_token("relax"): self.visit(node.args) self.visit_body(node.body) @@ -141,24 +142,24 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: @dispatch.register(token="relax", type_name="tvm_declare_function") def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: if node.returns is None: - ret_type, ret_shape = None, None + ret_sinfo, ret_type, ret_shape = relax.ObjectStructInfo(), None, None else: - ret_type, ret_shape = eval_type_annotation(self, node.returns) + ret_sinfo = eval_type_annotation(self, node.returns) + ret_type = relax.analysis.get_static_type(ret_sinfo) + ret_shape = relax.analysis.get_legacy_shape_hint(ret_sinfo) params = [] - arg_types = [] + params_sinfo = [] for arg in node.args.args: if arg.annotation is None: self.report_error(arg, "Type annotation is required for function parameters.") - param_type, param_shape = self.visit_tvm_annotation(arg.annotation) - arg_types.append(param_type) + param_sinfo = self.visit_tvm_annotation(arg.annotation) + param_type = relax.analysis.get_static_type(param_sinfo) + param_shape = relax.analysis.get_legacy_shape_hint(param_sinfo) + params_sinfo.append(param_sinfo) params.append(relax.Var(arg.arg, param_shape, param_type)) - # TODO(relax-team): remove the following line when fixing ret_shape issue in block builder - ret_shape = relax.RuntimeDepShape() - func_signature = relax.Function.create_unchecked(params, None, ret_type, ret_shape) global_var = I.decl_function(node.name, func_signature) - relax.expr._update_type(global_var, relax.FuncType(arg_types, ret_type)) self.var_table.add(node.name, global_var) @@ -193,8 +194,8 @@ def visit_arguments(self: Parser, node: doc.arguments) -> None: for arg in node.args: if arg.annotation is None: self.report_error(arg, "Type annotation is required for function parameters.") - param_type, param_shape = self.visit_tvm_annotation(arg.annotation) - param = R.arg(arg.arg, param_type, param_shape) + param_sinfo = self.visit_tvm_annotation(arg.annotation) + param = R.arg(arg.arg, param_sinfo) self.var_table.add(arg.arg, param) @@ -243,7 +244,7 @@ def visit_assign(self: Parser, node: doc.Assign) -> None: def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: lhs = node.target rhs = self.eval_expr(node.value) - ann_type, ann_shape = self.visit_tvm_annotation(node.annotation) + ann_sinfo = self.visit_tvm_annotation(node.annotation) self.eval_assign( target=lhs, source=rhs, @@ -252,7 +253,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: ) var = self.var_table.get().get(lhs.id) assert isinstance(var, relax.Var) - R.ir.annotate_type_shape(var, ann_type, ann_shape) + R.ir.annotate_struct_info(var, ann_sinfo) @dispatch.register(token="relax", type_name="Return") diff --git a/python/tvm/script/parser_v1/parser.py b/python/tvm/script/parser_v1/parser.py deleted file mode 100644 index 1f8f71c27168..000000000000 --- a/python/tvm/script/parser_v1/parser.py +++ /dev/null @@ -1,1403 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM Script Parser For TIR - -We use [synr](https://synr.readthedocs.io) to get an AST that is stable over -different python versions. Synr also provides an error handling context that we -use for error reporting. -""" -import functools -import inspect -import json -import operator - -# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except, import-outside-toplevel -import types -from typing import Any, Callable, Dict, List, Optional, Union - -import tvm -from synr import Transformer, ast, to_ast -from tvm import IRModule, relax -from tvm._ffi.base import TVMError -from tvm.ir import GlobalVar -from tvm.ir.function import BaseFunc -from tvm.tir import buffer -from tvm.tir.function import PrimFunc - -from . import tir -from .context_maintainer import ContextMaintainer -from .diagnostics import TVMDiagnosticCtx -from .meta_unparser import MetaUnparser -from .registry import Registry -from .tir import ty -from .tir.intrin import Intrin -from .tir.node import BufferSlice, Slice -from .tir.scope_handler import ForScopeHandler, ScopeHandler, WithScopeHandler -from .tir.special_stmt import SpecialStmt -from .utils import call_with_error_reporting, synr_span_from_tvm, tvm_span_from_synr - - -class CallArgumentReader(object): - """Helper class to read required arguments from passed arguments. - - When parsing a function call, we need to match the arguments provided in - the AST to the required arguments of the function. This class makes sure - all the positional arguments are filled and also fill keyword arguments - with thier default value if a different value was not provided. - """ - - def __init__(self, func_name, args, kwargs, parser, node): - self.func_name = func_name - self.args = args - self.kwargs = kwargs - self.parser = parser - self.node = node - - def get_pos_only_arg(self, pos, name): - """Get corresponding position only function argument from argument list""" - if len(self.args) >= pos: - arg = self.args[pos - 1] - elif name not in self.kwargs: - # If no positional argument was found in the AST, we see if it was - # defined by name instead. - # TODO(tkonolige): this error message is not quite correct. The - # number of required arguments is >= pos - self.parser.report_error( - f"{self.func_name} requires {pos} arguments, but only {len(self.args)} were given.", - self.node.span, - ) - else: - arg = self.kwargs[name] - - return arg - - def get_kwarg(self, pos, name, default): - """Get corresponding keyword function argument from argument list. - - If the user hasn't provided the argument, set it to the default value. - """ - if len(self.args) >= pos: - arg = self.args[pos - 1] - elif name in self.kwargs: - arg = self.kwargs[name] - else: - return default - - return arg - - def get_varargs(self, pos): - """Get corresponding variable argument from argument list""" - if len(self.args) >= pos and len(self.kwargs) == 0: - return self.args[pos - 1 :] - return [] - - -class TVMScriptParser(Transformer): - """Synr AST visitor pass which finally lowers to TIR. - - Notes for Extension - ------------------- - 1. To support a new type of AST node, add a function transform_xxx(). - 2. To support new functions, add the function to the appropriate registry: - We divide allowed function calls in TVM script into 3 categories, - intrin, scope_handler and special_stmt. - 1. intrin functions are low level functions like mod, load, and - constants. They correspond to a tir `IRNode`. They must have a - return value. The user can register intrin functions for the parser to - use. - 2. scope_handler functions have no return value. They take two - arguments: the parser and the AST node. scope_handler functions are - used in with and for statements. - 3. special_stmt functions handle cases that do not have a corresponding - tir `IRNode`. These functions take the parser and the AST node as - arguments and may return a value. - When visiting a Call node, we check the special_stmt registry first. If - no registered function is found, we then check the intrin registry. - When visiting With node, we check the with_scope registry. - When visiting For node, we check the for_scope registry. - """ - - _binop_maker = { - ast.BuiltinOp.Add: tvm.tir.Add, - ast.BuiltinOp.Sub: tvm.tir.Sub, - ast.BuiltinOp.Mul: tvm.tir.Mul, - ast.BuiltinOp.Div: tvm.tir.Div, - ast.BuiltinOp.FloorDiv: tvm.tir.FloorDiv, - ast.BuiltinOp.Mod: tvm.tir.FloorMod, - ast.BuiltinOp.BitOr: lambda lhs, rhs, span: operator.or_(lhs, rhs), - ast.BuiltinOp.BitAnd: lambda lhs, rhs, span: operator.and_(lhs, rhs), - ast.BuiltinOp.BitXor: lambda lhs, rhs, span: operator.xor(lhs, rhs), - ast.BuiltinOp.GT: tvm.tir.GT, - ast.BuiltinOp.GE: tvm.tir.GE, - ast.BuiltinOp.LT: tvm.tir.LT, - ast.BuiltinOp.LE: tvm.tir.LE, - ast.BuiltinOp.Eq: tvm.tir.EQ, - ast.BuiltinOp.NotEq: tvm.tir.NE, - ast.BuiltinOp.And: tvm.tir.And, - ast.BuiltinOp.Or: tvm.tir.Or, - } - - _unaryop_maker = { - ast.BuiltinOp.USub: lambda rhs, span: operator.neg(rhs), - ast.BuiltinOp.Invert: lambda rhs, span: operator.invert(rhs), - ast.BuiltinOp.Not: tvm.tir.Not, - } - - # pylint gets confused here with synr.Transformer which doesn't have a - # custom init, so just disable it - def __init__( - self, base_lineno, tir_namespace, closure_vars - ): # pylint: disable=super-init-not-called - self.context = None - - self.base_lineno = base_lineno - self.current_lineno = 0 - self.current_col_offset = 0 - self.tir_namespace = tir_namespace - self.closure_vars = closure_vars - self.meta = None - self._inside_buffer_sugar = False - - def init_function_parsing_env(self): - """Initialize function parsing environment""" - self.context = ContextMaintainer(self.report_error, self.closure_vars) # scope emitter - - def init_meta(self, meta_dict): - if meta_dict is not None: - self.meta = tvm.ir.load_json(json.dumps(meta_dict)) - - def transform(self, node): - """Generic transformation for visiting the AST. Dispatches to - `transform_ClassName` for the appropriate ClassName.""" - old_lineno, old_col_offset = self.current_lineno, self.current_col_offset - - if hasattr(node, "lineno"): - self.current_lineno = self.base_lineno + node.lineno - 1 - if hasattr(node, "col_offset"): - self.current_col_offset = node.col_offset - - method = "transform_" + node.__class__.__name__ - visitor = getattr(self, method, self.generic_visit) - transform_res = visitor(node) - - self.current_lineno, self.current_col_offset = old_lineno, old_col_offset - - return transform_res - - def match_tir_namespace(self, identifier: str) -> bool: - """Check if the namespace is equal to tvm.script.tir""" - return identifier in self.tir_namespace - - def report_error(self, message: str, span: Union[ast.Span, tvm.ir.Span]): - """Report an error occuring at a location. - - This just dispatches to synr's DiagnosticContext. - - Parameters - ---------- - message : str - Error message - span : Union[synr.ast.Span, tvm.ir.Span] - Location of the error - """ - if isinstance(span, tvm.ir.Span): - span = synr_span_from_tvm(span) - self.error(message, span) - - def parse_body(self, parent): - """Parse remaining statements in this scope. - - Parameters - ---------- - parent : synr.ast.Node - Parent node of this scope. Errors will be reported here. - """ - body = [] - spans = [] - stmt = parent - while len(self.context.node_stack[-1]) > 0: - stmt = self.context.node_stack[-1].pop() - spans.append(stmt.span) - res = self.transform(stmt) - if res is not None: - body.append(res) - if len(body) == 0: - self.report_error( - "Expected another statement at the end of this block. Perhaps you " - "used a concise statement and forgot to include a body afterwards.", - stmt.span, - ) - else: - return ( - tvm.tir.SeqStmt(body, tvm_span_from_synr(ast.Span.union(spans))) - if len(body) > 1 - else body[0] - ) - - def parse_arg_list(self, func, node_call): - """Match the arguments of a function call in the AST to the required - arguments of the function. This handles positional arguments, - positional arguments specified by name, keyword arguments, and varargs. - - Parameters - ---------- - func : Function - The function that provides the signature - - node_call: Union[ast.Call, ast.TypeApply, ast.TypeCall] - The AST call node that calls into the function. - - Returns - ------- - arg_list : list - The parsed positional argument. - """ - assert isinstance(node_call, (ast.Call, ast.TypeApply, ast.TypeCall)) - # collect arguments - args = [self.transform(arg) for arg in node_call.params] - if isinstance(node_call, ast.TypeApply): - kw_args = {} # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr - else: - kw_args = { - self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items() - } - # get the name and parameter list of func - if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)): - func_name, param_list = func.signature() - else: - self.report_error( - "Internal Error: function must be of type Intrin, ScopeHandler or SpecialStmt, " - f"but it is {type(func).__name__}", - node_call.span, - ) - # check arguments and parameter list and get a list of arguments - reader = CallArgumentReader(func_name, args, kw_args, self, node_call) - pos_only, kwargs, varargs = param_list - internal_args = list() - - for i, arg_name in enumerate(pos_only): - internal_args.append(reader.get_pos_only_arg(i + 1, arg_name)) - for i, arg_info in enumerate(kwargs): - arg_name, default = arg_info - internal_args.append(reader.get_kwarg(i + 1 + len(pos_only), arg_name, default=default)) - if varargs is not None: - internal_args.extend(reader.get_varargs(len(pos_only) + len(kwargs) + 1)) - elif len(args) + len(kw_args) > len(pos_only) + len(kwargs): - self.report_error( - "Arguments mismatched. " - + f"Expected {len(pos_only) + len(kwargs)} args but got " - + f"{len(args) + len(kw_args)}", - node_call.span, - ) - return internal_args - - def parse_type(self, type_node, parent): - """Parse a type annotation. - - We require the parent object to the type so that we have a place to - report the error message if the type does not exist. - """ - if type_node is None: - self.report_error("A type annotation is required", parent.span) - res_type = self.transform(type_node) - return tvm.ir.TupleType([]) if res_type is None else res_type.evaluate() - - def generic_visit(self, node): - """Fallback visitor if node type is not handled. Reports an error.""" - - self.report_error(type(node).__name__ + " AST node is not supported", node.span) - - def transform_Module(self, node): - """Module visitor - - Right now, we only support two formats for TVM Script. - - Example - ------- - 1. Generate a PrimFunc (If the code is printed, then it may also contain metadata) - .. code-block:: python - - import tvm - - @tvm.script - def A(...): - ... - - # returns a PrimFunc - func = A - - 2. Generate an IRModule - .. code-block:: python - - import tvm - - @tvm.script.ir_module - class MyMod(): - @T.prim_func - def A(...): - ... - @T.prim_func - def B(...): - ... - - __tvm_meta__ = ... - - # returns an IRModule - mod = MyMod - """ - if len(node.funcs) == 1: - return self.transform(next(iter(node.funcs.values()))) - elif len(node.funcs) == 0: - self.report_error( - "You must supply at least one class or function definition", node.span - ) - else: - self.report_error( - "Only one-function, one-class or function-with-meta source code is allowed", - ast.Span.union([x.span for x in list(node.funcs.values())[1:]]), - ) - - def transform_Class(self, node): - """Class definition visitor. - - A class can have multiple function definitions and a single - :code:`__tvm_meta__` statement. Each class corresponds to a single - :code:`IRModule`. - - Example - ------- - .. code-block:: python - - @tvm.script.ir_module - class MyClass: - __tvm_meta__ = {} - def A(): - T.evaluate(0) - """ - if len(node.assignments) == 1: - if not ( - len(node.assignments[0].lhs) == 1 - and isinstance(node.assignments[0].lhs[0], ast.Var) - and node.assignments[0].lhs[0].id.name == "__tvm_meta__" - ): - self.report_error( - "The only top level assignments allowed are `__tvm_meta__ = ...`", - node.assignments[0].span, - ) - self.init_meta( - MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context) - ) - elif len(node.assignments) > 1: - self.report_error( - "Only a single top level `__tvm_meta__` is allowed", - ast.Span.union([x.span for x in node.assignments[1:]]), - ) - - return IRModule( - {GlobalVar(name): self.transform(func) for name, func in node.funcs.items()} - ) - - def transform_Function(self, node): - """Function definition visitor. - - Each function definition is translated to a single :code:`PrimFunc`. - - There are a couple restrictions on TVM Script functions: - 1. Function arguments must have their types specified. - 2. The body of the function can contain :code:`func_attr` to specify - attributes of the function (like it's name). - 3. The body of the function can also contain multiple :code:`buffer_bind`s, - which give shape and dtype information to arguments. - 4. Return statements are implicit. - - Example - ------- - .. code-block:: python - - @T.prim_func - def my_function(x: T.handle): # 1. Argument types - T.func_attr({"global_symbol": "mmult"}) # 2. Function attributes - X_1 = tir.buffer_bind(x, [1024, 1024]) # 3. Buffer binding - T.evaluate(0) # 4. This function returns 0 - """ - - def check_as_torch_decorator(decorator: Union[ast.Call, ast.Var]): - if isinstance(decorator, ast.Call): - if len(decorator.params) != 1: - return False - func_name = decorator.func_name - else: - func_name = decorator - if isinstance(func_name, ast.Var): - return func_name.id.name == "as_torch" - - def check_decorator(decorators: List[ast.Expr]) -> bool: - """Check the decorator is `T.prim_func""" - if len(decorators) > 2 or len(decorators) == 0: - return False - if len(decorators) == 2 and not check_as_torch_decorator(decorators[0]): - return False - d: ast.Expr = decorators[-1] - return ( - isinstance(d, ast.Attr) - and isinstance(d.object, ast.Var) - and self.match_tir_namespace(d.object.id.name) - and d.field.name == "prim_func" - ) - - self.init_function_parsing_env() - self.context.enter_scope(nodes=node.body.stmts) - - # add parameters of function - for arg in node.params: - # Note that this case is for T.match_buffer syntax sugar - if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance( - self.transform(arg.ty.func_name), ty.GenericBufferType - ): - result = self.handle_match_buffer_type(arg.ty, arg.name) - if not isinstance(result, buffer.Buffer): - self.report_error( - "The result type of evaluating TypeCall and TypeApply stmt" - f" is wrong: {type(result)}. It should be a Buffer", - node.span, - ) - arg_name_with_handle = arg.name + "_handle" - arg_var = tvm.te.var(arg_name_with_handle, tvm.ir.PrimType("handle")) - self.context.func_buffer_map[arg_var] = result - self.context.update_symbol(arg.name, result, node) - else: - arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg)) - self.context.update_symbol(arg.name, arg_var, node) - self.context.func_params.append(arg_var) - - if not check_decorator(node.decorators): - self.report_error( - "All functions should be decorated by `T.prim_func`", - node.span, - ) - - # fetch the body of root block - body = self.parse_body(node.body) - - # return a tir.PrimFunc - dict_attr = self.context.func_dict_attr - ret_type = self.parse_type(node.ret_type, node) if node.ret_type is not None else None - func = tvm.tir.PrimFunc( - self.context.func_params, - body, - ret_type, - buffer_map=self.context.func_buffer_map, - attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None, - span=tvm_span_from_synr(node.span), - ) - - # New Scope : Implicit root block - # Each function contains an implicit root block in TensorIR, - # so here we need a block scope for it. - # If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or low-level func), - # the root block will not be added. The logic to add root block is in `_ffi_api.Complete` - - # Fix the PrimFunc - # 1. generate root block if necessary - # 2. generate surrounding loops for blocks if necessary - - func = call_with_error_reporting( - self.report_error, - node.span, - _ffi_api.Complete, - func, - self.context.root_alloc_buffers, - ) - - self.context.exit_scope() - return func - - def transform_Lambda(self, node): - """Lambda visitor - - Return an array of input parameters and the transformed lambda body. - """ - - self.context.enter_scope(nodes=[node.body]) - - # add parameters of the lambda - arg_vars = [] - for arg in node.params: - # Use "void" for dtype here. The actual type is not yet known and will be - # determined later. Using void type will allow IRSubstitute to do the - # replacement without flagging a type-mismatch error. - arg_var = tvm.te.var(arg.name, dtype="") - arg_vars.append(arg_var) - self.context.update_symbol(arg.name, arg_var, node) - - # the body of a lambda must be an expr - if not isinstance(node.body, ast.Expr): - self.report_error("The body of a lambda must be an expression", node.span) - - # transform the body of the lambda - body = self.transform(node.body) - - self.context.exit_scope() - return arg_vars, body - - def transform_Assign(self, node): - """Assign visitor - AST abstract grammar: - Assign(expr* targets, expr value, string? type_comment) - - By now 5 patterns of Assign is supported: - 1. special stmts with return value - 1.1 Buffer = T.match_buffer()/T.buffer_decl() - 1.2 Var = T.var() - 1.3 Var = T.env_thread() - 2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr - 3. (Store) Var[PrimExpr] = PrimExpr - 4. with scope handlers with concise scoping and var def - 4.1 var = T.allocate() - 5. A call to a pure python function, consuming and producing TVMScript values. - The outputs are inlined into the following body (no variable is created). - x, y = f(...) - """ - - if isinstance(node.rhs, ast.Call): - # Pattern 1 & Pattern 4 - if isinstance(node.rhs.func_name, ast.Op): - func = None - else: - func = self.transform(node.rhs.func_name) - - if isinstance(func, WithScopeHandler): - if not func.concise_scope or not func.def_symbol: - self.report_error( - "with scope handler " + func.signature()[0] + " is not suitable here", - node.rhs.span, - ) - # Pattern 4 - arg_list = self.parse_arg_list(func, node.rhs) - func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) - func.body = self.parse_body(node) - return func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span) - elif isinstance(func, SpecialStmt): - # Pattern 1 - arg_list = self.parse_arg_list(func, node.rhs) - func.handle(node, self.context, arg_list, node.rhs.func_name.span) - return self.parse_body(node) - elif isinstance(func, types.FunctionType): - # Pattern 5 - args = [self.transform(arg) for arg in node.rhs.params] - try: - out = func(*args) - except Exception as e: - self.report_error( - "Error occured when invoking the function " - + func.__name__ - + ": \n" - + str(e), - node.rhs.span, - ) - - if len(node.lhs) == 1 and not isinstance(out, list): - out = [out] - - assert len(out) == len(node.lhs) - - for var, value in zip(node.lhs, out): - self.context.update_symbol(var.id.name, value, node) - - body = self.parse_body(node) - - for var, value in zip(node.lhs, out): - self.context.remove_symbol(var.id.name) - - return body - - if isinstance(node.rhs, (ast.Call, ast.Constant)): - # Pattern 4 of let binding - value = self.transform(node.rhs) - if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var): - # This is a little confusing because it only is true when - # we have taken this branch. We might need to clarify what - # exectly is allowed in Assignments in tvmscript. - self.report_error( - "Left hand side of assignment must be an unqualified variable", - node.span, - ) - ast_var = node.lhs[0] - - if node.ty is None and hasattr(value, "dtype"): - var_ty = value.dtype - else: - var_ty = self.parse_type(node.ty, ast_var) - - var = tvm.te.var( - ast_var.id.name, - var_ty, - span=tvm_span_from_synr(ast_var.span), - ) - self.context.update_symbol(var.name, var, node) - body = self.parse_body(node) - self.context.remove_symbol(var.name) - return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) - - self.report_error( - """Assignments should be one of: - 1. A "special statement" with return value - 1.1 Buffer = T.match_buffer()/T.buffer_decl() - 1.2 Var = T.var() - 1.3 Var = T.env_thread() - 2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr - 3. A store into a variable: Var[PrimExpr] = PrimExpr - 4. A with scope handler with concise scoping and var def - 4.1 var = T.allocate() - 5. The right-hand side being a call to a pure python function, consuming and - producing TVMScript values. - x, y = f(...)""", - node.span, - ) - - def transform_SubscriptAssign(self, node): - """Visitor for statements of the form :code:`x[1] = 2`.""" - symbol = self.transform(node.params[0]) - indexes = self.transform(node.params[1]) - rhs = self.transform(node.params[2]) - rhs_span = tvm_span_from_synr(node.params[2].span) - if isinstance(symbol, tvm.tir.Buffer): - if len(indexes) != len(symbol.shape): - self.report_error( - f"Buffer {symbol.name} is {len(symbol.shape)}-dimensional, " - f"cannot be indexed by {len(indexes)}-dimensional indices.", - node.params[1].span, - ) - - def __convert_index(x): - if isinstance(x, Slice): - return x.as_index_expr(self.report_error) - return x - - # BufferStore - indexes = [__convert_index(x) for x in indexes] - return tvm.tir.BufferStore( - symbol, - tvm.runtime.convert(rhs, span=rhs_span), - indexes, - span=tvm_span_from_synr(node.span), - ) - else: - if symbol.dtype == "handle" and len(indexes) != 1: - self.report_error( - "Handles only support one-dimensional indexing. Use `T.match_buffer` to " - "construct a multidimensional buffer from a handle.", - node.params[0].span, - ) - if len(indexes) != 1: - self.report_error( - f"Store is only allowed with one index, but {len(indexes)} were provided.", - node.params[1].span, - ) - self.report_error( - "Use of tir.Store has been deprecated in favor of tir.BufferStore.", node.span - ) - - def transform_AttrAssign(self, node): - """Visitor for statements of the form :code:`x.y = 2`.""" - obj = self.transform(node.params[0]) - field = node.params[1] - value = self.transform(node.params[2]) - - if not hasattr(obj, field.name): - self.error(f"Field {field.name} does not exist", field.span) - - var = getattr(obj, field.name) - - if not isinstance(var, tvm.tir.Var): - self.error( - f"Can only assign to tir.Var attributes, not {type(var).__name__}", node.span - ) - - body = self.parse_body(node) - return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) - - def transform_Assert(self, node): - """Assert visitor - - Pattern corresponds to concise mode of :code:`with T.Assert()`. - """ - - condition = self.transform(node.condition) - if node.msg is None: - self.report_error("Assert statements must have an error message.", node.span) - message = self.transform(node.msg) - body = self.parse_body(node) - return tvm.tir.AssertStmt( - condition, tvm.runtime.convert(message), body, span=tvm_span_from_synr(node.span) - ) - - def transform_For(self, node): - """For visitor - AST abstract grammar: - For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) - By now 1 pattern of For is supported: - 1. for scope handler - for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/range()/ - T.grid()/T.thread_binding() - """ - - if not isinstance(node.rhs, ast.Call): - self.report_error("The loop iterator should be a function call.", node.rhs.span) - func = self.transform(node.rhs.func_name) - if not isinstance(func, ForScopeHandler): - self.report_error( - "Only For scope handlers can be used in a for statement.", node.rhs.func_name.span - ) - # prepare for new for scope - old_lineno, old_col_offset = self.current_lineno, self.current_col_offset - self.current_lineno = node.span.start_line - self.current_col_offset = node.span.start_column - self.context.enter_scope(nodes=node.body.stmts) - # for scope handler process the scope - arg_list = [ - tvm.runtime.convert(arg, span=tvm_span_from_synr(node.rhs.span)) - for arg in self.parse_arg_list(func, node.rhs) - ] - func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) - func.body = self.parse_body(node) - res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span) - # exit the scope - self.context.exit_scope() - self.current_lineno, self.current_col_offset = old_lineno, old_col_offset - return res - - def transform_While(self, node): - """While visitor - AST abstract grammar: - While(expr condition, stmt* body) - """ - condition = self.transform(node.condition) - # body - self.context.enter_scope(nodes=node.body.stmts) - body = self.parse_body(node) - self.context.exit_scope() - - return tvm.tir.While(condition, body, span=tvm_span_from_synr(node.span)) - - def transform_With(self, node): - """With visitor - AST abstract grammar: - With(withitem* items, stmt* body, string? type_comment) - withitem = (expr context_expr, expr? optional_vars) - By now 2 patterns of With is supported: - 1. with scope handler with symbol def - with T.allocate() as targets: - 2. with scope handler without symbol def - with T.block(*axes)/T.let()/T.Assert()/T.attr()/T.realize() - """ - - if not isinstance(node.rhs, ast.Call): - self.report_error( - "The context expression of a `with` statement should be a function call.", - node.rhs.span, - ) - - func = self.transform(node.rhs.func_name) - - if not isinstance(func, WithScopeHandler): - self.report_error( - f"Function {func} cannot be used in a `with` statement.", node.rhs.func_name.span - ) - # prepare for new block scope - old_lineno, old_col_offset = self.current_lineno, self.current_col_offset - self.current_lineno = node.body.span.start_line - self.current_col_offset = node.body.span.start_column - self.context.enter_block_scope(nodes=node.body.stmts) - # with scope handler process the scope - arg_list = self.parse_arg_list(func, node.rhs) - func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) - func.body = self.parse_body(node) - res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span) - # exit the scope - self.context.exit_block_scope() - self.current_lineno, self.current_col_offset = old_lineno, old_col_offset - return res - - def transform_If(self, node): - """If visitor - AST abstract grammar: - If(expr test, stmt* body, stmt* orelse) - """ - - condition = self.transform(node.condition) - # then body - self.context.enter_scope(nodes=node.true.stmts) - then_body = self.parse_body(node) - self.context.exit_scope() - - # else body - if len(node.false.stmts) > 0: - self.context.enter_scope(nodes=node.false.stmts) - else_body = self.parse_body(node) - self.context.exit_scope() - else: - else_body = None - - return tvm.tir.IfThenElse( - condition, then_body, else_body, span=tvm_span_from_synr(node.span) - ) - - def transform_Call(self, node): - """Call visitor - - 3 different Call patterns are allowed: - 1. Intrin representing a PrimExpr/IterVar - 1.1 tir.int/uint/float8/16/32/64/floormod/floordiv/load/cast/ramp/broadcast/max - 1.2 tir.range/reduce_axis/scan_axis/opaque_axis - 2. tir.Op(dtype, ...) - 3. other callable functions - """ - - if isinstance(node.func_name, ast.Op): - if node.func_name.name == ast.BuiltinOp.Subscript: - return self.transform_Subscript(node) - if node.func_name.name in self._binop_maker: - lhs = self.transform(node.params[0]) - # There is no supertype for everything that can appear in - # an expression, so we manually add what we might get here. - if not isinstance(lhs, (tvm.tir.PrimExpr, BufferSlice)): - # We would really like to report a more specific - # error here, but this parser contains no distinction - # between parsing statements and parsing expressions. All - # rules just call `transform`. - self.report_error( - f"Left hand side of binary op must be a PrimExpr, " - "but it is a {type(lhs).__name__}", - node.params[0].span, - ) - rhs = self.transform(node.params[1]) - if not isinstance(rhs, (tvm.tir.PrimExpr, BufferSlice)): - self.report_error( - f"Right hand side of binary op must be a PrimExpr, " - "but it is a {type(rhs).__name__}", - node.params[1].span, - ) - return call_with_error_reporting( - self.report_error, - node.span, - lambda node, lhs, rhs, span: self._binop_maker[node.func_name.name]( - lhs, rhs, span=span - ), - node, - lhs, - rhs, - tvm_span_from_synr(node.span), - ) - if node.func_name.name in self._unaryop_maker: - rhs = self.transform(node.params[0]) - return self._unaryop_maker[node.func_name.name]( - rhs, span=tvm_span_from_synr(node.span) - ) - self.report_error(f"Unsupported operator {node.func_name.name}.", node.func_name.span) - else: - func = self.transform(node.func_name) - if isinstance(func, Intrin) and not func.stmt: - # pattern 1 - arg_list = self.parse_arg_list(func, node) - return call_with_error_reporting( - self.report_error, - node.func_name.span, - func.handle, - arg_list, - node.func_name.span, - ) - else: - args = [self.transform(arg) for arg in node.params] - kw_args = { - self.transform(k): self.transform(v) for k, v in node.keyword_params.items() - } - if isinstance(func, tvm.tir.op.Op): - if not "dtype" in kw_args.keys(): - self.report_error(f"{func} requires a dtype keyword argument.", node.span) - # pattern 2 - return tvm.tir.Call( - kw_args["dtype"], func, args, span=tvm_span_from_synr(node.span) - ) - elif callable(func): - # pattern 3 - return func(*args, **kw_args) - else: - self.report_error( - f"Function is neither callable nor a tvm.tir.op.Op (it is a {type(func)}).", - node.func_name.span, - ) - - def transform_UnassignedCall(self, node): - """Visitor for statements that are function calls. - - This handles function calls that appear on thier own line like `tir.realize`. - - Examples - -------- - .. code-block:: python - - @T.prim_func - def f(): - A = T.buffer_decl([10, 10]) - T.realize(A[1:2, 1:2], "") # This is an UnassignedCall - A[1, 1] = 2 # This is also an UnassignedCall - """ - # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign. - if isinstance(node.call.func_name, ast.Op): - if node.call.func_name.name == ast.BuiltinOp.SubscriptAssign: - return self.transform_SubscriptAssign(node.call) - - if node.call.func_name.name == ast.BuiltinOp.AttrAssign: - return self.transform_AttrAssign(node.call) - - self.report_error( - "Binary and unary operators are not allowed as a statement", node.span - ) - - # handle a regular function call - func = self.transform(node.call.func_name) - arg_list = self.parse_arg_list(func, node.call) - - if isinstance(func, tir.scope_handler.AssertHandler): - self.report_error( - "A standalone `T.Assert` is not allowed. Use `assert condition, message` " - "instead.", - node.call.func_name.span, - ) - - if isinstance(func, Intrin): - if func.stmt: - return call_with_error_reporting( - self.report_error, - node.call.func_name.span, - func.handle, - arg_list, - node.call.func_name.span, - ) - else: - self.report_error(f"This intrinsic cannot be used as a statement.", node.call.span) - elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol: - func.enter_scope(node, self.context, arg_list, node.call.func_name.span) - func.body = self.parse_body(node) - return func.exit_scope(node, self.context, arg_list, node.call.func_name.span) - elif isinstance(func, SpecialStmt) and not func.def_symbol: - func.handle(node, self.context, arg_list, node.call.func_name.span) - return - - self.report_error( - "Unexpected statement. Expected an assert, an intrinsic, a with statement, or a " - f"special statement, but got {type(func).__name__}.", - node.call.func_name.span, - ) - - def transform_Slice(self, node): - """Index slice visitor.""" - start = self.transform(node.start) - end = self.transform(node.end) - if not ( - isinstance(node.step, ast.Constant) - and isinstance(node.step.value, int) - and node.step.value > 0 - ): - self.report_error( - "Only positive integer step size is supported for slices.", node.step.span - ) - return Slice(start, end, node.step.value, tvm_span_from_synr(node.span)) - - def transform_Subscript(self, node): - """Array access visitor. - - By now only 3 types of Subscript are supported: - 1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore) - Var[index] Buffer element access() - 2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...])) - 3. Array[index], Buffer element access - """ - - symbol = self.transform(node.params[0]) - if symbol is None: - self.report_error( - f"Variable {node.params[0].id.name} is not defined.", node.params[0].span - ) - - indexes = [self.transform(x) for x in node.params[1].values] - if isinstance(symbol, tvm.tir.expr.Var): - if symbol.dtype == "handle": - self.report_error( - "Cannot read directly from a handle, use `T.match_buffer` " - "to create a buffer to read from.", - node.params[0].span, - ) - if len(indexes) > 1: - self.report_error( - "Only a single index can be provided when indexing into a `var`.", - node.params[1].span, - ) - index = indexes[0] - if not isinstance(index, (tvm.tir.PrimExpr, int)): - self.report_error( - "Var load index should be an int or PrimExpr, but it is a" + type(index), - node.span, - ) - - self.report_error( - "Use of tir.Load has been deprecated in favor of tir.BufferLoad", node.span - ) - elif isinstance(symbol, tvm.tir.Buffer): - return BufferSlice( - symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span) - ) - elif isinstance(symbol, tvm.container.Array): - if len(indexes) > 1: - self.report_error( - "Array access should be one-dimension access, but the indices are " - + str(indexes), - node.span, - ) - index = indexes[0] - if not isinstance(index, (int, tvm.tir.expr.IntImm)): - self.report_error( - "Array access index expected int or IntImm, but got " + type(index), - node.span, - ) - if int(index) >= len(symbol): - self.report_error( - f"Array access out of bound, size: {len(symbol)}, got index {index}.", - node.span, - ) - return symbol[int(index)] - else: - self.report_error( - f"Cannot subscript from a {type(symbol).__name__}. Only variables and " - "buffers are supported.", - node.params[0].span, - ) - - def transform_Attr(self, node): - """Visitor for field access of the form `x.y`. - - This visitor is used to lookup function and symbol names. We have two - cases to handle here: - 1. If we have a statement of the form `tir.something`, then we lookup - `tir.something` in the `Registry`. If the function is not in the - registry, then we try to find a `tvm.ir.op.Op` with the same name. - 2. All other names `tvm.something` are lookup up in this current python - namespace. - """ - - def get_full_attr_name(node: ast.Attr) -> str: - reverse_field_names = [node.field.name] - while isinstance(node.object, ast.Attr): - node = node.object - reverse_field_names.append(node.field.name) - if isinstance(node.object, ast.Var): - reverse_field_names.append(node.object.id.name) - return ".".join(reversed(reverse_field_names)) - - if isinstance(node.object, (ast.Var, ast.Attr)): - full_attr_name = get_full_attr_name(node) - attr_object, fields = full_attr_name.split(".", maxsplit=1) - if self.match_tir_namespace(attr_object): - func_name = "tir." + fields - res = Registry.lookup(func_name) - if res is not None: - return res - try: - return tvm.ir.op.Op.get(func_name) - except TVMError as e: - # Check if we got an attribute error - if e.args[0].find("AttributeError"): - self.report_error(f"Unregistered function `tir.{fields}`.", node.span) - else: - raise e - - symbol = self.transform(node.object) - if symbol is None: - self.report_error("Unsupported Attribute expression.", node.object.span) - if not hasattr(symbol, node.field.name): - self.report_error( - f"Type {type(symbol)} does not have a field called `{node.field.name}`.", node.span - ) - res = getattr(symbol, node.field.name) - return res - - def transform_TypeAttr(self, node): - """Visitor for field access of the form `x.y` for types. - - We have two cases here: - 1. If the type is of the form `T.something`, we look up the type in - the `tir` namespace in this module. - 2. If the type is of the form `tvm.x.something` then we look up - `tvm.x.something` in this modules namespace. - """ - if isinstance(node.object, ast.TypeVar): - if self.match_tir_namespace(node.object.id.name): - if not hasattr(tir, node.field.name): - self.report_error( - f"Invalid type annotation `tir.{node.field.name}`.", node.span - ) - return getattr(tir, node.field.name) - - symbol = self.transform(node.object) - if symbol is None: - self.report_error("Unsupported Attribute expression", node.object.span) - if not hasattr(symbol, node.field): - self.report_error( - f"Type {type(symbol)} does not have a field called `{node.field}`.", node.span - ) - res = getattr(symbol, node.field) - return res - - def transform_DictLiteral(self, node): - """Dictionary literal visitor. - - Handles dictionary literals of the form `{x:y, z:2}`. - """ - - keys = [self.transform(key) for key in node.keys] - values = [self.transform(value) for value in node.values] - - return dict(zip(keys, values)) - - def transform_Tuple(self, node): - """Tuple visitor. - - Handles tuples of the form `(x, y, 2)`. - """ - - return tuple(self.transform(element) for element in node.values) - - def transform_ArrayLiteral(self, node): - """List literal visitor. - - Handles lists of the form `[x, 2, 3]`. - """ - - return [self.transform(element) for element in node.values] - - def transform_Var(self, node): - """Variable visitor - - Handles variables like `x` in `x = 2`. - """ - - name = node.id.name - if name == "meta": - return self.meta - symbol = Registry.lookup(name) - if symbol is not None: - return symbol - symbol = self.context.lookup_symbol(name) - if symbol is not None: - return symbol - self.report_error(f"Unknown identifier {name}.", node.span) - - def transform_TypeVar(self, node): - """Type variable visitor. - - Equivalent to `transform_Var` but for types. - """ - name = node.id.name - symbol = Registry.lookup(name) or self.context.lookup_symbol(name) - if symbol is not None: - return symbol - self.report_error(f"Unknown identifier {name}.", node.span) - - def transform_Constant(self, node): - """Constant value visitor. - - Constant values include `None`, `"strings"`, `2` (integers), `4.2` - (floats), and `true` (booleans). - """ - return tvm.runtime.convert(node.value, span=tvm_span_from_synr(node.span)) - - def transform_TypeConstant(self, node): - """Constant value visitor for types. - - See `transform_Constant`. - """ - if self._inside_buffer_sugar: - return self.transform_Constant(node) - - return node.value - - def transform_TypeTuple(self, node): - """Tuple value visitor for types. - - Mostly used in `transform_TypeCall` and `transform_TypeApply`. - """ - return [self.transform(value) for value in node.values] - - def transform_TypeCall(self, node): - """TypeCall visitor - - This occurs when an expression is used inside a T.Buffer - parameter annotation. - """ - - # ast.Call has the BuiltinOp as node.func_name.name, where - # ast.TypeCall has the BuiltinOp as node.func_name. So we can - # delegate to self.transform_Call, but the error messages for - # unsupported operations will highlight the entire expression - # and not just the function itself. - op = ast.Op(node.span, node.func_name) - call = ast.Call(node.span, op, node.params, node.keyword_params) - return self.transform_Call(call) - - def transform_TypeApply(self, node): - """Visitor for Type[Type] expressions. - - Mostly used for ``T.Ptr`` expressions. - """ - func = self.transform(node.func_name) - - if not isinstance(func, ty.TypeGeneric) or not hasattr(func, "__getitem__"): - self.report_error( - f"Use of type arguments requires a type that accepts type arguments (e.g. T.Ptr), " - f"but found {type(func).__name__} instead.", - node.span, - ) - - param_types = [] - for idx, param in enumerate(node.params): - param_type = self.transform(param) - if not isinstance(param_type, ty.TypeGeneric) and func.require_type_generic_at(idx): - self.report_error( - f"Expected a type but found {type(param).__name__} " - f"at {idx}th type argument", - param.span, - ) - - param_types.append(param_type) - - if len(param_types) == 1: - return func[param_types[0]] - else: - return func[param_types] - - def handle_match_buffer_type(self, node, buffer_name): - """special function to handle syntax sugar for match buffer. - - This method is for buffer declarations in the function parameters. - """ - func = self.transform(node.func_name) - assert isinstance(func, SpecialStmt) - - # parse args and kwargs for TypeCall and TypeApply - self._inside_buffer_sugar = True - try: - arg_list = self.parse_arg_list(func, node) - finally: - self._inside_buffer_sugar = False - - # Note that the third element in arg_list would always be the 'name' - # TODO: This index is hardcoded as a workaround. Better to make it programmatic - if arg_list[2] is None: - arg_list[2] = buffer_name - buf = func.handle(node, self.context, arg_list, node.func_name.span) - return buf - - def transform_Return(self, node): - self.report_error( - "TVM script does not support return statements. Instead the last statement in any " - "block is implicitly returned.", - node.span, - ) - - -def get_tir_namespace(script: Union[Callable, type]) -> List[str]: - assert inspect.isfunction(script) or inspect.isclass(script) - env: Dict[str, Any] = script.__globals__ - return [key for key in env.keys() if env[key] == tir] - - -def from_source( - input_func: Union[str, Callable], tir_prefix: Optional[List[str]] = None -) -> Union[PrimFunc, IRModule]: - """Parse function or string into PrimFunc or IRModule. - - If possible, pass the TVM script in as a function so that line numbers and - filename will be accurate. - - Parameters - ---------- - input_module : Union[str, Callable] - The python function to be parsed. - - tir_prefix : Optional[List[str]] - The tir prefix list. Only works for str input, default by "tir" and "T". - - Returns - ------- - output : Union[Function, Module] - The Function or Module in IR. - """ - if isinstance(input_func, str): - tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix - return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, {})) - elif inspect.isfunction(input_func): - _, start_line = inspect.getsourcelines(input_func) - env: Dict[str, Any] = input_func.__globals__ - namespace = [key for key in env.keys() if env[key] is tir] - _closure_vars = inspect.getclosurevars(input_func) - closure_vars = {**_closure_vars.nonlocals, **_closure_vars.globals} - parser = TVMScriptParser(start_line, namespace, closure_vars) - result = to_ast(input_func, TVMDiagnosticCtx(), parser) - return result - else: - raise TypeError("Only function definitions are supported.") - - -def ir_module(input_module=None, metadata=None) -> IRModule: - """Decorate a python class as tvm IRModule. - - Parameters - ---------- - input_module : type - The python class to be parsed. - - metadata : Optional[Union[str, DictAttrs]] - The metadata attributes to be parsed. - - Returns - ------- - mod : IRModule - The result IRModule. - """ - if metadata is not None: - from .relax.parser import RelaxTransformer as _RelaxTransformer - - _RelaxTransformer.update_meta(metadata) - - if input_module is None: - return functools.partial(ir_module, metadata=metadata) - - def _ir_module(input_module: type) -> IRModule: - if inspect.isclass(input_module): - func_dict = { - name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc) - } - mod = IRModule(func_dict, attrs=metadata) - mod = relax.transform.Normalize()(mod) - mod = relax.transform.ResolveGlobals()(mod) - # FIXME(@altanh): where is the source map? - return mod - - raise TypeError("Only class definitions are supported.") - - return _ir_module(input_module) diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 336575a93e97..a39bb3e738f1 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -68,6 +68,34 @@ DiagnosticBuilder Diagnostic::Help(Span span) { return DiagnosticBuilder(DiagnosticLevel::kHelp, span); } +DiagnosticBuilder Diagnostic::Bug(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kBug, loc); +} + +DiagnosticBuilder Diagnostic::Error(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kError, loc); +} + +DiagnosticBuilder Diagnostic::Warning(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kWarning, loc); +} + +DiagnosticBuilder Diagnostic::Note(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kNote, loc); +} + +DiagnosticBuilder Diagnostic::Help(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kHelp, loc); +} + +DiagnosticBuilder Diagnostic::Bug(const Object* loc) { return Bug(GetRef(loc)); } + +DiagnosticBuilder Diagnostic::Error(const Object* loc) { return Error(GetRef(loc)); } + +DiagnosticBuilder Diagnostic::Note(const Object* loc) { return Note(GetRef(loc)); } + +DiagnosticBuilder Diagnostic::Help(const Object* loc) { return Help(GetRef(loc)); } + /* Diagnostic Renderer */ TVM_REGISTER_NODE_TYPE(DiagnosticRendererNode); diff --git a/src/ir/type.cc b/src/ir/type.cc index 86dda2a27424..6fe46bfba88b 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -25,9 +25,10 @@ #include namespace tvm { -PrimType::PrimType(runtime::DataType dtype) { +PrimType::PrimType(runtime::DataType dtype, Span span) { ObjectPtr n = make_object(); n->dtype = dtype; + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index 5057d3dbf7ff..7ca8c132facd 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -48,6 +48,8 @@ Doc RelaxScriptPrinter::Print(const ObjectRef& node) { return tir::AsTVMScriptDoc(Downcast(node), "T", false); } else if (node->IsInstance()) { return Doc::StrLiteral(Downcast(node)); + } else if (node->IsInstance()) { + return VisitStructInfo(Downcast(node)); } else { return VisitNode(node); } @@ -269,7 +271,11 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::RuntimeDepShapeNode* op) { Doc RelaxScriptPrinter::VisitNode_(const relax::MatchShapeNode* op) { Doc doc; if (op->var.defined()) { - doc << Print(op->var) << PrintVarAnnotation(op->var) << " = "; + doc << Print(op->var); + if (const auto& sinfo = MatchStructInfo(op->var)) { + doc << ": " << Print(sinfo); + } + doc << " = "; } doc << "R.match_shape("; // TODO(@altanh): maybe op->pattern should just be a ShapeExpr? @@ -296,8 +302,8 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::VarBindingNode* op) { } } doc << Print(op->var); - if (print_annotation) { - doc << PrintVarAnnotation(op->var); + if (print_annotation && op->var->struct_info_.defined()) { + doc << ": " << Print(GetStructInfo(op->var)); } doc << " = " << Print(op->value); return doc; @@ -412,7 +418,7 @@ Doc RelaxScriptPrinter::VisitExpr_(const tir::MaxNode* op) { } Doc RelaxScriptPrinter::VisitType_(const relax::ShapeTypeNode* node) { - return Doc::Text("R.Shape"); + return Doc::Text("R.Shape(ndim=") << node->ndim << ")"; } Doc RelaxScriptPrinter::VisitType_(const relax::ObjectTypeNode* node) { @@ -420,24 +426,19 @@ Doc RelaxScriptPrinter::VisitType_(const relax::ObjectTypeNode* node) { } Doc RelaxScriptPrinter::VisitType_(const relax::DynTensorTypeNode* node) { - // NOTE: to print shape information, use PrintTensorAnnotation - return PrintTensorAnnotation(GetRef(node), NullOpt); + return Doc::Text("R.Tensor(ndim=") << node->ndim << ", dtype=" << PrintDType(node->dtype) << ")"; } Doc RelaxScriptPrinter::VisitType_(const relay::TupleTypeNode* node) { if (node->fields.empty()) { - return Doc::Text("R.Tuple()"); + return Doc::Text("R.Tuple"); } - Doc doc; - std::vector fields; for (Type ty : node->fields) { fields.push_back(Print(ty)); } - doc << "R.Tuple(" << Doc::Concat(fields) << ")"; - - return doc; + return Doc::Text("R.Tuple(") << Doc::Concat(fields) << ")"; } Doc RelaxScriptPrinter::VisitType_(const relay::FuncTypeNode* node) { @@ -578,10 +579,9 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& for (size_t i = 0; i < func->params.size(); ++i) { const relax::Var& var = func->params[i]; Doc param; - param << Print(var) << PrintVarAnnotation(var); + param << Print(var) << ": " << Print(GetStructInfo(var)); params.push_back(param); } - print_symbolic_shape_as_str_ = false; if (is_global) { ICHECK(symbolic_vars_.empty()); @@ -590,10 +590,13 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& // Step 2: print the function signature doc << "@R.function" << Doc::NewLine(); doc << "def " << name << "(" << Doc::Concat(params, Doc::Text(", ")) << ")"; - if (func->ret_type.defined()) { - doc << " -> " << Print(func->ret_type); + if (const auto& func_sinfo = MatchStructInfo(func)) { + StructInfo ret_sinfo = func_sinfo.value()->ret; + doc << " -> " << Print(ret_sinfo); } doc << ":" << Doc::NewLine(4); + // TODO(siyuan): Add printing of composite expression + print_symbolic_shape_as_str_ = false; // Step 3: print function attr Doc header_attr; @@ -643,69 +646,70 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& return doc; } -Doc RelaxScriptPrinter::PrintVarAnnotation(const relax::Var& var) { - // TODO(@altanh): we should consider moving annotation into binding - Doc doc; - Type annotation = var->checked_type_; - if (annotation.defined()) { - doc << ": "; - if (const relax::DynTensorTypeNode* tty = annotation.as()) { - doc << PrintTensorAnnotation(GetRef(tty), var->shape_); - } else if (const TupleTypeNode* tty = annotation.as()) { - doc << PrintTupleAnnotation(GetRef(tty), var->shape_); - } else { - doc << Print(annotation); +Doc RelaxScriptPrinter::VisitStructInfo_(const ObjectStructInfoNode* op) { + return Doc::Text("R.Object"); +} + +Doc RelaxScriptPrinter::VisitStructInfo_(const PrimStructInfoNode* op) { + // TODO(@relax-team): support PrimStructInfo printing and parsing + LOG(FATAL) << "Not allowed to print PrimStructInfo for now."; + return Doc::Text(""); +} + +Doc RelaxScriptPrinter::VisitStructInfo_(const ShapeStructInfoNode* op) { + if (op->values.defined()) { + std::vector fields; + for (const PrimExpr& field : op->values.value()) { + fields.push_back(Print(field)); } + return Doc::Text("R.Shape([") << Doc::Concat(fields, Doc::Text(", ")) << "])"; + } else { + return Doc::Text("R.Shape(ndim=") << op->ndim << ")"; } - return doc; } -Doc RelaxScriptPrinter::PrintTensorAnnotation(const relax::DynTensorType& ty, - const Optional& shape) { - Doc doc; - doc << "R.Tensor("; - // Print shape annotation - if (shape.defined()) { - doc << Print(Downcast(shape.value())); - } else { - doc << "None"; +Doc RelaxScriptPrinter::VisitStructInfo_(const TensorStructInfoNode* op) { + Doc doc = Doc::Text("R.Tensor"); + std::vector fields; + if (op->shape.defined()) { + fields.push_back(Print(op->shape.value())); } - // Print dtype annotation - if (!ty->dtype.is_void()) { - doc << ", dtype=" << PrintDType(ty->dtype); + if (!op->IsUnknownDtype()) { + fields.push_back(Doc::Text("dtype=") << PrintDType(op->dtype)); } - // Print ndim annotation only when it cannot be inferred from shape itself. - if (!shape.defined() || shape->IsInstance()) { - doc << ", ndim=" << ty->ndim; + if (!op->shape.defined() && !op->IsUnknownNdim()) { + fields.push_back(Doc::Text("ndim=") << op->ndim); + } + if (fields.size() > 0) { + doc << "(" << Doc::Concat(fields, Doc::Text(", ")) << ")"; } - doc << ")"; return doc; } -Doc RelaxScriptPrinter::PrintTupleAnnotation(const TupleType& ty, - const Optional& shape) { - Doc doc; - doc << "R.Tuple"; +Doc RelaxScriptPrinter::VisitStructInfo_(const TupleStructInfoNode* op) { + Doc doc = Doc::Text("R.Tuple"); std::vector fields; - if (!(shape.defined() && shape.value().as())) { - return Print(ty); - } - const TupleNode* shape_tuple = shape.value().as(); - for (size_t i = 0; i < ty->fields.size(); i++) { - if (const auto* tensor_field = ty->fields[i].as()) { - fields.push_back( - PrintTensorAnnotation(GetRef(tensor_field), shape_tuple->fields[i])); - } else if (const auto* tuple_field = ty->fields[i].as()) { - fields.push_back( - PrintTupleAnnotation(GetRef(tuple_field), shape_tuple->fields[i])); - } else { - fields.push_back(Print(ty->fields[i])); - } + for (const StructInfo& field : op->fields) { + fields.push_back(Print(field)); } doc << "(" << Doc::Concat(fields, Doc::Text(", ")) << ")"; return doc; } +Doc RelaxScriptPrinter::VisitStructInfo_(const FuncStructInfoNode* op) { + Doc doc = Doc::Text("R.Callable"); + std::vector params; + if (!op->IsOpaque()) { + for (const StructInfo& arg : op->params.value()) { + params.push_back(Print(arg)); + } + // Do not print derive_func. + return doc << "((" << Doc::Concat(params) << "), " << Print(op->ret) << ")"; + } else { + return doc; + } +} + Doc RelaxScriptPrinter::GetUniqueName(std::string prefix, std::string fallback = "x") { if (prefix.empty()) { prefix = fallback; diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 998d8d83da70..3039b1161367 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -242,7 +243,8 @@ namespace relax { class RelaxScriptPrinter : public relax::IRFunctor, public tir::ExprFunctor, public TypeFunctor, - public AttrFunctor { + public AttrFunctor, + public relax::StructInfoFunctor { public: explicit RelaxScriptPrinter(bool show_meta_data, TextMetaDataContext* meta) : show_meta_data_(show_meta_data), meta_(meta) {} @@ -303,10 +305,6 @@ class RelaxScriptPrinter : public relax::IRFunctor, Doc PrintIfStmt(const relax::Var& var, const relax::If& ite); Doc PrintFunctionDef(const Doc& name, const relax::Function& func, bool is_global); - Doc PrintVarAnnotation(const relax::Var& var); - Doc PrintTensorAnnotation(const relax::DynTensorType& ty, const Optional& shape); - Doc PrintTupleAnnotation(const TupleType& ty, const Optional& shape); - Doc VisitType_(const relax::ShapeTypeNode* node) override; Doc VisitType_(const relax::ObjectTypeNode* node) override; Doc VisitType_(const relax::DynTensorTypeNode* node) override; @@ -322,6 +320,16 @@ class RelaxScriptPrinter : public relax::IRFunctor, Doc VisitAttr_(const tir::IntImmNode* op) override; Doc VisitAttr_(const tir::FloatImmNode* op) override; + //------------------------------------ + // Overload of StructInfo printing functions + //------------------------------------ + Doc VisitStructInfo_(const ObjectStructInfoNode* op) override; + Doc VisitStructInfo_(const PrimStructInfoNode* op) override; + Doc VisitStructInfo_(const ShapeStructInfoNode* op) override; + Doc VisitStructInfo_(const TensorStructInfoNode* op) override; + Doc VisitStructInfo_(const TupleStructInfoNode* op) override; + Doc VisitStructInfo_(const FuncStructInfoNode* op) override; + Doc GetUniqueName(std::string prefix, std::string fallback); /*! diff --git a/src/relax/analysis/shape_analysis.cc b/src/relax/analysis/shape_analysis.cc new file mode 100644 index 000000000000..70ce5ac06e90 --- /dev/null +++ b/src/relax/analysis/shape_analysis.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file shape_analysis.cc + * + * \brief Utilities for shape analysis. + */ + +#include +#include + +namespace tvm { +namespace relax { + +bool CanProveShapeEqual(const Array& lhs, const Array& rhs, + arith::Analyzer* ana) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!ana->CanProveEqual(lhs[i], rhs[i])) return false; + } + return true; +} + +bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana) { + if (lhs.same_as(rhs)) return true; + auto* lhs_shape = lhs.as(); + auto* rhs_shape = rhs.as(); + + if (lhs_shape && rhs_shape) { + return CanProveShapeEqual(lhs_shape->values, rhs_shape->values, ana); + } else { + return false; + } +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc new file mode 100644 index 000000000000..06a2a2872bcd --- /dev/null +++ b/src/relax/analysis/struct_info_analysis.cc @@ -0,0 +1,933 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file struct_info_analysis.cc + * \brief Implementations of foundation struct info analysis + * + * \note Update this file when you added a new StructInfo. + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +//-------------------------- +// GetStaticType +//-------------------------- +class StaticTypeDeriver : public StructInfoFunctor { + public: + Type VisitStructInfo_(const ObjectStructInfoNode* op) final { return ObjectType(op->span); } + + Type VisitStructInfo_(const PrimStructInfoNode* op) final { + return PrimType(op->dtype, op->span); + } + + Type VisitStructInfo_(const ShapeStructInfoNode* op) final { + return ShapeType(op->ndim, op->span); + } + + Type VisitStructInfo_(const TensorStructInfoNode* op) final { + return DynTensorType(op->ndim, op->dtype); + } + + Type VisitStructInfo_(const TupleStructInfoNode* op) final { + Array fields = + op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + return TupleType(fields, op->span); + } + + Type VisitStructInfo_(const FuncStructInfoNode* op) final { + if (op->IsOpaque()) return PackedFuncType(op->span); + Array params = op->params.value().Map( + [this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + Type ret = this->VisitStructInfo(op->ret); + return FuncType(params, ret, {}, {}, op->span); + } +}; + +Type GetStaticType(const StructInfo& info) { return StaticTypeDeriver()(info); } + +TVM_REGISTER_GLOBAL("relax.analysis.GetStaticType").set_body_typed([](const StructInfo& info) { + return GetStaticType(info); +}); + +//-------------------------- +// GetLegacyShapeHint +//-------------------------- +// TODO(relax-team) remove this function after phasing out shape. +class LegacyShapeDeriver : public StructInfoFunctor(const StructInfo&)> { + public: + Optional VisitStructInfo_(const ObjectStructInfoNode* op) final { return NullOpt; } + + Optional VisitStructInfo_(const PrimStructInfoNode* op) final { return NullOpt; } + + Optional VisitStructInfo_(const ShapeStructInfoNode* op) final { return NullOpt; } + + Optional VisitStructInfo_(const TensorStructInfoNode* op) final { + if (op->shape.defined()) { + return op->shape; + } else { + return RuntimeDepShape(); + } + } + + Optional VisitStructInfo_(const TupleStructInfoNode* op) final { + bool valid = true; + Array fields = op->fields.Map([this, &valid](const StructInfo& sinfo) { + Optional shape = this->VisitStructInfo(sinfo); + valid &= shape.defined(); + return shape.value_or(Expr(nullptr)); + }); + + // recursively collect structinfo to make sure legacy shape is also well defined. + if (valid && fields.size() != 0) { + Tuple tuple(fields, op->span); + Array tuple_sinfo; + for (Expr field : tuple->fields) { + tuple_sinfo.push_back(GetStructInfo(field)); + } + UpdateStructInfo(tuple, TupleStructInfo(tuple_sinfo)); + return tuple; + } else { + return NullOpt; + } + } + + Optional VisitStructInfo_(const FuncStructInfoNode* op) final { return NullOpt; } +}; + +Optional GetLegacyShapeHint(const StructInfo& info) { return LegacyShapeDeriver()(info); } + +TVM_REGISTER_GLOBAL("relax.analysis.GetLegacyShapeHint").set_body_typed([](const StructInfo& info) { + return GetLegacyShapeHint(info); +}); + +//-------------------------- +// StructInfoFromType +//-------------------------- + +StructInfo StructInfoFromType(const Type& type) { + return StructInfoFromTypeLegacyShapeHint(type, NullOpt); +} + +StructInfo StructInfoFromTypeLegacyShapeHint(const Type& type, Optional shape_hint) { + if (type.as()) { + return ObjectStructInfo(type->span); + } else if (const PrimTypeNode* prim_type = type.as()) { + return PrimStructInfo(prim_type->dtype, prim_type->span); + } else if (const ShapeTypeNode* shape_type = type.as()) { + return ShapeStructInfo(shape_type->ndim, type->span); + } else if (const DynTensorTypeNode* tensor_type = type.as()) { + if (!shape_hint.defined() || shape_hint->IsInstance()) { + return TensorStructInfo(tensor_type->dtype, tensor_type->ndim); + } else { + return TensorStructInfo(shape_hint.value(), tensor_type->dtype); + } + } else if (const TupleTypeNode* tuple_type = type.as()) { + Array fields; + if (shape_hint.defined() && shape_hint.value()->IsInstance()) { + Array shape_hint_fields = Downcast(shape_hint.value())->fields; + ICHECK_EQ(shape_hint_fields.size(), tuple_type->fields.size()); + for (size_t i = 0; i < tuple_type->fields.size(); ++i) { + fields.push_back( + StructInfoFromTypeLegacyShapeHint(tuple_type->fields[i], shape_hint_fields[i])); + } + } else { + for (const Type& field : tuple_type->fields) { + fields.push_back(StructInfoFromType(field)); + } + } + return TupleStructInfo(fields, type->span); + } else if (const FuncTypeNode* func_type = type.as()) { + Array params = + func_type->arg_types.Map([](const Type& param) { return StructInfoFromType(param); }); + StructInfo ret = StructInfoFromType(func_type->ret_type); + return FuncStructInfo(params, ret, func_type->span); + } else { + LOG(FATAL) << "Unsupported type: " << type; + return StructInfo(); + } +} + +//-------------------------- +// EraseToWellDefined +//-------------------------- +class WellDefinedEraser : public StructInfoMutator, + public ExprMutatorBase, + public tir::ExprMutator { + public: + WellDefinedEraser(std::function(const tir::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, arith::Analyzer* ana) + : f_shape_var_map_(f_shape_var_map), f_var_map_(f_var_map), ana_(ana) {} + + StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) final { + bool has_undefined = false; + Optional> values; + + if (op->values.defined()) { + std::swap(has_undefined_, has_undefined); + values = op->values.value().Map([&](PrimExpr val) { return this->VisitPrimExpr(val); }); + std::swap(has_undefined_, has_undefined); + } + // erase symbolic shape if we have undefined. + if (!has_undefined) { + if (values.same_as(op->values)) { + return GetRef(op); + } else { + return ShapeStructInfo(values.value(), op->span); + } + } else { + return ShapeStructInfo(op->ndim, op->span); + } + } + + StructInfo VisitStructInfo_(const TensorStructInfoNode* op) final { + bool has_undefined = false; + Optional shape; + + if (op->shape.defined()) { + std::swap(has_undefined_, has_undefined); + shape = relax::ExprMutatorBase::VisitExpr(op->shape.value()); + std::swap(has_undefined_, has_undefined); + } + + // erase symbolic shape if we have undefined. + if (!has_undefined) { + if (shape.same_as(op->shape)) { + return GetRef(op); + } else { + if (shape.defined()) { + return TensorStructInfo(shape.value(), op->dtype, op->span); + } else { + return TensorStructInfo(op->dtype, op->ndim, op->span); + } + } + } else { + return TensorStructInfo(op->dtype, op->ndim, op->span); + } + } + + StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final { + // NOTE: we always require func struct info to be well-defined. + // + // All the occuring symbolic variables are defined in parameters' + // struct info annotations. So there is no needed to erase. + return GetRef(op); + } + + using relax::ExprMutatorBase::VisitExpr_; + using tir::ExprMutator::VisitExpr_; + + // connect things up + PrimExpr VisitPrimExpr(const PrimExpr& expr) { + // apply eager simplification + PrimExpr val = tir::ExprMutator::VisitExpr(expr); + if (!val.same_as(expr)) { + return ana_->Simplify(val); + } else { + return val; + } + } + + Expr VisitExpr_(const DataflowVarNode* var) final { + return VisitExpr_(static_cast(var)); + } + + Expr VisitExpr_(const VarNode* var) final { + Optional ret; + if (f_var_map_ != nullptr) { + ret = f_var_map_(GetRef(var)); + } + has_undefined_ = has_undefined_ || !ret.defined(); + if (ret.defined()) { + ICHECK(ret.as() || ret.as()) + << "Only allow Expr in StructInfo to be ShapeExpr or Var"; + } + return ret.value_or(GetRef(var)); + } + + PrimExpr VisitExpr_(const tir::VarNode* var) final { + Optional ret; + if (f_shape_var_map_ != nullptr) { + ret = f_shape_var_map_(GetRef(var)); + } + has_undefined_ = has_undefined_ || !ret.defined(); + + if (ret.defined()) { + PrimExpr value = ret.value(); + if (value->IsInstance()) { + return tvm::cast(DataType::Int(64), value); + } + ICHECK(value.dtype() == DataType::Int(64)) << "Can only provide i64 expressions in shape"; + return value; + } else { + return GetRef(var); + } + } + + private: + bool has_undefined_ = false; + std::function(const tir::Var& var)> f_shape_var_map_; + std::function(const Var& var)> f_var_map_; + arith::Analyzer* ana_; +}; + +StructInfo EraseToWellDefined( + const StructInfo& info, std::function(const tir::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return WellDefinedEraser(f_shape_var_map, f_var_map, &inst).VisitStructInfo(info); + } else { + return WellDefinedEraser(f_shape_var_map, f_var_map, ana).VisitStructInfo(info); + } +} + +StructInfo EraseToWellDefined(const StructInfo& info, Map shape_var_map, + Map var_map, arith::Analyzer* ana) { + std::function(const tir::Var& var)> f_shape_var_map = nullptr; + std::function(const Var& var)> f_var_map = nullptr; + + if (!shape_var_map.empty()) { + f_shape_var_map = [&](const tir::Var& var) -> Optional { + auto it = shape_var_map.find(var); + if (it != shape_var_map.end()) return (*it).second; + return NullOpt; + }; + } + + if (!var_map.empty()) { + f_var_map = [&](const Var& var) -> Optional { + auto it = var_map.find(var); + if (it != var_map.end()) return (*it).second; + return NullOpt; + }; + } + + return EraseToWellDefined(info, f_shape_var_map, f_var_map, ana); +} + +TVM_REGISTER_GLOBAL("relax.analysis.EraseToWellDefined") + .set_body_typed([](const StructInfo& info, Map shape_var_map, + Map var_map) { + return EraseToWellDefined(info, shape_var_map, var_map); + }); + +//-------------------------- +// IsBaseOf +//-------------------------- +class StructInfoBaseChecker + : public StructInfoFunctor { + public: + explicit StructInfoBaseChecker(arith::Analyzer* ana) : analyzer_(ana) {} + + BaseCheckResult VisitStructInfo(const StructInfo& lhs, const StructInfo& other) override { + // quick path + // Note: subclass may disable this quick path if we need to go over all struct info. + if (lhs.same_as(other)) return BaseCheckResult::kPass; + return StructInfoFunctor::VisitStructInfo(lhs, other); + } + + // Object is base of everything + BaseCheckResult VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { + return BaseCheckResult::kPass; + } + + BaseCheckResult VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + return lhs->dtype == rhs->dtype ? BaseCheckResult::kPass : BaseCheckResult::kFailL0; + } + + BaseCheckResult VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + // lhs have unknown ndim + if (lhs->IsUnknownNdim()) return BaseCheckResult::kPass; + + // ndim must match + if (lhs->ndim != rhs->ndim) { + if (rhs->IsUnknownNdim()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + + // lhs does not have symbolic value + if (!lhs->values.defined()) return BaseCheckResult::kPass; + // rhs does not have symbolic value but lhs do. + if (!rhs->values.defined()) return BaseCheckResult::kFailL2; + + // shape match check + return ShapeMatchCheck(lhs->values.value(), rhs->values.value()); + } + + BaseCheckResult VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + // dtype mismatch + if (!lhs->IsUnknownDtype() && lhs->dtype != rhs->dtype) { + if (rhs->IsUnknownDtype()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + + // ndim msiamtch + if (!lhs->IsUnknownNdim() && lhs->ndim != rhs->ndim) { + if (rhs->IsUnknownNdim()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + // lhs does not have defined shape and everything else matches + if (!lhs->shape.defined()) return BaseCheckResult::kPass; + // rhs does not have symbolic value but lhs don't + if (!rhs->shape.defined()) return BaseCheckResult::kFailL2; + + // shape match check + return ShapeMatchCheck(lhs->shape.value(), rhs->shape.value()); + } + + BaseCheckResult VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + return ArrayCheck(lhs->fields, rhs->fields); + } + + BaseCheckResult VisitStructInfo_(const FuncStructInfoNode* lhs, + const StructInfo& other) override { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + + // lhs opaque handling + if (lhs->IsOpaque()) { + if (lhs->derive_func.defined()) { + // function proving is best effort. + return lhs->derive_func.same_as(rhs->derive_func) ? BaseCheckResult::kPass + : BaseCheckResult::kFailL2; + } + // no derivation function, only depends on ret + return this->VisitStructInfo(lhs->ret, rhs->ret); + } + + // Function check is best effort. + // rhs is opaque but lhs is not + if (rhs->IsOpaque()) return BaseCheckResult::kFailL2; + + // NOTE: lhs->params, rhs->params may contain different symbolic + // vars that needs to be re-mapped to each other. + // This can only be done through structural equality check and not ArrayCheck. + // + // So we check structural equality here and if two are structurally + // equal return true. + // + // otherwise we do best effort BaseArrayCheck. + // + // This still does not handle cases where some arguments are sub of another + // while other parameters needs to get remapped. + // + // Given we only do best effort checking in these cases, and such cases + // are likely not a primary concern atm, we take this approach here. + if (struct_equal_(GetRef(lhs), other)) return BaseCheckResult::kPass; + + auto param_check = FuncParamsCheck(lhs->params.value(), rhs->params.value()); + auto ret_check = this->VisitStructInfo(lhs->ret, rhs->ret); + return CombineCheck(param_check, ret_check); + } + + protected: + // analyzer + arith::Analyzer* analyzer_; + // struct equal checker + StructuralEqual struct_equal_; + + // customizable functions. + /*! + * \brief Check symbolic shape value equivalence. + * \param lhs The left hand shape. + * \param rhs The right hand shape. + * \return CheckResult. + */ + virtual BaseCheckResult PrimValueMatchCheck(const PrimExpr& lhs, const PrimExpr& rhs) { + // get static shape checking right. + auto* int_lhs = lhs.as(); + auto* int_rhs = rhs.as(); + if (int_lhs && int_rhs) { + if (int_lhs->value == int_rhs->value) { + return BaseCheckResult::kPass; + } else { + return BaseCheckResult::kFailL0; + } + } + return analyzer_->CanProveEqual(lhs, rhs) ? BaseCheckResult::kPass : BaseCheckResult::kFailL2; + } + /*! + * \brief CheckShape value. + * \param lhs The left hand shape. + * \param rhs The right hand shape. + * \return CheckResult. + */ + virtual BaseCheckResult ShapeMatchCheck(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) return BaseCheckResult::kFailL0; + + BaseCheckResult ret = BaseCheckResult::kPass; + for (size_t i = 0; i < lhs.size(); ++i) { + auto cmp_ret = PrimValueMatchCheck(lhs[i], rhs[i]); + if (ret == BaseCheckResult::kFailL0) return ret; + ret = CombineCheck(cmp_ret, ret); + } + return ret; + } + + /*! + * \brief CheckShape value. + * \param lhs The left hand shape. + * \param rhs The right hand shape. + * \return Check result. + */ + virtual BaseCheckResult ShapeMatchCheck(const Expr& lhs, const Expr& rhs) { + if (lhs.same_as(rhs)) return BaseCheckResult::kPass; + auto* lhs_shape = lhs.as(); + auto* rhs_shape = rhs.as(); + if (lhs_shape && rhs_shape) { + return ShapeMatchCheck(lhs_shape->values, rhs_shape->values); + } else { + return BaseCheckResult::kFailL2; + } + } + + /*! + * \brief CheckShape function parameters. + * \param lhs The left hand params. + * \param rhs The right hand params. + * \return Check result. + */ + virtual BaseCheckResult FuncParamsCheck(const Array& lhs, + const Array& rhs) { + auto res = ArrayCheck(lhs, rhs); + // treat L1 failures in params checking as L2. + if (res == BaseCheckResult::kFailL1) res = BaseCheckResult::kFailL2; + return res; + } + // helper functions + /*! + * \brief Combine check results. + * \param lhs The left operand. + * \param rhs The righr operand. + * \return The check result. + */ + static BaseCheckResult CombineCheck(BaseCheckResult lhs, BaseCheckResult rhs) { + if (lhs == BaseCheckResult::kFailL0 || rhs == BaseCheckResult::kFailL0) { + return BaseCheckResult::kFailL0; + } + if (lhs == BaseCheckResult::kFailL1 || rhs == BaseCheckResult::kFailL1) { + return BaseCheckResult::kFailL1; + } + if (lhs == BaseCheckResult::kFailL2 || rhs == BaseCheckResult::kFailL2) { + return BaseCheckResult::kFailL2; + } + return BaseCheckResult::kPass; + } + + /*! + * \brief Generic helper function to check arrays. + * \param lhs The left operand. + * \param rhs The right operand. + */ + BaseCheckResult ArrayCheck(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) return BaseCheckResult::kFailL0; + BaseCheckResult ret = BaseCheckResult::kPass; + + for (size_t i = 0; i < lhs.size(); ++i) { + auto cmp_ret = this->VisitStructInfo(lhs[i], rhs[i]); + if (ret == BaseCheckResult::kFailL0) return ret; + ret = CombineCheck(cmp_ret, ret); + } + return ret; + } +}; + +BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived, + arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return StructInfoBaseChecker(&inst)(base, derived); + } else { + return StructInfoBaseChecker(ana)(base, derived); + } +} + +TVM_REGISTER_GLOBAL("relax.analysis.StructInfoBaseCheck") + .set_body_typed([](const StructInfo& base, const StructInfo& derived) -> int { + return static_cast(StructInfoBaseCheck(base, derived)); + }); + +bool IsBaseOf(const StructInfo& base, const StructInfo& derived, arith::Analyzer* ana) { + return StructInfoBaseCheck(base, derived, ana) == BaseCheckResult::kPass; +} + +TVM_REGISTER_GLOBAL("relax.StructInfoIsBaseOf") + .set_body_typed([](const StructInfo& base, const StructInfo& derived) { + return IsBaseOf(base, derived); + }); + +//-------------------------- +// DeriveStructInfo +//-------------------------- + +// NOTE: we are reusing StructInfoBaseChecker here to populate a mapping +// from the expressions in arg(rhs) to var in param. +class CallRetStructInfoDeriver : public StructInfoBaseChecker { + public: + explicit CallRetStructInfoDeriver(arith::Analyzer* ana) : StructInfoBaseChecker(ana) {} + + // No short cut, so we can recursively populate all pairs. + BaseCheckResult VisitStructInfo(const StructInfo& lhs, const StructInfo& other) final { + return StructInfoFunctor::VisitStructInfo(lhs, other); + } + + StructInfo Derive(const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { + // opaque derivation + if (finfo->IsOpaque()) { + if (finfo->derive_func.defined()) { + // derive using custom derivation function. + return finfo->derive_func.value()(call, ctx); + } else { + // directly return the normal value. + return finfo->ret; + } + } + + // Normal function signature derivation. + auto params = finfo->params.value(); + if (params.size() != call->args.size()) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "number of arguments and parameters mismatch:" + << " expected " << params.size() << ", given " << call->args.size()); + } + // Visit each param arg pair, check and populate the var map + for (size_t i = 0; i < params.size(); ++i) { + auto arg_sinfo = GetStructInfo(call->args[i]); + BaseCheckResult res = this->VisitStructInfo(params[i], arg_sinfo); + // Report error if we find L1 level failure + // L2 level is best effort so we don't report. + // The behavior of L2 can be customized later. + if (res == BaseCheckResult::kFailL0 || res == BaseCheckResult::kFailL1) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "Argument " << i << " type mismatch:" + << " expected " << params[i] << ", given " << arg_sinfo); + } + } + // map the ret using the populated var map. + return EraseToWellDefined(finfo->ret, shape_var_map_, var_map_); + } + + protected: + // Whether to populate map in params. + bool populate_mapping_{true}; + // for simplicity, we make these fields public so the user can access them. + Map shape_var_map_; + Map var_map_; + + using StructInfoBaseChecker::ShapeMatchCheck; + + // Match shape values in between param(lhs) and arg(rhs) + BaseCheckResult PrimValueMatchCheck(const PrimExpr& param, const PrimExpr& arg) final { + if (!populate_mapping_) { + return StructInfoBaseChecker::PrimValueMatchCheck(param, arg); + } + + if (auto* ptr = param.as()) { + auto var = GetRef(ptr); + auto it = shape_var_map_.find(var); + // not populated + if (it == shape_var_map_.end()) { + shape_var_map_.Set(var, arg); + return BaseCheckResult::kPass; + } else { + // Best effort prove. + PrimExpr mapped_value = (*it).second; + if (analyzer_->CanProveEqual(mapped_value, arg)) return BaseCheckResult::kPass; + return BaseCheckResult::kFailL2; + } + } else { + // Best effort + // Do not attempt to do prove when param contains a symbolic expr. + // such expression might depends on a later defined var in params created by dyn fusion. + // example: f(a: Tensor[(n+1)], s: Shape[(n,)]), the (n+1) case here. + return StructInfoBaseChecker::PrimValueMatchCheck(param, arg); + } + } + + BaseCheckResult ShapeMatchCheck(const Expr& lhs, const Expr& rhs) final { + if (!populate_mapping_) { + return StructInfoBaseChecker::ShapeMatchCheck(lhs, rhs); + } + + if (auto* ptr = lhs.as()) { + auto var = GetRef(ptr); + auto it = var_map_.find(var); + // not populated + if (it == var_map_.end()) { + var_map_.Set(var, rhs); + return BaseCheckResult::kPass; + } else { + // Best effort prove. + Expr mapped_value = (*it).second; + if (CanProveShapeEqual(mapped_value, rhs, analyzer_)) return BaseCheckResult::kPass; + return BaseCheckResult::kFailL2; + } + } + auto lhs_shape = lhs.as(); + auto rhs_shape = rhs.as(); + ICHECK(lhs_shape) << "lhs must have a shape"; + if (!rhs_shape) return BaseCheckResult::kFailL2; + return ShapeMatchCheck(lhs_shape->values, rhs_shape->values); + } + + BaseCheckResult FuncParamsCheck(const Array& lhs, + const Array& rhs) final { + // Set populate mapping to false + // so we do not pick up symbolic vars in params with function type. + // + // @R.function + // def f(g: R.Func([R.Tensor[(n,)]], R.Tensor[(n+1,)]), + // x: R.Tensor[(m,)]) -> R.Tensor[(m,)]: + // ... + // + // For example, in the above function f, we should avoid + // pick up n in g's signature. + bool populate_mapping = false; + std::swap(populate_mapping_, populate_mapping); + auto ret = StructInfoBaseChecker::FuncParamsCheck(lhs, rhs); + std::swap(populate_mapping_, populate_mapping); + return ret; + } +}; + +StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, + const BlockBuilder& ctx, arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return CallRetStructInfoDeriver(&inst).Derive(finfo, call, ctx); + } else { + return CallRetStructInfoDeriver(ana).Derive(finfo, call, ctx); + } +} + +TVM_REGISTER_GLOBAL("relax.analysis.DeriveCallRetStructInfo") + .set_body_typed([](const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { + return DeriveCallRetStructInfo(finfo, call, ctx); + }); + +//-------------------------- +// UnifyToLCA +//-------------------------- +class StructInfoLCAFinder + : public StructInfoFunctor { + public: + explicit StructInfoLCAFinder(arith::Analyzer* ana) : analyzer_(ana) {} + + StructInfo VisitStructInfo(const StructInfo& lhs, const StructInfo& other) final { + // quick path + if (lhs.same_as(other)) return lhs; + return StructInfoFunctor::VisitStructInfo(lhs, other); + } + + // Object is based of everything, unify to object. + StructInfo VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { + return GetRef(lhs); + } + + StructInfo VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + if (lhs->dtype == rhs->dtype) return GetRef(lhs); + // PrimType will be treated as their boxed(object) values + // as a result we can unify to object. + return ObjectStructInfo(lhs->span); + } + + StructInfo VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + + int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownDim; + if (lhs->ndim != rhs->ndim || !lhs->values.defined() || !rhs->values.defined() || + !CanProveShapeEqual(lhs->values.value(), rhs->values.value(), analyzer_)) { + // prefers return same when possible + if (!lhs->values.defined() && lhs->ndim == ndim) { + return GetRef(lhs); + } else { + return ShapeStructInfo(ndim, lhs->span); + } + } + // equals to each other + return GetRef(lhs); + } + + StructInfo VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + + // find the target dtype and ndim. + DataType dtype = lhs->dtype == rhs->dtype ? lhs->dtype : DataType::Void(); + int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownDim; + // if ndim mismatch or one side of shape is missing + // then we cannot keep in symbolic shape + if (lhs->ndim != rhs->ndim || !lhs->shape.defined() || !rhs->shape.defined() || + !CanProveShapeEqual(lhs->shape.value(), rhs->shape.value(), analyzer_)) { + // reuse lhs when possible + if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim) { + return GetRef(lhs); + } else { + return TensorStructInfo(dtype, ndim, lhs->span); + } + } + // symbolic shape match but dtype mismatch + if (lhs->dtype != dtype) { + return TensorStructInfo(lhs->shape.value(), dtype, lhs->span); + } else { + return GetRef(lhs); + } + } + + StructInfo VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + Optional> fields = UnifyArray(lhs->fields, rhs->fields); + // tuple length not the same. + if (!fields.defined()) return ObjectStructInfo(lhs->span); + + // same length tuple. + if (!fields.same_as(lhs->fields)) { + return TupleStructInfo(fields.value(), lhs->span); + } else { + return GetRef(lhs); + } + } + + StructInfo VisitStructInfo_(const FuncStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + + // lhs opaque handling + if (lhs->IsOpaque()) { + if (lhs->derive_func.defined()) { + if (lhs->derive_func.same_as(rhs->derive_func)) { + return GetRef(lhs); + } else { + // Create a new opaque with object return + return FuncStructInfo::OpaqueFunc(ObjectStructInfo(), lhs->span); + } + } else { + // no derivation function, only depends on ret + StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); + if (ret.same_as(lhs->ret)) return GetRef(lhs); + return FuncStructInfo::OpaqueFunc(ret, lhs->span); + } + } + // rhs is opaque, lhs is not + if (rhs->IsOpaque()) { + // unify ret value, note that rhs's ret is context free(because it is opaque) + // so result of the unify is also context-free. + StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); + return FuncStructInfo::OpaqueFunc(ret, lhs->span); + } + + // Both lhs and rhs are not opaque + // NOTE: lhs->params, rhs->params may contain different symbolic + // vars that needs to be re-mapped to each other. + // This can only be done through structural equality check. + // + // So we check structural equality here and if two are structurally + // equal return true. + // + // otherwise we do best effort of unify types without considering var remap. + // + // This still does not handle cases where some arguments are sub of another + // while other parameters needs to get remapped. + // + // Given we only do best effort checking in these cases, and such cases + // are likely not a primary concern atm, we take this approach here. + if (struct_equal_(GetRef(lhs), GetRef(rhs))) { + return GetRef(lhs); + } + + auto params = UnifyArray(lhs->params.value(), rhs->params.value()); + auto ret = this->VisitStructInfo(lhs->ret, rhs->ret); + + if (params.same_as(lhs->params) && ret.same_as(lhs->ret)) { + return GetRef(lhs); + } else { + // fail to unify the params + if (!params.defined()) { + return FuncStructInfo::OpaqueFunc(ret, lhs->span); + } else { + return FuncStructInfo(params.value(), ret, lhs->span); + } + } + } + + private: + // analyzer + arith::Analyzer* analyzer_; + // struct equal checker + StructuralEqual struct_equal_; + + // check arrays + Optional> UnifyArray(const Array& lhs, + const Array& rhs) { + if (lhs.same_as(rhs)) return lhs; + if (lhs.size() != rhs.size()) return NullOpt; + size_t index = 0; + return lhs.Map([&](const StructInfo& a) { return this->VisitStructInfo(a, rhs[index++]); }); + } +}; + +StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return StructInfoLCAFinder(&inst)(lhs, rhs); + } else { + return StructInfoLCAFinder(ana)(lhs, rhs); + } +} + +TVM_REGISTER_GLOBAL("relax.analysis.StructInfoLCA") + .set_body_typed([](const StructInfo& lhs, const StructInfo& rhs) { + return StructInfoLCA(lhs, rhs); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 0bbb207e9639..5a31e9aec8e1 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -20,6 +20,7 @@ /*! * \file relax/analysis/well_formed.cc * \brief Check if the IRModule is well-formed. + * * This pass is supposed to be applied to normalized Relax AST. * If it's malformed, messages will be logged as Warning. * This pass will check: @@ -52,6 +53,7 @@ #include #include #include +#include #include #include @@ -60,30 +62,16 @@ namespace tvm { namespace relax { -class WellFormedChecker; - -/*! \brief Helper to visit PrimExpr in the shape annotation and check if the symbolic vars in - * the PrimExpr are defined.*/ -class PrimExprVisitor : public tir::ExprVisitor { - public: - std::unordered_set symbolic_var_set_; - WellFormedChecker* checker_; - - explicit PrimExprVisitor(WellFormedChecker* checker) : checker_(checker) {} - - void VisitExpr_(const tir::VarNode* op); -}; - +// TODO(relax-team): Consider further refactor using +// Scope Frame to store manage the var context. +// /*! \brief Helper to implement well formed check.*/ -class WellFormedChecker : public relax::ExprVisitor { +class WellFormedChecker : public relax::ExprVisitor, + public relax::StructInfoVisitor, + public tir::ExprVisitor { public: - Optional diag_ctx; - bool well_formed = true; - explicit WellFormedChecker(const Optional& ctx) - : diag_ctx(ctx), prim_expr_visitor_(this) {} - void Malformed(Diagnostic diag) { well_formed = false; LOG(WARNING) << "This IR is not well formed: " << diag->message; @@ -91,29 +79,42 @@ class WellFormedChecker : public relax::ExprVisitor { void VisitExpr(const Expr& expr) override { if (!expr.as() && !expr->checked_type_.defined()) { - Malformed(Diagnostic::Error(expr->span) - << "The checked_type_ of Expr " << expr << " is nullptr."); + Malformed(Diagnostic::Error(expr) << "The checked_type_ of Expr " << expr << " is nullptr."); } - ExprVisitor::VisitExpr(expr); + relax::ExprVisitor::VisitExpr(expr); } void RegisterGlobalVar(GlobalVar var) { global_var_set_.insert(var); } private: + // Possible mode of visitor + enum class VisitMode { + /*! + * \brief Check all vars are well-defined + */ + kDefault, + /*! + * \brief Match define the vars on first occurance. + * Do not check the well-defined property of composite expr. + */ + kMatchVarDef + }; + void VisitExpr_(const GlobalVarNode* op) { GlobalVar var = GetRef(op); if (global_var_set_.count(var) == 0) { - Malformed(Diagnostic::Error(var->span) - << "GlobalVar " << op->name_hint << " is not defined."); + Malformed(Diagnostic::Error(var) << "GlobalVar " << op->name_hint << " is not defined."); } if (op->checked_type_.defined()) { if ((!op->checked_type_->IsInstance()) && (!op->checked_type_->IsInstance())) { - Malformed(Diagnostic::Error(var->span) << "The checked_type_ of GlobalVar " << op->name_hint - << " must be either FuncType or PackedFuncType."); + Malformed(Diagnostic::Error(var) << "The checked_type_ of GlobalVar " << op->name_hint + << " must be either FuncType or PackedFuncType."); } } + + CheckStructInfo(op); } void VisitExpr_(const TupleNode* op) { @@ -122,80 +123,79 @@ class WellFormedChecker : public relax::ExprVisitor { if (IsLeafExpr(expr)) { this->VisitExpr(expr); } else { - Malformed(Diagnostic::Error(expr->span) + Malformed(Diagnostic::Error(expr) << "Tuple is not in ANF form, field " << i << " gets " << expr->GetTypeKey()); } } - if (op->shape_) { - this->VisitExpr(Downcast(op->shape_.value())); - } + CheckStructInfo(op); } void VisitExpr_(const TupleGetItemNode* op) { if (IsLeafExpr(op->tuple)) { this->VisitExpr(op->tuple); } else { - Malformed(Diagnostic::Error(op->span) + Malformed(Diagnostic::Error(op) << "The tuple value in a TupleGetItem node must be a leaf expression."); } + CheckStructInfo(op); } void VisitExpr_(const VarNode* op) { Var var = GetRef(op); if (var_set_.count(var) == 0) { - Malformed(Diagnostic::Error(var->span) << "Var " << op->name_hint() << " is not defined."); + Malformed(Diagnostic::Error(var) << "Var " << op->name_hint() << " is not defined."); } + CheckStructInfo(op); } void VisitExpr_(const DataflowVarNode* op) { DataflowVar var = GetRef(op); if (!is_dataflow_) { - Malformed(Diagnostic::Error(var->span) + Malformed(Diagnostic::Error(var) << "DataflowVar " << op->name_hint() << " is used outside DataflowBlock."); } if (dataflow_var_set_.count(var) == 0) { - Malformed(Diagnostic::Error(var->span) - << "DataflowVar " << op->name_hint() << " is not defined."); + Malformed(Diagnostic::Error(var) << "DataflowVar " << op->name_hint() << " is not defined."); } + CheckStructInfo(op); } void VisitExpr_(const FunctionNode* op) { // save the var_set_ for local function - std::unordered_set previous_var_set_ = var_set_; - for (Var param : op->params) { - // register symbolic var defined in the shape annotation of function params - if (param->shape_) { - Expr var_shape = Downcast(param->shape_); - if (var_shape.as() || var_shape.as()) { - VisitExpr(var_shape); - } else { - for (PrimExpr expr : Downcast(var_shape)->values) { - if (expr.as()) { - prim_expr_visitor_.symbolic_var_set_.insert(Downcast(expr)); - } else { - prim_expr_visitor_(expr); - } - } - } + auto prev_var_set = var_set_; + auto prev_symbolic_var_set = symbolic_var_set_; + // symbolic var is not captured across function boundaries + symbolic_var_set_.clear(); + + // first populate defs in params + WithMode(VisitMode::kMatchVarDef, [&]() { + ICHECK(mode_ == VisitMode::kMatchVarDef); + for (Var param : op->params) { + relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param)); } + }); + // check all expr are well defined. + for (Var param : op->params) { this->VisitVarDef(param); } + if (auto seq = op->body.as()) { this->VisitSeqExpr(seq); } else { - Malformed(Diagnostic::Error(op->span) << "Function bodies must be sequence expressions"); + Malformed(Diagnostic::Error(op) << "Function bodies must be sequence expressions"); } - var_set_ = previous_var_set_; - prim_expr_visitor_.symbolic_var_set_.clear(); + + var_set_ = prev_var_set; + symbolic_var_set_ = prev_symbolic_var_set; } void VisitExpr_(const CallNode* op) { if (IsLeafExpr(op->op)) { this->VisitExpr(op->op); } else { - Malformed(Diagnostic::Error(op->span) << "The called expression must be a leaf expression"); + Malformed(Diagnostic::Error(op) << "The called expression must be a leaf expression"); } for (size_t i = 0; i < op->args.size(); i++) { Expr arg = op->args[i]; @@ -207,50 +207,48 @@ class WellFormedChecker : public relax::ExprVisitor { } } - if (op->shape_) { - this->VisitExpr(Downcast(op->shape_.value())); - } + CheckStructInfo(op); } void VisitExpr_(const IfNode* op) { if (IsLeafExpr(op->cond)) { this->VisitExpr(op->cond); } else { - Malformed(Diagnostic::Error(op->span) - << "The condition for an if node must be a leaf expression."); + Malformed(Diagnostic::Error(op) << "The condition for an if node must be a leaf expression."); } auto true_seq = op->true_branch.as(); auto false_seq = op->false_branch.as(); if (true_seq && false_seq) { - std::unordered_set previous_var_set_ = var_set_; - std::unordered_set previous_symbolic_var_set_ = - prim_expr_visitor_.symbolic_var_set_; + std::unordered_set previous_var_set = var_set_; + std::unordered_set previous_symbolic_var_set = + symbolic_var_set_; this->VisitSeqExpr(true_seq); - var_set_ = previous_var_set_; - prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_; + var_set_ = previous_var_set; + symbolic_var_set_ = previous_symbolic_var_set; this->VisitSeqExpr(false_seq); - var_set_ = previous_var_set_; - prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_; + var_set_ = previous_var_set; + symbolic_var_set_ = previous_symbolic_var_set; } else { - Malformed(Diagnostic::Error(op->span) << "If node branches must be seq exprs"); + Malformed(Diagnostic::Error(op) << "If node branches must be seq exprs"); } + CheckStructInfo(op); } void VisitExpr_(const ShapeExprNode* op) { for (PrimExpr expr : op->values) { // check if the symbolic vars in the expr are defined, e.g, 2 * m - prim_expr_visitor_(expr); + tir::ExprVisitor::VisitExpr(expr); if (!expr.dtype().is_int()) { - Malformed(Diagnostic::Error(expr->span) + Malformed(Diagnostic::Error(expr) << "Shape expressions must be of integer type, but got " << expr.dtype()); } } + CheckStructInfo(op); } void VisitExpr_(const SeqExprNode* op) { - Malformed(Diagnostic::Error(op->span) - << "SeqExpr only serves as the function body in FunctionNode, " - "or the true/false branch body in IfNode."); + Malformed(Diagnostic::Error(op) << "SeqExpr only serves as the function body in FunctionNode, " + "or the true/false branch body in IfNode."); } void VisitSeqExpr(const SeqExprNode* op) { @@ -260,9 +258,10 @@ class WellFormedChecker : public relax::ExprVisitor { this->VisitBindingBlock(block); } if (!IsLeafExpr(op->body)) { - Malformed(Diagnostic::Error(op->span) << "SeqExpr bodies must be leaf expressions."); + Malformed(Diagnostic::Error(op) << "SeqExpr bodies must be leaf expressions."); } this->VisitExpr(op->body); + CheckStructInfo(op); } void VisitBinding_(const VarBindingNode* binding) { @@ -272,14 +271,15 @@ class WellFormedChecker : public relax::ExprVisitor { void VisitBinding_(const MatchShapeNode* binding) { this->VisitExpr(binding->value); - for (PrimExpr expr : binding->pattern) { - if (expr.as()) { - // register symbolic var implicitly defined in the pattern of MatchShape - prim_expr_visitor_.symbolic_var_set_.insert(Downcast(expr)); - } else { - // check if the symbolic var in the expr are defined, e.g, 2 * m - prim_expr_visitor_(expr); + // define the vars + WithMode(VisitMode::kMatchVarDef, [&]() { + for (PrimExpr expr : binding->pattern) { + this->VisitStructInfoExprField(expr); } + }); + + for (PrimExpr expr : binding->pattern) { + this->VisitStructInfoExprField(expr); } if (binding->var.defined()) { @@ -298,26 +298,28 @@ class WellFormedChecker : public relax::ExprVisitor { void VisitVarDef_(const DataflowVarNode* var) { if (!is_dataflow_) { - Malformed(Diagnostic::Error(var->span) + Malformed(Diagnostic::Error(var) << "DataflowVar " << var->name_hint() << " is defined outside DataflowBlock."); } DataflowVar lv = GetRef(var); if (dataflow_var_set_.count(lv) == 1) { - Malformed(Diagnostic::Error(var->span) + Malformed(Diagnostic::Error(var) << "DataflowVar " << lv->name_hint() << " is defined more than once."); } // register DataflowVar dataflow_var_set_.insert(lv); + CheckStructInfo(var); } void VisitVarDef_(const VarNode* var) { Var gv = GetRef(var); if (var_set_.count(gv) == 1) { - Malformed(Diagnostic::Error(var->span) + Malformed(Diagnostic::Error(var) << "Var " << gv->name_hint() << " is defined more than once."); } // register Var var_set_.insert(gv); + CheckStructInfo(var); } void VisitVarDef(const Var& var) { @@ -328,29 +330,80 @@ class WellFormedChecker : public relax::ExprVisitor { } else { LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); } + } + + void VisitExpr_(const tir::VarNode* op) final { + tir::Var var = GetRef(op); + // default mode, check defined. + if (symbolic_var_set_.count(var) == 0) { + this->Malformed(Diagnostic::Error(var) + << "Symbolic Var " << var->name_hint << " is not defined."); + } + } + + void VisitStructInfoExprField(const Expr& expr) final { + if (mode_ == VisitMode::kMatchVarDef) { + // populate symbolic var in first occurance + if (auto* op = expr.as()) { + auto var = GetRef(op); + if (var_set_.count(var) == 0) { + var_set_.insert(var); + } + } + if (auto* shape = expr.as()) { + for (auto val : shape->values) { + this->VisitStructInfoExprField(val); + } + } + } else { + relax::ExprVisitor::VisitExpr(expr); + } + } + + void VisitStructInfoExprField(const PrimExpr& expr) final { + if (mode_ == VisitMode::kMatchVarDef) { + // populate symbolic var in first occurance + if (auto* op = expr.as()) { + auto var = GetRef(op); + if (symbolic_var_set_.count(var) == 0) { + symbolic_var_set_.insert(var); + } + } + } else { + tir::ExprVisitor::VisitExpr(expr); + } + } - if (var->shape_) { - VisitExpr(Downcast(var->shape_.value())); + void CheckStructInfo(const ExprNode* op) { + auto* sinfo = op->struct_info_.as(); + if (sinfo != nullptr) { + this->VisitStructInfo(GetRef(sinfo)); + } else { + Malformed(Diagnostic::Error(op) << "Expr must have struct_info populated. " + << " Expr.type_key=" << op->GetTypeKey()); } } + // Run callback with mode. + template + void WithMode(VisitMode mode, FType callback) { + std::swap(mode_, mode); + callback(); + std::swap(mode_, mode); + } + bool is_dataflow_ = false; + // Current visit mode. + VisitMode mode_ = VisitMode::kDefault; + // set of context variables. std::unordered_set global_var_set_; std::unordered_set var_set_; std::unordered_set dataflow_var_set_; - PrimExprVisitor prim_expr_visitor_; + std::unordered_set symbolic_var_set_; }; -void PrimExprVisitor::VisitExpr_(const tir::VarNode* op) { - tir::Var var = GetRef(op); - if (symbolic_var_set_.count(var) == 0) { - checker_->Malformed(Diagnostic::Error(var->span) - << "Symbolic Var " << var->name_hint << " is not defined."); - } -} - bool WellFormed(const IRModule& m, Optional diag_ctx) { - WellFormedChecker well_formed_checker = WellFormedChecker(diag_ctx); + WellFormedChecker well_formed_checker = WellFormedChecker(); for (const auto& it : m->functions) { // register GlobalVar in the IRModule first well_formed_checker.RegisterGlobalVar(it.first); diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index c439731467d7..5a84ae0a1621 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -21,9 +21,12 @@ * \file src/relax/block_builder.cc */ #include +#include #include #include #include +#include +#include #include #include #include @@ -128,9 +131,19 @@ class BlockBuilderImpl : public BlockBuilderNode { } GlobalVar gvar = GlobalVar(func_name); - ICHECK(func->checked_type_.defined()) - << "The function to be added does not have checked_type_."; - gvar->checked_type_ = func->checked_type_; + StructInfo finfo; + if (func->struct_info_.defined()) { + finfo = GetStructInfo(func); + } else if (auto* prim_func = func.as()) { + // NOTE: use a slightly different struct info than checked type + // in PrimFunc so handle can turn into Tensor. + // TODO(relax-team): add fine-grained PrimFunc struct info signature generation. + finfo = FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)); + } else { + finfo = StructInfoFromType(func->checked_type()); + } + UpdateStructInfo(gvar, finfo); + context_mod_->Add(gvar, func); ctx_func_dedup_map_->emplace(func, gvar); @@ -162,6 +175,16 @@ class BlockBuilderImpl : public BlockBuilderNode { } } + void ReportFatal(const Diagnostic& diagnostic) final { + // TODO(relax-team): Print more context information by looking + // into the diagnostic->loc and surrounding IRModule. + // We do not materialzie DiagnosticContext to avoid double referencing to + // the change IRModule in COW. Additionally, we need to be able to + // continue use the builder after an error is thrown to avoid state building up. + // in an interactive environment. + LOG(FATAL) << diagnostic->message; + } + //------------------------------- // Scope management //------------------------------- @@ -175,22 +198,52 @@ class BlockBuilderImpl : public BlockBuilderNode { void BeginBindingBlock() final { block_stack_.emplace_back(BlockFrame{{}, false}); } + void BeginScope(Optional> params) final { + // The current implementation handles the collection of shape var + // defined in parameter struct info annotations. The implementation + // is correct (since we will simply erase all relax Vars in EraseToWellDefined), + // but can be further improved. + // + // TODO(relax-team): Add support for relax Var in struct info annotations. + Map shape_var_map; + for (const Var& var : params.value_or(Array())) { + const Map& var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); + for (const auto& kv : var_map) { + const tir::Var& shape_var = kv.first; + const PrimExpr& shape_expr = kv.second; + auto it = shape_var_map.find(shape_var); + if (it == shape_var_map.end()) { + shape_var_map.Set(shape_var, shape_expr); + } else { + const PrimExpr& old_shape_expr = (*it).second; + CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr)) + << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " + << shape_expr; + } + shape_var_map.Set(shape_var, shape_expr); + } + } + scope_stack_.emplace_back(ScopeFrame({std::move(shape_var_map)})); + } + + void EndScope() final { scope_stack_.pop_back(); } + BindingBlock EndBlock() final { - BlockFrame* cur_frame = CurrentFrame(); + BlockFrame* cur_frame = CurrentBlockFrame(); BindingBlock ret = cur_frame->is_dataflow ? DataflowBlock(cur_frame->bindings) : BindingBlock(cur_frame->bindings); block_stack_.pop_back(); return ret; } - bool CurrentBlockIsDataFlow() final { return CurrentFrame()->is_dataflow; } + bool CurrentBlockIsDataFlow() final { return CurrentBlockFrame()->is_dataflow; } Var Emit(Expr expr, String name_hint) final { - return this->Emit(expr, CurrentFrame()->is_dataflow, name_hint); + return this->Emit(expr, CurrentBlockFrame()->is_dataflow, name_hint); } Var Emit(VarBinding binding) final { - BlockFrame* cur_frame = CurrentFrame(); + BlockFrame* cur_frame = CurrentBlockFrame(); if (cur_frame->is_dataflow) { ICHECK(binding->var.as()) << "Emit can only be used for local bindings in a dataflow block, use EmitOutput for " @@ -204,20 +257,17 @@ class BlockBuilderImpl : public BlockBuilderNode { Var EmitMatchShape(Expr value, Array pattern, String name_hint) final { value = this->Normalize(value); - BlockFrame* cur_frame = CurrentFrame(); + BlockFrame* cur_frame = CurrentBlockFrame(); Var var = CreateVar(cur_frame->is_dataflow, name_hint); - if (value->checked_type().as()) { - UpdateType(var, ShapeType()); - } else if (const DynTensorTypeNode* tty = value->checked_type().as()) { - ShapeExpr shape = ShapeExpr(pattern); - UpdateShape(var, shape); - DataType dtype = tty->dtype; - UpdateType(var, DynTensorType(pattern.size(), dtype)); + if (value->struct_info_.as()) { + UpdateStructInfo(var, ShapeStructInfo(pattern.size())); + } else if (const auto* tensor_sinfo = value->struct_info_.as()) { + UpdateStructInfo(var, TensorStructInfo(ShapeExpr(pattern), tensor_sinfo->dtype)); } else { - this->diag_ctx_.EmitFatal( - Diagnostic::Error(value->span) - << "The value passed to EmitMatchShape must be of DynTensorType or ShapeType."); + this->ReportFatal( + Diagnostic::Error(value) + << "The value passed to EmitMatchShape must be of TensorStructInfo or ShapeStructInfo."); } MatchShape match_shape = MatchShape(value, pattern, var); @@ -228,7 +278,7 @@ class BlockBuilderImpl : public BlockBuilderNode { } Var EmitMatchShape(MatchShape binding) final { - BlockFrame* cur_frame = CurrentFrame(); + BlockFrame* cur_frame = CurrentBlockFrame(); // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. cur_frame->bindings.push_back(binding); @@ -236,7 +286,7 @@ class BlockBuilderImpl : public BlockBuilderNode { } Var EmitOutput(Expr output, String name_hint) final { - BlockFrame* cur_frame = CurrentFrame(); + BlockFrame* cur_frame = CurrentBlockFrame(); ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; @@ -244,7 +294,7 @@ class BlockBuilderImpl : public BlockBuilderNode { } Var EmitOutput(VarBinding binding) final { - BlockFrame* cur_frame = CurrentFrame(); + BlockFrame* cur_frame = CurrentBlockFrame(); ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; ICHECK(!binding->var.as()) << "EmitOutput can only emit Var bindings."; @@ -255,7 +305,7 @@ class BlockBuilderImpl : public BlockBuilderNode { } void EmitNormalized(Binding binding) final { - BlockFrame* cur_frame = CurrentFrame(); + BlockFrame* cur_frame = CurrentBlockFrame(); if (auto* var_binding = binding.as()) { if (!cur_frame->is_dataflow) { @@ -294,17 +344,32 @@ class BlockBuilderImpl : public BlockBuilderNode { /*! * \brief Binding map used by normalizer. * - * \note The normalizer only caches reuse in the current scope + * \note The normalizer only caches reuse in the current block scope * and will not cache bindings from parent scope. */ std::unordered_map normalize_binding_map; }; + /*! + * \brief A representation of a scope frame. + * + * A scope frame records tracks the context of current scope. + */ + struct ScopeFrame { + // NOTE: for simplicity, only tracks symbolic var for now + // the scope is only used for erasure, so less information means + // more conservative analysis. + // Consider impl alternative: merge with block frame if we have more frame kinds. + // + // TODO(relax-team) tracks the var defined also through match-cast. + /*! \brief set of defined symbolic vars, value as themself. */ + Map shape_var_map; + }; /*! \brief A stack to store block frames. */ std::vector block_stack_; - /*! \brief A diagnostic context for reporting errors. */ - DiagnosticContext diag_ctx_ = DiagnosticContext::Default(IRModule({}, {})); + /*! \brief A stack to store scope frames. */ + std::vector scope_stack_; /*! \brief A binding table that maps var to value. */ std::unordered_map binding_table_; @@ -324,11 +389,20 @@ class BlockBuilderImpl : public BlockBuilderNode { * or other scope calls this value can change if the block stack get updated, * then the block frame is no longer valid. */ - BlockFrame* CurrentFrame() { + BlockFrame* CurrentBlockFrame() { ICHECK(!block_stack_.empty()) << "no block is being built"; return &block_stack_.back(); } + /*! + * \return The current scope frame. + * \note only use this value + */ + ScopeFrame* CurrentScopeFrame() { + ICHECK(!scope_stack_.empty()) << "no scope is being opened"; + return &scope_stack_.back(); + } + /*! * \brief Emits an Expr, and returns the variable it is bound to. * \param expr The Expr to be emitted. @@ -344,10 +418,9 @@ class BlockBuilderImpl : public BlockBuilderNode { Var var = CreateVar(is_dataflow, name_hint); // set the values - UpdateType(var, expr->checked_type_); - UpdateShape(var, expr->shape_); + UpdateStructInfo(var, Downcast(expr->struct_info_.value())); - CurrentFrame()->bindings.push_back(VarBinding(var, expr)); + CurrentBlockFrame()->bindings.push_back(VarBinding(var, expr)); // update the binding table binding_table_[var->vid] = expr; @@ -390,6 +463,42 @@ class BlockBuilderImpl : public BlockBuilderNode { ctx_func_dedup_map_->emplace(func, gv); } } + + // Collect all the variables that a parameter var can define. + // The collector is used to making sure that we record the + // shape vars as defined when calling BeginScope(params) + class StructInfoVarCollector : public StructInfoVisitor { + public: + static Map Collect(const StructInfo& struct_info) { + StructInfoVarCollector collector; + collector(struct_info); + return collector.shape_var_map_; + } + + private: + void VisitStructInfo_(const TensorStructInfoNode* op) final { + if (const auto* shape_expr = op->shape.as()) { + for (const PrimExpr& s : shape_expr->values) { + // Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1)) + if (const auto* var = s.as()) { + shape_var_map_.Set(GetRef(var), s); + } + } + } + } + + void VisitStructInfo_(const ShapeStructInfoNode* op) final { + for (const PrimExpr& s : op->values.value_or(Array())) { + // Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1)) + if (const auto* var = s.as()) { + shape_var_map_.Set(GetRef(var), s); + } + } + } + + private: + Map shape_var_map_; + }; }; //--------------------------------------- @@ -399,6 +508,13 @@ class BlockBuilderImpl : public BlockBuilderNode { Expr VisitExpr_(const OP* op) final { return GetRef(op); } // TODO(relax-team): Check normalize logic after struct info. + +// Normalizer on struct info: +// +// We take benefit of the following invariants(that are checked in constructor): +// - If an expr appears in StructInfo, then it is already normalized. +// As a result, we do not need to peek into StructInfo in Normalization. +// - Constant, ShapeExpr, already have their StructInfo populated in constructing time. class Normalizer : public BlockBuilderImpl, private ExprFunctor { public: explicit Normalizer(IRModule context_mod) : BlockBuilderImpl(context_mod) {} @@ -407,11 +523,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorVisitExpr(expr); // Invariant: // After Normalize: an Expr always have - // checked_type (with the exception of Op). + // struct_info (with the exception of Op). if (!normalized->IsInstance()) { - ICHECK(normalized->checked_type_.defined()) - << "The checked_type_ of an Expr except OpNode after " - "normalization must not be nullptr. However, this Expr does not have checked_type_: " + ICHECK(normalized->struct_info_.defined()) + << "The struct_info_ of an Expr except OpNode after " + "normalization must not be nullptr. However, this Expr does not have struct_info_: " << normalized; } @@ -425,14 +541,16 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorIsInstance()) return arg; + if (auto* prim_func = arg.as()) { + return NormalizePrimFunc(GetRef(prim_func)); + } if (!block_stack_.empty()) { // cache lookup - BlockFrame* cur_frame = CurrentFrame(); + BlockFrame* cur_frame = CurrentBlockFrame(); auto it = cur_frame->normalize_binding_map.find(arg); if (it != cur_frame->normalize_binding_map.end()) { return it->second; @@ -446,7 +564,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorEmit(post, ""); // NOTE: current frame addr can change due to underlying vector // re-allocation, redo lookup - CurrentFrame()->normalize_binding_map[arg] = var; + CurrentBlockFrame()->normalize_binding_map[arg] = var; return var; } else { return post; @@ -462,34 +580,37 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor Expr VisitVar_(const typename T::ContainerType* var) { - bool shape_unchanged = true; - Expr new_shape; - if (var->shape_) { - new_shape = this->VisitExpr(Downcast(var->shape_.value())); - shape_unchanged &= new_shape.same_as(var->shape_); - } - - if (shape_unchanged) { - return GetRef(var); - } else { - Var new_var = T(var->vid, NullOpt, var->checked_type_, var->span); - UpdateShape(new_var, new_shape); - return new_var; - } + // Parameters and free-vars must be present with struct info + // Other vars must have already been normalized through binding + ICHECK(var->struct_info_.defined()) + << "Var " << var->name_hint() << " does not have struct info."; + return GetRef(var); } Expr VisitExpr_(const VarNode* var) final { return VisitVar_(var); } Expr VisitExpr_(const DataflowVarNode* var) final { return VisitVar_(var); } + // Temp patch to ensure we handle inline PrimFunc case. + // TODO(relax-team) remove such cases from parser and testcases. + Expr NormalizePrimFunc(tir::PrimFunc prim_func) { + if (!prim_func->struct_info_.defined()) { + auto finfo = FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)); + UpdateStructInfo(prim_func, finfo); + } + return prim_func; + } + Expr VisitExpr(const Expr& expr) final { // Temp patch to ensure we handle inline PrimFunc case. // TODO(relax-team) remove such cases from parser and testcases. - if (expr->IsInstance()) return expr; + if (auto* prim_func = expr.as()) { + return NormalizePrimFunc(GetRef(prim_func)); + } // lookup normalize map if (!block_stack_.empty()) { - BlockFrame* cur_frame = CurrentFrame(); + BlockFrame* cur_frame = CurrentBlockFrame(); auto it = cur_frame->normalize_binding_map.find(expr); if (it != cur_frame->normalize_binding_map.end()) { return it->second; @@ -517,55 +638,33 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor new_fields; + for (const Expr& field : op->fields) { Expr new_field = this->NormalizeArgument(field); new_fields.push_back(new_field); unchanged &= new_field.same_as(field); } - Tuple tuple = unchanged ? GetRef(op) : Tuple(new_fields); - - // only do shape/type inference if the Tuple does not have shape/type - if (tuple->shape_ && tuple->checked_type_.defined()) { - return tuple; - } - // Tuple's checked_type must not be null - if (!tuple->checked_type_.defined()) { - Array tuple_type; + Tuple tuple = unchanged ? GetRef(op) : Tuple(new_fields, op->span); + // Update tuple fields. + if (!tuple->struct_info_.defined()) { + Array tuple_sinfo; for (Expr field : tuple->fields) { - ICHECK(field->checked_type_.defined()) - << "The checked_type_ of the field " << field << " of Tuple has not propagated."; - tuple_type.push_back(field->checked_type_); + tuple_sinfo.push_back(GetStructInfo(field)); } - UpdateType(tuple, TupleType(tuple_type)); - } - - // NOTE: Tuple's shape can be null - // When a tuple consists of all DynTensorType elements or nested tuple of DynTensorTypes, - // it has a shape. - if (!tuple->shape_.defined()) { - UpdateShape(tuple, GetTupleShape(tuple)); - } - - // TODO(relax-team): revisit after struct info. - // recurse into its shape in case its shape also need to be normalized - if (tuple->shape_ && tuple->shape_.value()->IsInstance()) { - this->VisitExpr(Downcast(tuple->shape_.value())); + UpdateStructInfo(tuple, TupleStructInfo(tuple_sinfo, op->span)); } - return tuple; } Expr VisitExpr_(const FunctionNode* op) final { - Expr new_body = this->VisitWithNewScope(op->body); - Function func; + Expr new_body = this->VisitWithNewScope(op->body, op->params); + if (new_body.same_as(op->body)) { - func = GetRef(op); + return GetRef(op); } else { - func = Function(op->params, new_body, op->ret_type, op->ret_shape, op->attrs); + return Function(op->params, new_body, op->ret_type, op->ret_shape, op->attrs); } - // NOTE: the shape_ of Function is left as null for now - return func; } Expr VisitExpr_(const CallNode* op) final { @@ -587,28 +686,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorattrs, op->type_args); } - // only do shape/type inference if the Call does not have shape/type - if (call->shape_.defined() && call->checked_type_.defined()) { - return call; + if (!call->struct_info_.defined()) { + auto inferred_sinfo = InferStructInfo(call); + UpdateStructInfo(call, inferred_sinfo); } - // Update the type prior to updating the shape, since the shape inference may need the updated - // type in cases of Call for ExternFunc. - if (!call->checked_type_.defined()) { - // type inference - auto inferred_type = InferType(call, this->diag_ctx_, this->context_mod_); - UpdateType(call, inferred_type); - } - - if (!call->shape_) { - // shape inference - auto inferred_shape = InferShape(call, this->diag_ctx_, this->context_mod_); - if (inferred_shape) { - UpdateShape(call, inferred_shape.value()); - } - } - - CheckShapeTypeConsistency(call->shape_, call->checked_type_); return call; } @@ -640,20 +722,12 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor(op); } else { - seq_expr = SeqExpr(normalized_blocks, new_body); + seq_expr = SeqExpr(normalized_blocks, new_body, op->span); } // only do shape/type inference if the SeqExpr does not have shape/type - if (seq_expr->shape_ && seq_expr->checked_type_.defined()) { - return seq_expr; - } - - if (!seq_expr->shape_ && seq_expr->body->shape_) { - UpdateShape(seq_expr, seq_expr->body->shape_); - } - - if (!seq_expr->checked_type_.defined() && seq_expr->body->checked_type_.defined()) { - UpdateType(seq_expr, seq_expr->body->checked_type_); + if (!seq_expr->struct_info_.defined()) { + UpdateStructInfo(seq_expr, EraseToWellDefinedInScope(GetStructInfo(seq_expr->body))); } return seq_expr; } @@ -668,25 +742,13 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorfalse_branch)) { if_node = GetRef(op); } else { - if_node = If(new_cond, new_true, new_false); - } - - if (!op->checked_type_.defined()) { - ICHECK(new_true->checked_type_.defined() && new_false->checked_type_.defined()) - << "The checked_type_ of true and false branches must not be nullptr."; - UpdateType(if_node, FindLCA(new_true->checked_type_, new_false->checked_type_)); + if_node = If(new_cond, new_true, new_false, op->span); } - - if (!op->shape_.defined()) { - if (new_true->shape_ && new_false->shape_ && - this->ShapeStructEqual(Downcast(new_true->shape_.value()), - Downcast(new_false->shape_.value()))) { - UpdateShape(if_node, new_true->shape_); - } else { - UpdateShape(if_node, RuntimeDepShape()); - } + if (!if_node->struct_info_.defined()) { + auto true_info = EraseToWellDefinedInScope(GetStructInfo(new_true)); + auto false_info = EraseToWellDefinedInScope(GetStructInfo(new_false)); + UpdateStructInfo(if_node, StructInfoLCA(true_info, false_info)); } - return if_node; } @@ -696,23 +758,10 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctortuple) ? GetRef(op) : TupleGetItem(new_tuple, op->index); - // only do shape/type inference if the TupleGetItem does not have shape/type - if (node->shape_ && node->checked_type_.defined()) { - return node; - } - - if (!node->checked_type_.defined()) { - const TupleTypeNode* tuple_type = node->tuple->checked_type_.as(); - ICHECK(tuple_type) << "The checked_type_ of Tuple must be TupleTypeNode."; - UpdateType(node, tuple_type->fields[node->index]); - } - - if (!node->shape_ && node->tuple->shape_) { - // TODO(@prakalp, @yuchen): assign the shape_ to RuntimeDepShape when we cannot obtain the - // field - if (const TupleNode* shape = node->tuple->shape_.as()) { - UpdateShape(node, shape->fields[node->index]); - } + if (!node->struct_info_.defined()) { + auto opt = MatchStructInfo(node->tuple); + ICHECK(opt) << "The struct info of Tuple must be TupleStructInfo."; + UpdateStructInfo(node, opt.value()->fields[node->index]); } return node; @@ -727,21 +776,26 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorVisitExpr(binding->value); - if (new_value.same_as(binding->value) || new_value.same_as(binding->var)) { - // if new_value = binding->var, then we have found an ANF binding site, so just return it - return binding; + if (!new_value.same_as(binding->value)) { + binding = VarBinding(binding->var, new_value, binding->span); } - return VarBinding(binding->var, new_value); + if (!binding->var->struct_info_.defined()) { + UpdateStructInfo(binding->var, GetStructInfo(new_value)); + } + return binding; } - MatchShape VisitMatchShape(const MatchShape& binding) { + MatchShape VisitMatchShape(MatchShape binding) { Expr new_value = this->VisitExpr(binding->value); - if (new_value.same_as(binding->value)) { - return binding; + if (!new_value.same_as(binding->value)) { + binding = MatchShape(new_value, binding->pattern, binding->var, binding->span); + } + if (binding->var.defined() && !binding->var->struct_info_.defined()) { + UpdateStructInfo(binding->var, GetStructInfo(new_value)); } - return MatchShape(new_value, binding->pattern, binding->var); + return binding; } BindingBlock VisitBindingBlock(const BindingBlock& block) { @@ -767,221 +821,67 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor()) { - return std::all_of(shape_expr->values.begin(), shape_expr->values.end(), - [](const PrimExpr& e) { return e->IsInstance(); }); - } else if (const auto* shape_tuple = shape.as()) { - return std::all_of(shape_tuple->fields.begin(), shape_tuple->fields.end(), - [&](const Expr& e) { return IsConstantShapes(e); }); + // Helper function to infer the type of a Call. + StructInfo InferStructInfo(const Call& call) { + if (auto* op_ptr = call->op.as()) { + // Case 1: the op field is a primitive op, look up FInferStructInfo attribute + Op op = GetRef(op_ptr); + ICHECK(op_map_infer_struct_info_.count(op)) + << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; + return op_map_infer_struct_info_[op](call, GetRef(this)); } else { - return false; + // derive using function parameters + ICHECK(call->op->struct_info_.defined()); + auto opt = MatchStructInfo(call->op); + ICHECK(opt) << "Call->op must contains a function struct info"; + FuncStructInfo finfo = opt.value(); + return DeriveCallRetStructInfo(finfo, call, GetRef(this), &analyzer_); } } - // Helper function to infer the shape of a Call. - Optional InferShape(const Call& call, DiagnosticContext diag_ctx, IRModule ctx_mod) { - if (call->op.as()) { - std::function f_create_type = [&f_create_type](const Type& type) -> Expr { - if (!type.defined() || type->IsInstance() || - type->IsInstance() || type->IsInstance()) { - return Expr(); - } - if (const auto* tuple_type = type.as()) { - if (tuple_type->fields.size() == 0) { - // VoidType (i.e. empty TupleType) does not have shape - return Expr(); - } - Array fields; - fields.reserve(tuple_type->fields.size()); - for (const Type& field_type : tuple_type->fields) { - fields.push_back(f_create_type(field_type)); - } - return Tuple(fields); - } else if (type->IsInstance()) { - return RuntimeDepShape(); - } else { - LOG(FATAL) << "Unsupported relax type: " << type->GetTypeKey(); - throw; - } - }; - return f_create_type(call->checked_type_); - } else if (call->op.as()) { - // primitive op: look up FInferShape attribute - Op op = Downcast(call->op); - if (op_map_infer_shape_.count(op)) { - return op_map_infer_shape_[op](call, diag_ctx); - } - } else if (const auto* gv = call->op.as()) { - // global function: find the function's shape_ - auto it_func = ctx_mod->functions.find(GetRef(gv)); - - if (it_func != ctx_mod->functions.end()) { - if (const auto* func = (*it_func).second.as()) { - if (!func->body.defined()) { - return func->ret_shape; - } - // TODO(relax-team): migrate shape deduction to `ret_shape` - Expr func_shape = Downcast(func->body->shape_); - if (IsConstantShapes(func_shape)) { - // Case 1. Nested tuples of constant shapes - return func_shape; - } else { - // TODO(@yuchen): add deducer for other cases - return RuntimeDepShape(); - } - } - } - // TODO(@yuchen): add this check after normalization in parser - // else { - // LOG(FATAL) << "ValueError: Cannot find function " << gv->name_hint - // << " in the context IRModule."; - // } - } else if (const auto* var = call->op.as()) { - if (var->shape_) { - return Downcast(var->shape_.value()); - } - Optional val = this->LookupBinding(GetRef(var)); - if (const auto* func_node = val.value().as()) { - Function func = GetRef(func_node); - if (func->ret_type.as()) { - Expr func_shape = Downcast(func_node->body->shape_); - if (IsConstantShapes(func_shape)) { - return func_shape; - } else { - // TODO(@yuchen, @yongwww): add deducer for other cases - return RuntimeDepShape(); - } - } - } - } else { - LOG(FATAL) << "ValueError: Failed to do shape inference for " << call->op->GetTypeKey(); + // erase to well defined within current scope. + StructInfo EraseToWellDefinedInScope(StructInfo info) { + if (scope_stack_.empty()) { + return EraseToWellDefined(info); } - - return NullOpt; + auto* curr_scope = CurrentScopeFrame(); + auto f_shape_var_map = [curr_scope](tir::Var var) -> Optional { + auto it = curr_scope->shape_var_map.find(var); + if (it != curr_scope->shape_var_map.end()) return (*it).second; + return NullOpt; + }; + return EraseToWellDefined(info, f_shape_var_map); } - // Helper function to infer the type of a Call. - Type InferType(const Call& call, DiagnosticContext diag_ctx, IRModule ctx_mod) { - if (call->op.as()) { - // Case 1: the op field is a primitive op, look up FInferType attribute - Op op = Downcast(call->op); - if (op_map_infer_type_.count(op)) { - return op_map_infer_type_[op](call, diag_ctx); - } else { - LOG(FATAL) << "ValueError: Cannot find the FInferType attribute registered to op: " - << op->name; - } + Expr VisitWithNewScope(const Expr& expr, Optional> params = NullOpt) { + // SeqExpr do not need to prepare for normalization. + if (expr.as()) { + this->BeginScope(params); + Expr ret = this->VisitExpr(expr); + this->EndScope(); + return ret; } else { - // Case 2: the op field is of callable type - ICHECK(call->op->checked_type_.defined()) - << "When the op field is not an OpNode, the CallNode's op must have checked_type_."; - if (call->op->checked_type_.as()) { - if (call->type_args.defined()) { - if (call->type_args.size() == 0) { - return ObjectType(); - } else if (call->type_args.size() == 1) { - return call->type_args.front(); - } else { - return TupleType(call->type_args); - } - } else { - LOG(FATAL) << "ExternFunc call must have type args."; - } - } else if (auto* func_node = call->op->checked_type_.as()) { - return func_node->ret_type; - } - } - LOG(FATAL) << "ValueError: the CallNode's op has to be either an OpNode, or has " - << " Callable (i.e., PackedFuncType or FuncType) as its checked_type_"; - throw; - } + this->BeginScope(params); - // Helper function to check if the provided shape and type is consistent. - // Throw internal exceptions if they are not consistent. - void CheckShapeTypeConsistency(const Optional& opt_shape, const Type& type) { - if (!type.defined() || type->IsInstance() || type->IsInstance() || - type->IsInstance()) { - ICHECK(!opt_shape.defined()) - << "When the type of an Expr is undefined/ShapeType/FuncType/ObjectType, the shape of " - "this Expr is expected to be undefined. However, the actual shape is defined and is " - << opt_shape.value(); - } else if (const auto* dyn_tensor_type = type.as()) { - // `opt_shape` should either be a relax::Expr or undefined. - if (opt_shape.defined()) { - const auto* shape = opt_shape.as(); - ICHECK(shape != nullptr) << "The shape of an Expr, if defined, is expected to be a Relax " - "Expr. However, the actual shape is not a Relax Expr and is " - << opt_shape.value()->GetTypeKey(); - ICHECK(shape->checked_type()->IsInstance()) - << "The shape of an Expr, if defined, is expected to be a Relax Expr which has type " - "ShapeType. However, the actual shape has type " - << shape->checked_type()->GetTypeKey(); - } - - const auto* shape_expr = opt_shape.as(); - if (dyn_tensor_type->IsUnknownNdim()) { - ICHECK(shape_expr == nullptr) - << "When the type of an Expr is DynTensorType with unknown ndim, the shape of the Expr " - "is expected not to be a ShapeExpr. However, the actual shape is ShapeExpr " - << GetRef(shape_expr); - } else if (shape_expr != nullptr) { - ICHECK(dyn_tensor_type->ndim == static_cast(shape_expr->values.size())) - << "When the type of an Expr is DynTensorType with known ndim and the shape of that " - "Expr is a ShapeExpr, the ShapeExpr should have as many values as the ndim " - "indicates. However, the actual Expr type has ndim " - << dyn_tensor_type->ndim << " while the actual Expr shape is " - << GetRef(shape_expr) << ", which has length " << shape_expr->values.size(); + this->BeginBindingBlock(); + Expr post = this->NormalizeArgument(expr); + BindingBlock prologue = this->EndBlock(); + // "New scopes" (function bodies, if/else clauses) must be wrapped in seq exprs. + // Don't wrap if it's already a seq and there are no bindings to add + if (post.as() && prologue->bindings.empty()) { + return post; } - } else if (const auto* tuple_type = type.as()) { - const auto* tuple_shape = opt_shape.as(); - if (tuple_shape == nullptr) { - ICHECK(tuple_type->fields.size() == 0) - << "When the type of an Expr is TupleType and the shape of that Expr is not a Tuple, " - "it means that the type should be a VoidType, which is represented as an empty " - "TupleType. However, here the shape is not a tuple while the type has " - << tuple_type->fields.size() << " field(s)."; - } else { - ICHECK_EQ(tuple_shape->fields.size(), tuple_type->fields.size()) - << "When the type of an Expr is TupleType and the shape of that Expr is a Tuple, the " - "two should have the same number of fields. However, the type has " - << tuple_type->fields.size() << " field(s) while the shape has " - << tuple_shape->fields.size() << " field(s)"; - int n_field = tuple_shape->fields.size(); - // Recursively check the consistency. - for (int i = 0; i < n_field; ++i) { - CheckShapeTypeConsistency(tuple_shape->fields[i], tuple_type->fields[i]); - } + Array bindings; + if (!prologue->bindings.empty()) { + bindings.push_back(prologue); } - } else { - LOG(FATAL) << "Unsupported relax type: " << type->GetTypeKey(); - } - } - Expr VisitWithNewScope(const Expr& expr) { - // SeqExpr do not need to prepare for normalization. - if (expr.as()) return this->VisitExpr(expr); + SeqExpr seq(bindings, post); + UpdateStructInfo(seq, EraseToWellDefinedInScope(GetStructInfo(seq->body))); - this->BeginBindingBlock(); - Expr post = this->NormalizeArgument(expr); - BindingBlock prologue = this->EndBlock(); - // "New scopes" (function bodies, if/else clauses) must be wrapped in seq exprs. - // Don't wrap if it's already a seq and there are no bindings to add - if (post.as() && prologue->bindings.empty()) { - return post; - } - Array bindings; - if (!prologue->bindings.empty()) { - bindings.push_back(prologue); + this->EndScope(); + return seq; } - - SeqExpr seq(bindings, post); - UpdateShape(seq, post->shape_); - UpdateType(seq, post->checked_type_); - return seq; } Array FlattenBlocks(const Array& blocks) { @@ -1067,11 +967,9 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor op_map_infer_shape_ = Op::GetAttrMap("FInferShape"); - - /*! \brief Operator to type inference map. */ - tvm::OpAttrMap op_map_infer_type_ = Op::GetAttrMap("FInferType"); + /*! \brief Operator struct info inference map. */ + tvm::OpAttrMap op_map_infer_struct_info_ = + Op::GetAttrMap("FInferStructInfo"); }; BlockBuilder BlockBuilder::Create(Optional mod) { @@ -1151,5 +1049,11 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderCurrentBlockIsDataFlow") TVM_REGISTER_GLOBAL("relax.BlockBuilderLookupBinding") .set_body_method(&BlockBuilderNode::LookupBinding); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginScope") + .set_body_method(&BlockBuilderNode::BeginScope); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEndScope") + .set_body_method(&BlockBuilderNode::EndScope); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index a931d7eec081..e28d13228bb2 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -16,7 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +#include #include +#include #include #include @@ -26,6 +28,13 @@ RelayExpr RelayExprNode::shape() const { if (this->shape_.defined()) { return Downcast(this->shape_); } + if (this->struct_info_.defined()) { + Optional shape = + relax::GetLegacyShapeHint(Downcast(this->struct_info_.value())); + if (shape.defined()) { + return shape.value(); + } + } static const Op& op = Op::Get("relax.shape_of"); RelayExpr self = GetRef(this); relax::Call call_shape_of(op, {self}, {}, {}); @@ -231,21 +240,19 @@ TVM_REGISTER_NODE_TYPE(ShapeExprNode); ShapeExpr::ShapeExpr(Array values, Span span) { ObjectPtr n = make_object(); - Array new_values; - new_values.reserve(values.size()); - for (const PrimExpr& value : values) { - PrimExpr new_value = value; + + n->values = values.Map([](PrimExpr value) { if (value->IsInstance()) { - new_value = tvm::cast(DataType::Int(64), value); - } else if (value.dtype() != DataType::Int(64)) { - LOG(FATAL) << "the value in ShapeExpr can only have dtype of int64"; + return tvm::cast(DataType::Int(64), value); } - new_values.push_back(new_value); - } - n->values = std::move(new_values); + ICHECK(value.dtype() == DataType::Int(64)) + << "the value in ShapeStructInfo can only have dtype of int64"; + return value; + }); n->span = span; n->shape_ = NullOpt; - n->checked_type_ = ShapeType(); + n->checked_type_ = ShapeType(values.size()); + n->struct_info_ = ShapeStructInfo(values, span); data_ = std::move(n); } @@ -258,6 +265,7 @@ TVM_REGISTER_NODE_TYPE(RuntimeDepShapeNode); RuntimeDepShape::RuntimeDepShape(Span span) { ObjectPtr n = make_object(); n->span = span; + n->struct_info_ = ShapeStructInfo(kUnknownDim); n->checked_type_ = ShapeType(); data_ = std::move(n); } @@ -284,9 +292,15 @@ TVM_REGISTER_NODE_TYPE(VarNode); Var::Var(Id vid, Optional shape_annotation, Optional type_annotation, Span span) { ObjectPtr n = make_object(); n->vid = std::move(vid); - n->shape_ = std::move(shape_annotation); + // invariant for transition, always require type ann if shape is provided. + if (shape_annotation) { + ICHECK(type_annotation) << "Var requires type annotation if we provide shape ann"; + } if (type_annotation) { + StructInfo sinfo = StructInfoFromTypeLegacyShapeHint(type_annotation.value(), shape_annotation); + n->struct_info_ = sinfo; n->checked_type_ = std::move(type_annotation.value()); + n->shape_ = GetLegacyShapeHint(sinfo); } n->span = std::move(span); data_ = std::move(n); @@ -308,9 +322,15 @@ DataflowVar::DataflowVar(Id vid, Optional shape_annotation, Optional Span span) { ObjectPtr n = make_object(); n->vid = std::move(vid); - n->shape_ = std::move(shape_annotation); + // invariant for transition, always require type ann if shape is provided. + if (shape_annotation) { + ICHECK(type_annotation) << "Var requires type annotation if we provide shape ann"; + } if (type_annotation) { + StructInfo sinfo = StructInfoFromTypeLegacyShapeHint(type_annotation.value(), shape_annotation); + n->struct_info_ = sinfo; n->checked_type_ = std::move(type_annotation.value()); + n->shape_ = GetLegacyShapeHint(sinfo); } n->span = std::move(span); data_ = std::move(n); @@ -332,15 +352,19 @@ Constant::Constant(runtime::NDArray data, Span span) { ObjectPtr n = make_object(); n->data = std::move(data); n->span = std::move(span); - DataType dtype = n->data.DataType(); - ShapeTuple shape_tuple = n->data.Shape(); - Type type = DynTensorType(shape_tuple.size(), dtype); - n->checked_type_ = type; + + // set struct info. Array values; - for (size_t dim = 0; dim < shape_tuple.size(); dim++) { + auto shape_tuple = n->data.Shape(); + for (size_t dim = 0; dim < shape_tuple.size(); ++dim) { values.push_back(IntImm(DataType::Int(64), shape_tuple[dim])); } - n->shape_ = ShapeExpr(values); + TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), span); + + n->struct_info_ = tinfo; + n->checked_type_ = DynTensorType(tinfo->ndim, tinfo->dtype); + n->shape_ = tinfo->shape; + data_ = std::move(n); } @@ -430,38 +454,45 @@ Function::Function(Array params, Expr body, Type ret_type, Expr ret_shape, // Set the function type. // For function, we take a conservative approach and require the function type // to be known at construction time. - Array param_types; + Array param_sinfo; + for (const Var& param : params) { - CHECK(param->checked_type_.defined()) - << "relax.Function requires params to contain checked_type_"; - param_types.push_back(param->checked_type_); + CHECK(param->struct_info_.defined()) + << "relax.Function requires params to contain struct_info_"; + param_sinfo.push_back(GetStructInfo(param)); } - if (!ret_type.defined()) { - CHECK(body->checked_type_.defined()) - << "relax.Function requires body to contain deduced checked_type_" - << " or ret_type to be supplied"; - ret_type = body->checked_type_; - } else { - if (body->checked_type_.defined()) { - CHECK(IsBaseOf(ret_type, body->checked_type_)) - << "relax.Function requires the deduced body->checked_type_ to be a subtype of the " - "annotated ret_type but meet body->checked_type_: " - << body->checked_type_ << ", ret_type: " << ret_type; - - // Use the more refined body->checked_type_ as the return type. - ret_type = body->checked_type_; + Optional ret_sinfo; + Optional body_sinfo; + + if (body->struct_info_.defined()) { + body_sinfo = GetStructInfo(body); + } + + if (ret_type.defined()) { + ret_sinfo = StructInfoFromTypeLegacyShapeHint(ret_type, ret_shape); + // allow body to override ret if body is more fine-grained. + if (body_sinfo.defined()) { + if (IsBaseOf(ret_sinfo.value(), body_sinfo.value())) { + ret_sinfo = body_sinfo; + } } + } else { + CHECK(body_sinfo.defined()) + << "Function do not have a return signature and body is not normalized"; + ret_sinfo = body_sinfo; } - auto func_type = FuncType(param_types, ret_type, {}, {}); + + FuncStructInfo func_sinfo(param_sinfo, ret_sinfo.value()); // set the fields ObjectPtr n = make_object(); n->params = std::move(params); n->body = std::move(body); - n->ret_type = std::move(ret_type); - n->ret_shape = std::move(ret_shape); - n->checked_type_ = std::move(func_type); + n->ret_type = GetStaticType(ret_sinfo.value()); + n->ret_shape = GetLegacyShapeHint(ret_sinfo.value()).value_or(ret_shape); + n->checked_type_ = GetStaticType(func_sinfo); + n->struct_info_ = std::move(func_sinfo); n->attrs = std::move(attrs); n->span = std::move(span); data_ = std::move(n); @@ -475,15 +506,31 @@ TVM_REGISTER_GLOBAL("relax.Function") Function Function::CreateUnchecked(Array params, Expr body, Type ret_type, Expr ret_shape, DictAttrs attrs, Span span) { + // TODO(@Hzfengsy): revisit `CreateUnchecked` after the parser_v1 removed + + Array param_sinfo; + for (Var param : params) { ICHECK(param->checked_type_.defined()) << "relax.Function requires params to contain checked_type_."; + param_sinfo.push_back(GetStructInfo(param)); + } + + StructInfo ret_info; + + if (ret_type.defined()) { + ret_info = StructInfoFromTypeLegacyShapeHint(ret_type, ret_shape); + } else { + ret_info = FuncStructInfo::OpaqueFunc(); } + FuncStructInfo finfo(param_sinfo, ret_info); // set the fields ObjectPtr n = make_object(); n->params = std::move(params); n->body = std::move(body); + n->checked_type_ = GetStaticType(finfo); + n->struct_info_ = std::move(finfo); n->ret_type = std::move(ret_type); n->ret_shape = std::move(ret_shape); n->attrs = std::move(attrs); @@ -497,13 +544,41 @@ TVM_REGISTER_GLOBAL("relax.Function_CreateUnchecked") return Function::CreateUnchecked(params, body, ret_type, ret_shape, attrs, span); }); +// Special opaque derivation function for ExternFunc +// Take look at type_args to figure out the return StructInfo. +// TODO(relax-team): revisit type_args related deduction. +TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_by_ty_args") + .set_body_typed([](const Call& call, const BlockBuilder& ctx) -> StructInfo { + if (call->type_args.defined()) { + if (call->type_args.size() == 0) { + return ObjectStructInfo(); + } else if (call->type_args.size() == 1) { + return StructInfoFromType(call->type_args[0]); + } else { + return StructInfoFromType(TupleType(call->type_args)); + } + } else { + return ObjectStructInfo(); + } + }); + +// Get the derive function. +FuncStructInfo GetExternFuncStructInfo() { + EnvFunc fn = EnvFunc::Get("tvm.relax.struct_info.infer_by_ty_args"); + StructInfoDeriveFunc derive; + derive = fn; + return FuncStructInfo::OpaqueFunc(derive); +} + TVM_REGISTER_NODE_TYPE(ExternFuncNode); ExternFunc::ExternFunc(String global_symbol, Span span) { ObjectPtr n = make_object(); n->global_symbol = std::move(global_symbol); n->span = span; - n->checked_type_ = PackedFuncType(); + static auto sinfo = GetExternFuncStructInfo(); + n->struct_info_ = sinfo; + n->checked_type_ = GetStaticType(sinfo); data_ = std::move(n); } @@ -511,25 +586,5 @@ TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol, return ExternFunc(global_symbol, span); }); -void UpdateType(Expr expr, Type type) { - ICHECK(!expr->checked_type_.defined() || tvm::StructuralEqual()(expr->checked_type_, type)) - << "the checked_type_ of the Expr to be updated must be nullptr for idempotency"; - expr->checked_type_ = type; -} - -TVM_REGISTER_GLOBAL("relax.UpdateType").set_body_typed([](Expr expr, Type type) { - UpdateType(expr, type); -}); - -void UpdateShape(Expr expr, Optional shape) { - ICHECK(!expr->shape_.defined()) - << "the shape_ of the Expr to be updated must be nullptr for idempotency"; - expr->shape_ = shape; -} - -TVM_REGISTER_GLOBAL("relax.UpdateShape").set_body_typed([](Expr expr, Optional shape) { - UpdateShape(expr, shape); -}); - } // namespace relax } // namespace tvm diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 87981b5e5e64..89c862aae8d4 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -152,7 +152,12 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); } -void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { this->VisitSpan(op->span); } +void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { + for (PrimExpr val : op->values) { + this->VisitPrimExpr(val); + } + this->VisitSpan(op->span); +} void ExprVisitor::VisitExpr_(const RuntimeDepShapeNode* op) { this->VisitSpan(op->span); } @@ -170,6 +175,8 @@ void ExprVisitor::VisitType(const Type& t) {} void ExprVisitor::VisitSpan(const Span& span) {} +void ExprVisitor::VisitPrimExpr(const PrimExpr& expr) {} + // implementations of binding visitor dispatch RELAX_VAR_BINDING_DISPATCH_IMPL(ExprVisitor); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ConstantNode); @@ -367,7 +374,15 @@ Expr ExprMutatorBase::VisitExpr_(const TupleGetItemNode* op) { } } -Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { return GetRef(op); } +Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { + auto values = op->values.Map([this](const PrimExpr& e) { return this->VisitPrimExpr(e); }); + + if (values.same_as(op->values)) { + return GetRef(op); + } else { + return ShapeExpr(values, op->span); + } +} Expr ExprMutatorBase::VisitExpr_(const RuntimeDepShapeNode* op) { return GetRef(op); } @@ -421,6 +436,8 @@ BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { Type ExprMutatorBase::VisitType(const Type& t) { return t; } +PrimExpr ExprMutatorBase::VisitPrimExpr(const PrimExpr& expr) { return expr; } + // ================== // ExprMutator @@ -478,7 +495,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { Type ret_type = this->VisitType(op->ret_type); Expr ret_shape = this->VisitExpr(op->ret_shape); - Expr body = this->VisitWithNewScope(op->body); + Expr body = this->VisitWithNewScope(op->body, params); if (all_params_unchanged && ret_type.same_as(op->ret_type) && body.same_as(op->body) && ret_shape.same_as(op->ret_shape)) { @@ -565,7 +582,7 @@ void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { return; } - Var temp = WithShapeAndType(new_var, new_value->shape_, new_value->checked_type_); + Var temp = WithStructInfo(new_var, GetStructInfo(new_value)); if (!temp.same_as(new_var)) { new_var = temp; this->var_remap_[binding->var->vid] = new_var; @@ -580,18 +597,14 @@ void ExprMutator::VisitBinding_(const MatchShapeNode* binding) { Var new_var; if (binding->var.defined()) { - // in the case of `x = R.match_shape(val, pattern)`, we want `x` to directly get `pattern` as - // the shape when `val` is a tensor. - Optional new_shape; - Type new_type = new_value->checked_type_; - if (new_value->checked_type_.defined() && new_value->checked_type_.as()) { - new_shape = new_pattern; - ICHECK(new_shape->IsInstance()); - int ndim = Downcast(new_shape.value())->values.size(); - new_type = DynTensorType(ndim, new_value->checked_type_.as()->dtype); + StructInfo new_sinfo = GetStructInfo(new_value); + + if (auto* ptr = new_sinfo.as()) { + new_sinfo = TensorStructInfo(new_pattern, ptr->dtype); } new_var = this->VisitVarDef(binding->var); - Var temp = WithShapeAndType(new_var, new_shape, new_type); + + Var temp = WithStructInfo(new_var, new_sinfo); if (!temp.same_as(new_var)) { new_var = temp; this->var_remap_[binding->var->vid] = new_var; @@ -630,42 +643,14 @@ BindingBlock ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) { } Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) { - bool shape_unchanged = true; - Expr new_shape; - if (var->shape_) { - new_shape = this->VisitExpr(Downcast(var->shape_.value())); - shape_unchanged &= new_shape.same_as(var->shape_); - } - - if (shape_unchanged) { - return GetRef(var); - } else { - Var new_var = DataflowVar(var->vid, NullOpt, var->checked_type_, var->span); - UpdateShape(new_var, new_shape); - - this->var_remap_[var->vid] = new_var; - return new_var; - } + // If an Expr have struct info, they must already be normalized, + // This invariant is checked at the constructor location. + // to simplify our overall development complexity and keep var def + // stable by default. + return GetRef(var); } -Var ExprMutator::VisitVarDef_(const VarNode* var) { - bool shape_unchanged = true; - Expr new_shape; - if (var->shape_) { - new_shape = this->VisitExpr(Downcast(var->shape_.value())); - shape_unchanged &= new_shape.same_as(var->shape_); - } - - if (shape_unchanged) { - return GetRef(var); - } else { - Var new_var = Var(var->vid, NullOpt, var->checked_type_, var->span); - UpdateShape(new_var, new_shape); - - this->var_remap_[var->vid] = new_var; - return new_var; - } -} +Var ExprMutator::VisitVarDef_(const VarNode* var) { return GetRef(var); } void ExprMutator::VisitBinding(const Binding& binding) { if (const auto* node = binding.as()) { @@ -701,50 +686,36 @@ Var ExprMutator::VisitVarDef(const Var& var) { return ret; } -Expr ExprMutator::VisitWithNewScope(const Expr& expr) { - if (expr->IsInstance()) { - return this->VisitExpr(expr); - } else { - builder_->BeginBindingBlock(); - Expr ret = this->VisitExpr(expr); - BindingBlock prologue = builder_->EndBlock(); - if (!prologue->bindings.empty()) { - ret = SeqExpr({prologue}, ret); - } - return ret; - } +Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional> params) { + ICHECK(expr->IsInstance()) + << "Normal form requires all new scope is stored as SeqExpr"; + builder_->BeginScope(params); + Expr ret = this->VisitExpr(expr); + builder_->EndScope(); + return ret; } Optional ExprMutator::LookupBinding(const Var& var) { return builder_->LookupBinding(var); } -Var ExprMutator::WithShapeAndType(Var var, Optional shape, Type type) { - // shape/type changes if it goes from defined -> undefined or the other way, hence xor - bool shape_changed = var->shape_.operator bool() ^ shape.operator bool(); - shape_changed |= var->shape_ && shape && - !builder_->CanProveShapeEqual(Downcast(var->shape_.value()), - Downcast(shape.value())); +Var ExprMutator::WithStructInfo(Var var, StructInfo struct_info) { + ICHECK(struct_info.defined()); - bool type_changed = var->checked_type_.defined() ^ type.defined(); - type_changed |= var->checked_type_.defined() && type.defined() && - !StructuralEqual()(var->checked_type_, type); - - if (shape_changed || type_changed) { - Var new_var = var.as() ? DataflowVar(var->vid, NullOpt, NullOpt, var->span) - : Var(var->vid, NullOpt, NullOpt, var->span); - UpdateShape(new_var, var->shape_); - UpdateType(new_var, var->checked_type_); - var = new_var; - } - - if (shape_changed) { - var->shape_ = shape; - } - - if (type_changed) { - var->checked_type_ = type; + // TODO(relax-team) add StructInfoEqual check + if (var->struct_info_.defined()) { + // use same-as as a quick path + if (var->struct_info_.same_as(struct_info) || + StructuralEqual()(var->struct_info_, struct_info)) { + return var; + } else { + Var new_var = var.as() ? DataflowVar(var->vid, NullOpt, NullOpt, var->span) + : Var(var->vid, NullOpt, NullOpt, var->span); + UpdateStructInfo(new_var, struct_info); + return new_var; + } + } else { + UpdateStructInfo(var, struct_info); + return var; } - - return var; } TVM_REGISTER_GLOBAL("relax.MakePyExprVisitor").set_body_typed(PyExprVisitor::MakePyExprVisitor); @@ -857,9 +828,9 @@ TVM_REGISTER_GLOBAL("relax.PyExprMutatorLookupBinding") return mutator->LookupBinding(var); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorWithShapeAndType") - .set_body_typed([](PyExprMutator mutator, Var var, Optional shape, Type type) { - return mutator->WithShapeAndType(var, shape, type); +TVM_REGISTER_GLOBAL("relax.PyExprMutatorWithStructInfo") + .set_body_typed([](PyExprMutator mutator, Var var, StructInfo sinfo) { + return mutator->WithStructInfo(var, sinfo); }); TVM_REGISTER_GLOBAL("relax.PyExprMutatorSetVarRemap") diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc new file mode 100644 index 000000000000..4b53baa43c8e --- /dev/null +++ b/src/relax/ir/struct_info.cc @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/struct_info.cc + * \brief Relax struct info. + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +ObjectStructInfo::ObjectStructInfo(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ObjectStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) { + return ObjectStructInfo(span); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "ObjectStructInfo()"; + }); + +// Prim +PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { + ObjectPtr n = make_object(); + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PrimStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, Span span) { + return PrimStructInfo(dtype, span); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + p->stream << "PrimStructInfo(" << node->dtype << ")"; + }); + +// Shape +ShapeStructInfo::ShapeStructInfo(Array values, Span span) { + ObjectPtr n = make_object(); + n->ndim = static_cast(values.size()); + n->values = values.Map([](PrimExpr value) { + if (value->IsInstance()) { + return tvm::cast(DataType::Int(64), value); + } + ICHECK(value.dtype() == DataType::Int(64)) + << "the value in ShapeStructInfo can only have dtype of int64"; + return value; + }); + n->span = span; + data_ = std::move(n); +} + +ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { + ObjectPtr n = make_object(); + n->ndim = ndim; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ShapeStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.ShapeStructInfo") + .set_body_typed([](Optional> values, int ndim, Span span) { + if (values.defined()) { + CHECK_EQ(ndim, kUnknownDim) << "ValueError: Cannot both specify values and ndim"; + return ShapeStructInfo(values.value(), span); + } else { + return ShapeStructInfo(ndim, span); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + if (node->values.defined()) { + p->stream << "ShapeStructInfo(" << node->values.value() << ")"; + } else { + p->stream << "ShapeStructInfo(ndim=" << node->ndim << ")"; + } + }); + +// Tensor +TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) { + ObjectPtr n = make_object(); + // assign ndim before move + Optional sinfo = MatchStructInfo(shape); + ICHECK(sinfo) << "We expect shape to contain pre-set shape struct info"; + ICHECK(shape.defined()) << "Must provide a shape in this constructor"; + ICHECK(shape->IsInstance() || shape->IsInstance()) + << "We require shape to be normalized when constructing TensorStructInfo"; + n->ndim = sinfo.get()->ndim; + // assign rest of the fields. + n->shape = std::move(shape); + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Span span) { + ObjectPtr n = make_object(); + n->ndim = ndim; + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TensorStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.TensorStructInfo") + .set_body_typed([](Optional shape, DataType dtype, int ndim, Span span) { + if (shape.defined()) { + CHECK_EQ(ndim, kUnknownDim) << "ValueError: Cannot both specify shape and ndim"; + return TensorStructInfo(shape.value(), dtype, span); + } else { + return TensorStructInfo(dtype, ndim, span); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + if (node->shape.defined()) { + p->stream << "TensorStructInfo(" << node->shape.value() << ", " << node->dtype << ")"; + } else { + p->stream << "TensorStructInfo(" << node->dtype << ", ndim=" << node->ndim << ")"; + } + }); + +// Tuple +TupleStructInfo::TupleStructInfo(Array fields, Span span) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TupleStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.TupleStructInfo") + .set_body_typed([](Array fields, Span span) { + return TupleStructInfo(fields, span); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + p->stream << "TupleStructInfo(" << node->fields << ")"; + }); + +// Func +FuncStructInfo::FuncStructInfo(Array params, StructInfo ret, Span span) { + ObjectPtr n = make_object(); + n->params = std::move(params); + n->ret = std::move(ret); + n->span = span; + data_ = std::move(n); +} + +FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func, Span span) { + ObjectPtr n = make_object(); + n->derive_func = std::move(derive_func); + n->ret = ObjectStructInfo(); + n->span = span; + return FuncStructInfo(n); +} + +FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, Span span) { + ObjectPtr n = make_object(); + n->ret = std::move(ret); + n->span = span; + return FuncStructInfo(n); +} + +TVM_REGISTER_NODE_TYPE(FuncStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.FuncStructInfo") + .set_body_typed([](Array params, StructInfo ret, Span span) { + return FuncStructInfo(params, ret, span); + }); + +TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc") + .set_body_typed([](Optional ret, Optional derive_func, + Span span) { + if (derive_func.defined()) { + ICHECK(!ret.defined()) << "ValueError: Cannot specify both ret and derive_func"; + return FuncStructInfo::OpaqueFunc(derive_func.value(), span); + } else { + return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), span); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + p->stream << "FuncStructInfo(" << node->params << ", " << node->ret << ")"; + }); + +// Helper functions +void UpdateStructInfo(Expr expr, StructInfo struct_info) { + ICHECK(!expr->struct_info_.defined()) + << "the struct_info_ of the Expr to be updated must be nullptr for idempotency"; + expr->struct_info_ = struct_info; + // also set checked type + expr->checked_type_ = GetStaticType(struct_info); + expr->shape_ = GetLegacyShapeHint(struct_info); +} + +TVM_REGISTER_GLOBAL("relax.UpdateStructInfo").set_body_typed([](Expr expr, StructInfo struct_info) { + UpdateStructInfo(expr, struct_info); +}); + +TVM_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) { + return GetStructInfo(expr); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc new file mode 100644 index 000000000000..199491e3c63f --- /dev/null +++ b/src/relax/ir/struct_info_functor.cc @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file struct_info_functor.cc + * \brief Implementations of struct info functors. + */ +#include + +namespace tvm { +namespace relax { + +void StructInfoVisitor::VisitStructInfo_(const ObjectStructInfoNode* op) {} + +void StructInfoVisitor::VisitStructInfo_(const PrimStructInfoNode* op) {} + +void StructInfoVisitor::VisitStructInfo_(const ShapeStructInfoNode* op) { + if (op->values.defined()) { + for (PrimExpr value : op->values.value()) { + this->VisitStructInfoExprField(value); + } + } +} + +void StructInfoVisitor::VisitStructInfo_(const TensorStructInfoNode* op) { + if (op->shape.defined()) { + this->VisitStructInfoExprField(op->shape.value()); + } +} + +void StructInfoVisitor::VisitStructInfo_(const TupleStructInfoNode* op) { + for (StructInfo field : op->fields) { + this->VisitStructInfo(field); + } +} + +void StructInfoVisitor::VisitStructInfo_(const FuncStructInfoNode* op) { + if (op->params.defined()) { + for (StructInfo param : op->params.value()) { + this->VisitStructInfo(param); + } + } + this->VisitStructInfo(op->ret); +} + +StructInfo StructInfoMutator::VisitStructInfo_(const ObjectStructInfoNode* op) { + return GetRef(op); +} + +StructInfo StructInfoMutator::VisitStructInfo_(const PrimStructInfoNode* op) { + return GetRef(op); +} + +StructInfo StructInfoMutator::VisitStructInfo_(const ShapeStructInfoNode* op) { + Optional> values; + + if (op->values.defined()) { + // if no changes are made the original array will be returned. + values = op->values.value().Map( + [this](const PrimExpr& expr) { return this->VisitStructInfoExprField(expr); }); + } + + if (values.same_as(op->values)) { + return GetRef(op); + } else { + return ShapeStructInfo(values.value(), op->span); + } +} + +StructInfo StructInfoMutator::VisitStructInfo_(const TensorStructInfoNode* op) { + Optional shape; + + if (op->shape.defined()) { + shape = this->VisitStructInfoExprField(op->shape.value()); + } + + if (shape.same_as(op->shape)) { + return GetRef(op); + } else { + return TensorStructInfo(shape.value(), op->dtype, op->span); + } +} + +StructInfo StructInfoMutator::VisitStructInfo_(const TupleStructInfoNode* op) { + Array fields = + op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + + if (fields.same_as(op->fields)) { + return GetRef(op); + } else { + return TupleStructInfo(fields, op->span); + } +} + +StructInfo StructInfoMutator::VisitStructInfo_(const FuncStructInfoNode* op) { + Optional> params; + + if (op->params.defined()) { + params = op->params.value().Map( + [this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + } + + StructInfo ret = this->VisitStructInfo(op->ret); + + if (params.same_as(op->params) && ret.same_as(op->ret)) { + return GetRef(op); + } else { + ICHECK(ret.defined()) << "FuncStructInfo that contains params must contain ret"; + return FuncStructInfo(params.value(), ret, op->span); + } +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc index 1c5ff9b44961..71a06cbce787 100644 --- a/src/relax/ir/type.cc +++ b/src/relax/ir/type.cc @@ -29,13 +29,16 @@ namespace relax { TVM_REGISTER_NODE_TYPE(ShapeTypeNode); -ShapeType::ShapeType(Span span) { +ShapeType::ShapeType(int ndim, Span span) { ObjectPtr n = make_object(); + n->ndim = ndim; n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ShapeType").set_body_typed([](Span span) { return ShapeType(span); }); +TVM_REGISTER_GLOBAL("relax.ShapeType").set_body_typed([](int ndim, Span span) { + return ShapeType(ndim, span); +}); ObjectType::ObjectType(Span span) { ObjectPtr n = make_object(); diff --git a/src/relax/ir/type_analysis.cc b/src/relax/ir/type_analysis.cc index 5d00bce73bae..7c532d816c95 100644 --- a/src/relax/ir/type_analysis.cc +++ b/src/relax/ir/type_analysis.cc @@ -41,8 +41,8 @@ class BaseTypeChecker : public TypeFunctor { bool VisitType_(const ObjectTypeNode* base) final { return true; } bool VisitType_(const ShapeTypeNode* base) final { - if (derived_.as()) { - return true; + if (auto* rhs = derived_.as()) { + return base->ndim == kUnknownDim || base->ndim == rhs->ndim; } return false; } @@ -122,7 +122,8 @@ class LCAVisitor : public TypeFunctor { Type VisitType_(const ObjectTypeNode* t) final { return ObjectType(); } Type VisitType_(const ShapeTypeNode* t) final { - if (u_.as()) { + if (auto* rhs = u_.as()) { + if (t->ndim == rhs->ndim) return GetRef(t); return ShapeType(); } return ObjectType(); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index a9389ed60db5..a60b9bf775e7 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -54,26 +55,44 @@ bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { return false; } -Type ReturnVoidType(const Call& call, DiagnosticContext diag_ctx) { return VoidType(); } +StructInfo ReturnVoidStructInfo(const Call& call, const BlockBuilder& ctx) { + return TupleStructInfo(Array()); +} -Type ReturnObjectType(const Call& call, DiagnosticContext diag_ctx) { return ObjectType(); } +StructInfo ReturnObjectStructInfo(const Call& call, const BlockBuilder& ctx) { + return ObjectStructInfo(); +} -Type ReturnShapeType(const Call& call, DiagnosticContext diag_ctx) { return ShapeType(); } +StructInfo ReturnShapeStructInfo(const Call& call, const BlockBuilder& ctx) { + return ShapeStructInfo(kUnknownDim); +} // call_tir -Expr InferShapeCallTIR(const Call& call, DiagnosticContext diag_ctx) { - Expr output_shape = call->args[2]; - return output_shape; +StructInfo CallTIRStructInfoFromShapeType(Expr shape, Type type) { + if (auto* tuple = shape.as()) { + auto* ptr_type = type.as(); + ICHECK(ptr_type != nullptr) << "Expect tuple type and shape to be consistent."; + ICHECK_EQ(ptr_type->fields.size(), tuple->fields.size()); + Array arr; + for (size_t i = 0; i < ptr_type->fields.size(); ++i) { + arr.push_back(CallTIRStructInfoFromShapeType(tuple->fields[i], ptr_type->fields[i])); + } + return TupleStructInfo(arr); + } else { + auto* ptr_type = type.as(); + ICHECK(ptr_type != nullptr) << "Expect singleton shape to correspond to DynTensorType."; + return TensorStructInfo(shape, ptr_type->dtype); + } } -Type InferTypeArg(const Call& call, DiagnosticContext diag_ctx) { +StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { + Expr output_shape = call->args[2]; if (call->type_args.size() != 1) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "type_args should have exact 1 output type."); + ctx->ReportFatal(Diagnostic::Error(call) << "type_args should have exact 1 output type."); } Type output_type = call->type_args[0]; - return output_type; + return CallTIRStructInfoFromShapeType(output_shape, output_type); } RELAY_REGISTER_OP("relax.call_tir") @@ -84,8 +103,7 @@ RELAY_REGISTER_OP("relax.call_tir") .add_argument("packed_ints", "Expr", "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " "args if unused") - .set_attr("FInferShape", InferShapeCallTIR) - .set_attr("FInferType", InferTypeArg); + .set_attr("FInferStructInfo", InferStructInfoCallTIR); Expr MakeCallTIR(Expr func, Tuple args, Expr output_shape, Type output_type, Optional packed_ints) { @@ -109,7 +127,7 @@ RELAY_REGISTER_OP("relax.print") .set_attrs_type() .set_num_inputs(-1) .add_argument("vals", "Array", "Values to print.") - .set_attr("FInferType", ReturnVoidType) + .set_attr("FInferStructInfo", ReturnVoidStructInfo) .set_attr("FCallPacked", "relax.run.print"); Expr MakePrint(Array vals, std::string format) { @@ -125,21 +143,21 @@ TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint); // can't actually name it assert or else Python will consider it a syntax error -Type InferAssertType(const Call& call, DiagnosticContext diag_ctx) { +StructInfo InferAssertStructInfo(const Call& call, const BlockBuilder& ctx) { // Ensure that the condition argument is a boolean scalar. // Also permitted is a tensor with unknown shape and unknown dtype // (checked dynamically in that case). Returns void. if (call->args.size() < 1) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "Assert must have at least one argument (the condition)."); + ctx->ReportFatal(Diagnostic::Error(call) + << "Assert must have at least one argument (the condition)."); } Type arg_type = call->args[0]->checked_type(); if (!IsBoolScalarType(arg_type)) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "The argument to assert must be a boolean scalar type, but received " - << arg_type); + ctx->ReportFatal(Diagnostic::Error(call) + << "The argument to assert must be a boolean scalar type, but received " + << arg_type); } - return VoidType(); + return ReturnVoidStructInfo(call, ctx); } TVM_REGISTER_NODE_TYPE(AssertOpAttrs); @@ -150,7 +168,7 @@ RELAY_REGISTER_OP("relax.assert_op") .add_argument("vals", "Array", "The first value is used as the assertion condition. The others are used as " "format arguments if there is an error.") - .set_attr("FInferType", InferAssertType) + .set_attr("FInferStructInfo", InferAssertStructInfo) .set_attr("FCallPacked", "relax.run.assert_op"); Expr MakeAssertOp(Expr condition, Array vals, std::string format) { @@ -172,7 +190,7 @@ RELAY_REGISTER_OP("relax.make_closure") .set_num_inputs(2) .add_argument("func", "Expr", "The closure.") .add_argument("args", "Tuple", "The captured variables.") - .set_attr("FInferType", ReturnObjectType); + .set_attr("FInferStructInfo", ReturnObjectStructInfo); Expr MakeClosure(Expr func, Tuple args) { static const Op& op = Op::Get("relax.make_closure"); @@ -183,11 +201,21 @@ TVM_REGISTER_GLOBAL("relax.op.make_closure").set_body_typed(MakeClosure); // invoke_closure +StructInfo InferStructInfoInvokeClosure(const Call& call, const BlockBuilder& ctx) { + if (call->type_args.empty()) { + return ObjectStructInfo(); + } else if (call->type_args.size() == 1) { + return StructInfoFromType(call->type_args[0]); + } else { + return StructInfoFromType(TupleType(call->type_args)); + } +} + RELAY_REGISTER_OP("relax.invoke_closure") .set_num_inputs(2) .add_argument("closure", "Expr", "The VMClosure.") .add_argument("args", "Tuple", "The captured variables.") - .set_attr("FInferType", InferTypeArg); + .set_attr("FInferStructInfo", InferStructInfoInvokeClosure); Expr InvokeClosure(Expr closure, Tuple args, Array type_args) { static const Op& op = Op::Get("relax.invoke_closure"); @@ -201,7 +229,7 @@ TVM_REGISTER_GLOBAL("relax.op.invoke_closure").set_body_typed(InvokeClosure); RELAY_REGISTER_OP("relax.shape_of") .set_num_inputs(1) .add_argument("input", "Expr", "The input expression") - .set_attr("FInferType", ReturnShapeType); + .set_attr("FInferStructInfo", ReturnShapeStructInfo); Expr MakeShapeOf(Expr expr) { static const Op& op = Op::Get("relax.shape_of"); @@ -212,22 +240,19 @@ TVM_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf); // alloc_tensor -Expr InferShapeAllocTensor(const Call& call, DiagnosticContext diag_ctx) { return call->args[0]; } - -Type InferTypeAllocTensor(const Call& call, DiagnosticContext diag_ctx) { - auto attrs = call->attrs.as(); +StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& ctx) { + const auto* attrs = call->attrs.as(); ICHECK(attrs != nullptr) << "must be AllocTensorAttrs, but got " << call->attrs->GetTypeKey(); - auto output_shape = call->args[0].as(); - ICHECK(output_shape != nullptr) << "must be ShapeExpr, but got " << call->args[0]->GetTypeKey(); - return DynTensorType(output_shape->values.size(), attrs->dtype); + ICHECK(call->args[0].as()) + << "must be ShapeExpr, but got " << call->args[0]->GetTypeKey(); + return TensorStructInfo(call->args[0], attrs->dtype); } RELAY_REGISTER_OP("relax.builtin.alloc_tensor") .set_attrs_type() .set_num_inputs(1) .add_argument("shape", "Expr", "The shape of the tensor to allocate.") - .set_attr("FInferShape", InferShapeAllocTensor) - .set_attr("FInferType", InferTypeAllocTensor); + .set_attr("FInferStructInfo", InferStructInfoAllocateTensor); Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t runtime_device_index) { auto attrs = make_object(); @@ -245,7 +270,7 @@ RELAY_REGISTER_OP("relax.memory.alloc_storage") .set_attrs_type() .set_num_inputs(1) .add_argument("total_space", "Expr", "The total space of the storage to allocate.") - .set_attr("FInferType", ReturnObjectType); + .set_attr("FInferStructInfo", ReturnObjectStructInfo); Expr MakeAllocStorage(Expr size, int64_t virtual_device_index, std::string storage_scope, DataType dtype) { @@ -261,17 +286,12 @@ TVM_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocSto // memory planning alloc_tensor -Expr InferShapeMemAllocTensor(const Call& call, DiagnosticContext diag_ctx) { - return call->args[1]; -} - -Type InferTypeMemAllocTensor(const Call& call, DiagnosticContext diag_ctx) { - auto attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "must be MemAllocTensorAttrs , but got " << call->attrs->GetTypeKey(); - if (const auto* output_shape = call->args[1].as()) { - return DynTensorType(output_shape->values.size(), attrs->dtype); - } - return DynTensorType::CreateUnknownNDim(attrs->dtype, Span()); +StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& ctx) { + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "must be MemAllocTensorAttrs, but got " << call->attrs->GetTypeKey(); + ICHECK(GetStructInfoAs(call->args[1])) + << "must be a Expr of ShapeStructInfo, but got " << call->args[1]->GetTypeKey(); + return TensorStructInfo(call->args[1], attrs->dtype); } RELAY_REGISTER_OP("relax.memory.alloc_tensor") @@ -279,8 +299,7 @@ RELAY_REGISTER_OP("relax.memory.alloc_tensor") .set_num_inputs(2) .add_argument("storage", "Expr", "The storage to allocate the tensor to.") .add_argument("shape", "Expr", "The shape of the tensor to allocate.") - .set_attr("FInferShape", InferShapeMemAllocTensor) - .set_attr("FInferType", InferTypeMemAllocTensor); + .set_attr("FInferStructInfo", InferStructInfoMemAllocTensor); Expr MakeMemAllocTensor(Expr storage, Expr shape, int offset, DataType dtype) { auto attrs = make_object(); @@ -297,7 +316,7 @@ TVM_REGISTER_GLOBAL("relax.op.memory.alloc_tensor").set_body_typed(MakeMemAllocT RELAY_REGISTER_OP("relax.memory.kill_storage") .set_num_inputs(1) .add_argument("storage", "Expr", "The storage to be killed.") - .set_attr("FInferType", ReturnVoidType); + .set_attr("FInferStructInfo", ReturnVoidStructInfo); Expr MakeMemKillStorage(Expr storage) { static const Op& op = Op::Get("relax.memory.kill_storage"); @@ -311,7 +330,7 @@ TVM_REGISTER_GLOBAL("relax.op.memory.kill_storage").set_body_typed(MakeMemKillSt RELAY_REGISTER_OP("relax.memory.kill_tensor") .set_num_inputs(1) .add_argument("tensor", "Expr", "The tensor to be killed.") - .set_attr("FInferType", ReturnVoidType); + .set_attr("FInferStructInfo", ReturnVoidStructInfo); Expr MakeMemKillTensor(Expr tensor) { static const Op& op = Op::Get("relax.memory.kill_tensor"); @@ -326,7 +345,7 @@ RELAY_REGISTER_OP("relax.vm.builtin.alloc_storage") .set_attrs_type() .set_num_inputs(1) .add_argument("size", "Expr", "The size of the storage to allocate.") - .set_attr("FInferType", ReturnObjectType); + .set_attr("FInferStructInfo", ReturnObjectStructInfo); Expr MakeVMAllocStorage(Expr size, DataType dtype, int64_t runtime_device_index) { auto attrs = make_object(); @@ -342,13 +361,15 @@ TVM_REGISTER_GLOBAL("relax.op.vm.builtin.alloc_storage").set_body_typed(MakeVMAl Expr InferShapeVMAllocTensor(const Call& call, DiagnosticContext diag_ctx) { return call->args[1]; } -Type InferTypeVMAllocTensor(const Call& call, DiagnosticContext diag_ctx) { +StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& ctx) { auto attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "must be VMAllocTensorAttrs , but got " << call->attrs->GetTypeKey(); + if (const auto* output_shape = call->args[1].as()) { - return DynTensorType(output_shape->values.size(), attrs->dtype); + return TensorStructInfo(GetRef(output_shape), attrs->dtype); } - return DynTensorType::CreateUnknownNDim(attrs->dtype, Span()); + return TensorStructInfo(attrs->dtype, kUnknownDim); } RELAY_REGISTER_OP("relax.vm.builtin.alloc_tensor") @@ -356,8 +377,7 @@ RELAY_REGISTER_OP("relax.vm.builtin.alloc_tensor") .set_num_inputs(2) .add_argument("storage", "Expr", "The storage to allocate the tensor to.") .add_argument("shape", "Expr", "The shape of the tensor to allocate.") - .set_attr("FInferShape", InferShapeVMAllocTensor) - .set_attr("FInferType", InferTypeVMAllocTensor); + .set_attr("FInferStructInfo", InferStructInfoVMAllocTensor); Expr MakeVMAllocTensor(Expr storage, Expr shape, int offset, DataType dtype) { auto attrs = make_object(); @@ -376,7 +396,7 @@ RELAY_REGISTER_OP("relax.vm.builtin.store_shape") .set_num_inputs(2) .add_argument("shape", "Expr", "The shape to be stored.") .add_argument("heap", "Expr", "The heap to store the shape.") - .set_attr("FInferType", ReturnVoidType); + .set_attr("FInferStructInfo", ReturnVoidStructInfo); Expr MakeStoreShape(Expr shape, Expr heap, Array indices) { auto attrs = make_object(); @@ -393,7 +413,7 @@ RELAY_REGISTER_OP("relax.vm.builtin.load_shape") .set_attrs_type() .set_num_inputs(1) .add_argument("heap", "Expr", "The heap to load the shape from.") - .set_attr("FInferType", ReturnShapeType); + .set_attr("FInferStructInfo", ReturnShapeStructInfo); Expr MakeLoadShape(Expr heap, Array indices) { auto attrs = make_object(); @@ -411,7 +431,7 @@ RELAY_REGISTER_OP("relax.vm.call_tir_dyn") .add_argument("func", "Expr", "The destination-passing-style function.") .add_argument("args", "Tuple", "The input arguments (list of tensors and last argument is ShapeExpr)") - .set_attr("FInferType", ReturnVoidType); + .set_attr("FInferStructInfo", ReturnVoidStructInfo); } // namespace relax } // namespace tvm diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 7f97716ec024..a2d3f9b13b5c 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -56,8 +56,7 @@ bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs); .set_num_inputs(2) \ .add_argument("lhs", "Tensor", "The left hand side tensor.") \ .add_argument("rhs", "Tensor", "The right hand side tensor.") \ - .set_attr("FInferShape", InferShapeBinaryBroadcast) \ - .set_attr("FInferType", InferTypeBinaryBroadcast) + .set_attr("FInferStructInfo", InferStructInfoBroadcast) } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index a3933376adc0..e72cfcc0aefd 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -35,23 +35,54 @@ RELAX_REGISTER_BINARY_BROADCAST_OP("multiply") .describe("Elementwise multiply with broadcasting") .set_support_level(1); -Expr InferShapeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { +StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "Binary broadcast op should have 2 arguments"); + ctx->ReportFatal(Diagnostic::Error(call) << "Binary broadcast op should have 2 arguments"); } - Expr lhs_shape = call->args[0]->shape(); - Expr rhs_shape = call->args[1]->shape(); - auto* s0 = lhs_shape.as(); - auto* s1 = rhs_shape.as(); - if (s0 && s1) { + auto* lhs_sinfo = GetStructInfoAs(call->args[0]); + auto* rhs_sinfo = GetStructInfoAs(call->args[1]); + if (!lhs_sinfo || !rhs_sinfo) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Both lhs and rhs should be Tensor for broadcasting, but got " + << call->args[0]->struct_info_->GetTypeKey() << " and " + << call->args[0]->struct_info_->GetTypeKey()); + } + + // DateType + DataType output_dtype; + if (lhs_sinfo->IsUnknownDtype() || rhs_sinfo->IsUnknownDtype()) { + output_dtype = DataType::Void(); + } else if (lhs_sinfo->dtype != rhs_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Data types " << lhs_sinfo->dtype << " and " << rhs_sinfo->dtype + << " must be equal for broadcasting operators"); + } else { + output_dtype = lhs_sinfo->dtype; + } + + // ndims + int output_ndim; + if (lhs_sinfo->IsUnknownNdim() || rhs_sinfo->IsUnknownNdim()) { + output_ndim = kUnknownDim; + } else { + output_ndim = std::max(lhs_sinfo->ndim, rhs_sinfo->ndim); + } + + auto* lhs_shape = lhs_sinfo->shape.as(); + auto* rhs_shape = rhs_sinfo->shape.as(); + // Shapes and ndims + if (lhs_shape && rhs_shape) { + // If all inputs have shapes, directly infer shapes std::vector output_shape; - size_t ndim0 = s0->values.size(); - size_t ndim1 = s1->values.size(); + + size_t lhs_ndim = lhs_sinfo->ndim; + size_t rhs_ndim = rhs_sinfo->ndim; + size_t max_ndim = std::max(lhs_ndim, rhs_ndim); + size_t i = 1; - for (; i <= std::min(ndim0, ndim1); ++i) { - PrimExpr dim0 = s0->values[ndim0 - i]; - PrimExpr dim1 = s1->values[ndim1 - i]; + for (; i <= std::min(lhs_ndim, rhs_ndim); ++i) { + const PrimExpr& dim0 = lhs_shape->values[lhs_ndim - i]; + const PrimExpr& dim1 = rhs_shape->values[rhs_ndim - i]; if (EqualConstInt(dim0, 1)) { output_shape.push_back(dim1); } else if (EqualConstInt(dim1, 1)) { @@ -59,57 +90,19 @@ Expr InferShapeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { } else if (EqualCheck(dim0, dim1)) { output_shape.push_back(dim0); } else { - // defer the computation of output shapes to runtime - // e.g., broadcast Tensor([m, n]), Tensor([k]) -> defer to runtime - Call call_infer(ExternFunc(String("vm.binary_broadcast_shape_infer")), - {call->args[0], call->args[1]}, {}, {}); - call_infer->checked_type_ = ShapeType(); - return call_infer; + // Use simple fallback when shape mismatch. + return TensorStructInfo(output_dtype, /*ndim=*/output_ndim); } } - size_t max_ndim = std::max(ndim0, ndim1); - auto& longer_shape = (ndim0 > ndim1) ? s0 : s1; + auto& longer_shape = (lhs_ndim > rhs_ndim) ? lhs_shape : rhs_shape; for (; i <= max_ndim; ++i) { output_shape.push_back(longer_shape->values[max_ndim - i]); } - return ShapeExpr(Array(output_shape.rbegin(), output_shape.rend())); - } - return RuntimeDepShape(); -} - -Type InferTypeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { - if (call->args.size() != 2) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "Binary broadcast op should have 2 arguments"); - } - Type lhs_type = call->args[0]->checked_type(); - Type rhs_type = call->args[1]->checked_type(); - auto* t0 = lhs_type.as(); - auto* t1 = rhs_type.as(); - if (!t0 || !t1) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "Both lhs and rhs should be DynTensor for broadcasting, but got " - << lhs_type->GetTypeKey() << " and " << rhs_type->GetTypeKey()); - } - - DataType output_dtype; - if (t0->IsUnknownDtype() || t1->IsUnknownDtype()) { - output_dtype = DataType::Void(); - } else if (t0->dtype != t1->dtype) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "Data types " << t0->dtype << " and " << t1->dtype - << " must be equal for broadcasting operators"); - } else { - output_dtype = t0->dtype; - } - - int output_ndim; - if (t0->IsUnknownNdim() || t1->IsUnknownNdim()) { - output_ndim = -1; + Expr shape = ShapeExpr(Array(output_shape.rbegin(), output_shape.rend())); + return TensorStructInfo(shape, output_dtype); } else { - output_ndim = std::max(t0->ndim, t1->ndim); + return TensorStructInfo(output_dtype, /*ndim=*/output_ndim); } - return DynTensorType(output_ndim, output_dtype); } } // namespace relax diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index 30b12f3a0d39..52bba8e22b9a 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -36,8 +36,7 @@ namespace tvm { namespace relax { -Type InferTypeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx); -Expr InferShapeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx); +StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index fbd34b825d79..d13d23baabd2 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -27,24 +27,9 @@ namespace tvm { namespace relax { -RELAY_REGISTER_OP("relax.ewise_fma") - .set_num_inputs(3) - .add_argument("e1", "Expr", "The input expression") - .add_argument("e2", "Expr", "The input expression") - .add_argument("e3", "Expr", "The input expression") - .set_attr("FInferShape", InferShapeEwiseFMA) - .set_attr("FInferType", InferTypeEwiseFMA); - -Expr MakeEwiseFma(Expr expr1, Expr expr2, Expr expr3) { - static const Op& op = Op::Get("relax.ewise_fma"); - return Call(op, {expr1, expr2, expr3}, {}, {}); -} - -TVM_REGISTER_GLOBAL("relax.op.ewise_fma").set_body_typed(MakeEwiseFma); - Type InferTypeEwiseFMA(const Call& call, DiagnosticContext diag_ctx) { if (call->args.size() != 3) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "EwiseFMA op should have 3 arguments"); + diag_ctx.EmitFatal(Diagnostic::Error(call) << "EwiseFMA op should have 3 arguments"); } Type type0 = call->args[0]->checked_type(); Type type1 = call->args[1]->checked_type(); @@ -53,7 +38,7 @@ Type InferTypeEwiseFMA(const Call& call, DiagnosticContext diag_ctx) { auto* t1 = type1.as(); auto* t2 = type2.as(); if (!t0 || !t1 || !t2) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) + diag_ctx.EmitFatal(Diagnostic::Error(call) << "The 3 arguments of EwiseFMA should be DynTensor"); } @@ -61,7 +46,7 @@ Type InferTypeEwiseFMA(const Call& call, DiagnosticContext diag_ctx) { if (t0->IsUnknownDtype() || t1->IsUnknownDtype() || t2->IsUnknownDtype()) { output_dtype = DataType::Void(); } else if (t0->dtype != t1->dtype || t1->dtype != t2->dtype) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) + diag_ctx.EmitFatal(Diagnostic::Error(call) << "Data types " << t0->dtype << ", " << t1->dtype << ", and " << t2->dtype << " must be equal for EwiseFMA"); } else { @@ -77,27 +62,41 @@ Type InferTypeEwiseFMA(const Call& call, DiagnosticContext diag_ctx) { return DynTensorType(output_ndim, output_dtype); } -Expr InferShapeEwiseFMA(const Call& call, DiagnosticContext diag_ctx) { +StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "EwiseFMA op should have 3 arguments"); + ctx->ReportFatal(Diagnostic::Error(call) << "EwiseFMA op should have 3 arguments"); + } + + auto* t0 = GetStructInfoAs(call->args[0]); + auto* t1 = GetStructInfoAs(call->args[1]); + auto* t2 = GetStructInfoAs(call->args[2]); + + if (!t0 || !t1 || !t2) { + ctx->ReportFatal(Diagnostic::Error(call) << "EwiseFMA expects three tensor inputs"); } - Expr shape0 = call->args[0]->shape(); - Expr shape1 = call->args[1]->shape(); - Expr shape2 = call->args[2]->shape(); - auto* s0 = shape0.as(); - auto* s1 = shape1.as(); - auto* s2 = shape2.as(); + + DataType output_dtype; + if (t0->IsUnknownDtype() || t1->IsUnknownDtype() || t2->IsUnknownDtype()) { + output_dtype = DataType::Void(); + } else if (t0->dtype != t1->dtype || t1->dtype != t2->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Data types " << t0->dtype << ", " << t1->dtype << ", and " << t2->dtype + << " must be equal for EwiseFMA"); + } else { + output_dtype = t0->dtype; + } + + auto* s0 = t0->shape.as(); + auto* s1 = t1->shape.as(); + auto* s2 = t2->shape.as(); if (s0 && s1 && s2) { - std::vector output_shape; + Array output_shape; size_t ndim0 = s0->values.size(); size_t ndim1 = s1->values.size(); size_t ndim2 = s2->values.size(); if (ndim0 != ndim1 || ndim1 != ndim2) { - LOG(INFO) << ndim0; - LOG(INFO) << ndim1; - LOG(INFO) << ndim2; - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "The 3 arguments of EwiseFMA must have the same number of dimensions"); + ctx->ReportFatal(Diagnostic::Error(call) + << "The 3 arguments of EwiseFMA must have the same number of dimensions"); } for (size_t i = 0; i < ndim0; ++i) { PrimExpr dim0 = s0->values[i]; @@ -106,14 +105,35 @@ Expr InferShapeEwiseFMA(const Call& call, DiagnosticContext diag_ctx) { if (EqualCheck(dim0, dim1) && EqualCheck(dim1, dim2)) { output_shape.push_back(dim0); } else { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "The 3 arguments of EwiseFMA must have the same shape"); + ctx->ReportFatal(Diagnostic::Error(call) + << "The 3 arguments of EwiseFMA must have the same shape"); } } - return ShapeExpr(Array(output_shape.begin(), output_shape.end())); + return TensorStructInfo(ShapeExpr(output_shape), output_dtype); } - return RuntimeDepShape(); + + int output_ndim; + if (t0->IsUnknownNdim() || t1->IsUnknownNdim() || t2->IsUnknownNdim()) { + output_ndim = kUnknownDim; + } else { + output_ndim = t0->ndim; + } + return TensorStructInfo(output_dtype, output_ndim); +} + +RELAY_REGISTER_OP("relax.ewise_fma") + .set_num_inputs(3) + .add_argument("e1", "Expr", "The input expression") + .add_argument("e2", "Expr", "The input expression") + .add_argument("e3", "Expr", "The input expression") + .set_attr("FInferStructInfo", InferStructInfoEwiseFMA); + +Expr MakeEwiseFma(Expr expr1, Expr expr2, Expr expr3) { + static const Op& op = Op::Get("relax.ewise_fma"); + return Call(op, {expr1, expr2, expr3}, {}, {}); } +TVM_REGISTER_GLOBAL("relax.op.ewise_fma").set_body_typed(MakeEwiseFma); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/ternary.h b/src/relax/op/tensor/ternary.h index 7abfb4c02ffe..6b019f7829fa 100644 --- a/src/relax/op/tensor/ternary.h +++ b/src/relax/op/tensor/ternary.h @@ -37,7 +37,7 @@ namespace relax { /* relax.ewise_fma */ Type InferTypeEwiseFMA(const Call& call, DiagnosticContext diag_ctx); -Expr InferShapeEwiseFMA(const Call& call, DiagnosticContext diag_ctx); +StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index 7f194b9b49e1..57e15454abd6 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -29,17 +29,6 @@ namespace relax { TVM_REGISTER_NODE_TYPE(UniqueAttrs); -RELAY_REGISTER_OP("relax.unique") - .describe( - "This operation returns the unique elements and the new index of each item in a given " - "tensor.") - .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor") - .set_attrs_type() - .set_attr("FInferShape", InferShapeUnique) - .set_attr("FInferType", InferTypeUnique) - .set_attr("FCallPacked", "relax.run.unique"); - Expr MakeUnique(Expr data, bool sorted, bool return_inverse, bool return_counts, int dim) { auto attrs = make_object(); attrs->sorted = sorted; @@ -52,37 +41,37 @@ Expr MakeUnique(Expr data, bool sorted, bool return_inverse, bool return_counts, TVM_REGISTER_GLOBAL("relax.op.unique").set_body_typed(MakeUnique); -Type InferTypeUnique(const Call& call, DiagnosticContext diag_ctx) { +StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Unique op should have 1 argument"); + ctx->ReportFatal(Diagnostic::Error(call) << "Unique op should have 1 argument"); } - auto* input_ty = call->args[0]->checked_type().as(); - if (!input_ty) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "Input should be DynTensor, but got " - << call->args[0]->checked_type()->GetTypeKey()); - } - - // TODO(prakalp): Add support for return_inverse, return_counts and dim attributes. Only defaults - // are supported right now. auto unique_attrs = call->attrs.as(); - if (unique_attrs->return_counts || unique_attrs->return_inverse || unique_attrs->dim != -1) - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "support for return_inverse, return_counts, and dim is not implemented"); - return DynTensorType(/*ndim=*/1, input_ty->dtype); -} -Expr InferShapeUnique(const Call& call, DiagnosticContext diag_ctx) { - if (call->args.size() != 1) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Unique op should have 1 argument"); + auto input_sinfo = GetStructInfoAs(call->args[0]); + + if (!input_sinfo) { + ctx->ReportFatal(Diagnostic::Error(call) << "Input should be Tensor, but got " + << call->args[0]->struct_info_->GetTypeKey()); } - auto unique_attrs = call->attrs.as(); + // Only default values of these attributes are supported right now. - if (unique_attrs->return_counts || unique_attrs->return_inverse || unique_attrs->dim != -1) - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "support for return_inverse, return_counts, and dim is not implemented"); - return RuntimeDepShape(call->span); + if (unique_attrs->return_counts || unique_attrs->return_inverse || unique_attrs->dim != -1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "support for return_inverse, return_counts, and dim is not implemented"); + } + + return TensorStructInfo(input_sinfo->dtype, /*ndim=*/1); } +RELAY_REGISTER_OP("relax.unique") + .describe( + "This operation returns the unique elements and the new index of each item in a given " + "tensor.") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoUnique) + .set_attr("FCallPacked", "relax.run.unique"); + } // namespace relax } // namespace tvm diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 87cc7d619188..4ebc5d957401 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -447,8 +447,7 @@ class FunctionCreator : public ExprMutator { attrs.Set(tvm::relax::attr::kPrimitive, Integer(1)); function_ = Function(/*params=*/params_, // /*body=*/body, // - /*ret_type=*/body->checked_type_, - /*ret_shape=*/RuntimeDepShape(), + Type(), Expr(), /*attrs=*/DictAttrs(attrs)); } diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index f6a58cce1264..17a44bece0a9 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -96,8 +96,8 @@ class LambdaLifter : public ExprMutator { Array typed_captured_vars; Map rebinding_map; for (auto free_var : captured_vars) { - Var var = Var(free_var->name_hint(), NullOpt, free_var->checked_type_, free_var->span); - var->shape_ = free_var->shape_; + Var var = Var(free_var->name_hint(), NullOpt, NullOpt, free_var->span); + UpdateStructInfo(var, GetStructInfo(free_var)); typed_captured_vars.push_back(var); rebinding_map.Set(free_var, var); } @@ -180,8 +180,8 @@ class LambdaLifter : public ExprMutator { ICHECK(lifted_func.defined()); // Add the lifted function to the module. + UpdateStructInfo(global, GetStructInfo(lifted_func)); builder_->UpdateFunction(global, lifted_func); - UpdateType(global, lifted_func->checked_type()); if (!is_closure) { return std::move(global); diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 8beb2b6b5a58..ed42e616ba1f 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -26,6 +26,7 @@ #include #include +#include #include namespace tvm { @@ -41,7 +42,7 @@ class NormalizeMutator : public ExprMutatorBase { } Expr VisitExpr_(const FunctionNode* op) { - Expr body = this->VisitWithNewScope(op->body); + Expr body = this->VisitWithNewScope(op->body, op->params); if (body.same_as(op->body)) { return GetRef(op); @@ -62,13 +63,15 @@ class NormalizeMutator : public ExprMutatorBase { } } - Expr VisitWithNewScope(const Expr& expr) { + Expr VisitWithNewScope(const Expr& expr, Optional> params = NullOpt) { builder_->BeginBindingBlock(); + builder_->BeginScope(params); Expr ret = this->VisitExpr(expr); BindingBlock prologue = builder_->EndBlock(); if (!prologue->bindings.empty()) { ret = SeqExpr({prologue}, ret); } + builder_->EndScope(); return ret; } @@ -146,12 +149,10 @@ class NormalizeMutator : public ExprMutatorBase { }; Expr new_value = this->VisitExpr(binding->value); - if (!binding->var->checked_type_.defined()) { - UpdateType(binding->var, new_value->checked_type_); - } - if (!binding->var->shape_.defined()) { - UpdateShape(binding->var, new_value->shape_); + if (!binding->var->struct_info_.defined()) { + UpdateStructInfo(binding->var, GetStructInfo(new_value)); } + if (new_value.same_as(binding->value)) { emit(GetRef(binding)); } else { @@ -163,11 +164,8 @@ class NormalizeMutator : public ExprMutatorBase { Expr new_value = this->VisitExpr(binding->value); if (binding->var.defined()) { - if (!binding->var->checked_type_.defined()) { - UpdateType(binding->var, new_value->checked_type_); - } - if (!binding->var->shape_.defined()) { - UpdateShape(binding->var, new_value->shape_); + if (!binding->var->struct_info_.defined()) { + UpdateStructInfo(binding->var, GetStructInfo(new_value)); } } if (new_value.same_as(binding->value)) { diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc index c7c26278057e..72b227bac73e 100644 --- a/src/relax/transform/to_non_dataflow.cc +++ b/src/relax/transform/to_non_dataflow.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -33,8 +34,8 @@ class ToNonDFMutator : public ExprMutator { public: Var VisitVarDef(const Var& var) final { if (var.as()) { - Var new_var = Var(var->vid, NullOpt, var->checked_type_, var->span); - UpdateShape(new_var, var->shape_); + Var new_var = Var(var->vid, NullOpt, NullOpt, var->span); + UpdateStructInfo(new_var, GetStructInfo(var)); this->var_remap_[var->vid] = new_var; return new_var; } diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index ddbddd4b1de1..e8ffbbf577bb 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -39,6 +40,22 @@ GlobalVar DeclFunction(const String& func_name, const Optional& func_s CHECK(!frame->global_var_map.count(func_name)) << "ValueError: function " << func_name << " already exists"; GlobalVar gv = GlobalVar(func_name); + if (func_signature.defined()) { + const BaseFunc& func = func_signature.value(); + if (func->struct_info_.defined()) { + gv->struct_info_ = tvm::relax::GetStructInfo(func); + } else if (const auto* prim_func = func.as()) { + // NOTE: use a slightly different struct info than checked type + // in PrimFunc so handle can turn into Tensor. + // TODO(relax-team): add fine-grained PrimFunc struct info signature generation. + gv->struct_info_ = tvm::relax::FuncStructInfo::OpaqueFunc( + tvm::relax::StructInfoFromType(prim_func->ret_type)); + } else { + LOG(FATAL) << "Unsupported function type: " << func->GetTypeKey(); + } + } else { + gv->struct_info_ = tvm::relax::FuncStructInfo::OpaqueFunc(); + } CHECK(frame->functions.find(gv) == frame->functions.end()) << "ValueError: function " << func_name << " has already been defined."; frame->global_var_map.Set(func_name, gv); diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 4d4b4609d00e..a0bfd3446092 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -54,7 +54,7 @@ void FunctionFrameNode::ExitWithScope() { // Step 1: Create the function. CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use " "`return` to return an Expr"; - + this->block_builder->BeginScope(params); Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); Expr func_shape = ret_shape.value_or(tvm::relax::RuntimeDepShape()); if (func_shape->IsInstance()) { @@ -63,6 +63,7 @@ void FunctionFrameNode::ExitWithScope() { // func_shape = tvm::relax::DeriveFuncRetShape(params, body); } auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); + this->block_builder->EndScope(); tvm::relax::Function func(/*params=*/params, /*body=*/body, /*ret_type=*/ret_type.value_or(Type()), diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 1aa39ee62636..5470fb091a73 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -49,16 +50,10 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) ////////////////////////////// Tensor Type ////////////////////////////// -ShapedType::ShapedType(Type type, Optional shape) { - auto n = make_object(); - n->type = std::move(type); - n->shape = std::move(shape); - data_ = std::move(n); -} - -TVM_REGISTER_NODE_TYPE(ShapedTypeNode); +using tvm::relax::TensorStructInfo; +using tvm::relax::TupleStructInfo; -ShapedType Tensor(Optional> shape, DataType dtype, int ndim) { +TensorStructInfo Tensor(Optional> shape, DataType dtype, int ndim) { using namespace tvm::relax; ICHECK_GE(ndim, -1) << "ndim must be >= -1, but got " << ndim; if (shape.defined() && ndim >= 0) { @@ -67,36 +62,15 @@ ShapedType Tensor(Optional> shape, DataType dtype, int ndim) { } else if (shape.defined()) { ndim = shape.value().size(); } - Expr shape_expr = RuntimeDepShape(); if (shape.defined()) { - shape_expr = ShapeExpr(shape.value()); - } - return ShapedType(DynTensorType(ndim, dtype), shape_expr); -} - -ShapedType CreateShapedTuple(Array types, Array> shapes) { - CHECK_EQ(types.size(), shapes.size()) - << "ValueError: The number of types and shapes mismatched, got " << types.size() << " vs " - << shapes.size(); - Array _shapes; - bool has_none_shape = false; - for (const auto& shape : shapes) { - if (shape.defined()) { - _shapes.push_back(shape.value()); - } else { - has_none_shape = true; - break; - } - } - Optional final_shape = NullOpt; - if (!has_none_shape) { - final_shape = tvm::relax::Tuple(_shapes); + ShapeExpr shape_expr(shape.value()); + return TensorStructInfo(shape_expr, dtype); + } else { + return TensorStructInfo(dtype, ndim); } - return ShapedType(TupleType(types), final_shape); } TVM_REGISTER_GLOBAL("script.ir_builder.relax.Tensor").set_body_typed(Tensor); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.CreateShapedTuple").set_body_typed(CreateShapedTuple); /////////////////////////////// Function //////////////////////////////// @@ -111,9 +85,11 @@ FunctionFrame Function() { return FunctionFrame(n); } -tvm::relax::Var Arg(const String& name, const Type& type, const tvm::relax::Expr& shape) { +tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_info) { FunctionFrame frame = FindFunctionFrame("R.Arg"); - tvm::relax::Var var(name, shape, type); + // TODO(relax-team): Update constructor to include struct info as argument. + tvm::relax::Var var(name, NullOpt, NullOpt); + UpdateStructInfo(var, struct_info); frame->params.push_back(var); return var; } @@ -135,22 +111,15 @@ void FuncAttrs(Map attrs) { frame->attrs = attrs; } -void FuncRetType(tvm::Type ret_type) { - FunctionFrame frame = FindFunctionFrame("R.ret_type"); - if (frame->ret_type.defined()) { - LOG(FATAL) << "ValueError: Duplicate function return type, previous one is:\n " - << frame->ret_type.value(); - } - frame->ret_type = ret_type; -} - -void FuncRetShape(tvm::relax::Expr ret_shape) { - FunctionFrame frame = FindFunctionFrame("R.ret_shape"); - if (frame->ret_shape.defined()) { - LOG(FATAL) << "ValueError: Duplicate function return type, previous one is:\n " - << frame->ret_type.value(); +void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) { + FunctionFrame frame = FindFunctionFrame("R.func_ret_struct_info"); + if (frame->ret_sinfo.defined()) { + LOG(FATAL) << "ValueError: Duplicate function return struct info, previous one is:\n " + << frame->ret_sinfo.value(); } - frame->ret_shape = ret_shape; + frame->ret_sinfo = ret_sinfo; + frame->ret_type = GetStaticType(ret_sinfo); + frame->ret_shape = GetLegacyShapeHint(ret_sinfo); } void FuncRetValue(const tvm::relax::Expr& value) { @@ -180,8 +149,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function) TVM_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg); TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName); TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetType").set_body_typed(FuncRetType); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetShape").set_body_typed(FuncRetShape); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetStructInfo").set_body_typed(FuncRetStructInfo); TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue); ///////////////////////////// BindingBlock ////////////////////////////// @@ -264,37 +232,56 @@ TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchShape").set_body_typed(Emi ///////////////////////////// Type Deduce ////////////////////////////// -void AnnotateTypeShape(const tvm::relax::Var& var, const Type& anno_type, - const Optional& anno_shape) { +void AnnotateStructInfo(const tvm::relax::Var& var, + const tvm::relax::StructInfo& anno_struct_info) { using tvm::relax::IsBaseOf; - if (var->checked_type_.defined()) { - const Type& var_type = var->checked_type(); - CHECK(IsBaseOf(anno_type, var_type) || IsBaseOf(var_type, anno_type)) - << "TypeError: The annotated type and value type are not compatible. " - << "The Type is expected to be " << var_type << " but got annotation: " << anno_type; - } - - if (var->shape_.defined() && anno_shape.defined()) { - tvm::relax::Expr var_shape = Downcast(var->shape_.value()); - auto check_shape = [](const tvm::relax::Expr& lhs, const tvm::relax::Expr& rhs) { - if (lhs->IsInstance() || - rhs->IsInstance()) { - return true; - } else { - const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); - return block_builder->CanProveShapeEqual(lhs, rhs); - } - }; - CHECK(check_shape(var_shape, anno_shape.value())) - << " The shape of var " << var->name_hint() << " is expected to be " << var_shape - << " but got annotation: " << anno_shape.value(); + using tvm::relax::StructInfo; + + Type type = GetStaticType(anno_struct_info); + Optional shape = GetLegacyShapeHint(anno_struct_info); + + // TODO(siyuan, ruihang): Revisit the checks aftr we fully migrate to struct info. + // consider simplify assumption to require var not contain struct info at all. + if (var->struct_info_.defined()) { + // Case 1. The var already has struct info. + const StructInfo& var_struct_info = Downcast(var->struct_info_.value()); + CHECK(IsBaseOf(var_struct_info, anno_struct_info) || + IsBaseOf(anno_struct_info, var_struct_info)) + << "ValueError: The annotation struct info is not a base of the existing struct info."; + } else { + // Case 2. The var doesn't have struct info. + // This path may be removed later + + if (var->checked_type_.defined()) { + const Type& var_type = var->checked_type(); + CHECK(IsBaseOf(type, var_type) || IsBaseOf(var_type, type)) + << "TypeError: The annotated type and value type are not compatible. " + << "The Type is expected to be " << var_type << " but got annotation: " << type; + } + if (var->shape_.defined() && shape.defined()) { + tvm::relax::Expr var_shape = Downcast(var->shape_.value()); + auto check_shape = [](const tvm::relax::Expr& lhs, const tvm::relax::Expr& rhs) { + if (lhs->IsInstance() || + rhs->IsInstance()) { + return true; + } else { + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + return block_builder->CanProveShapeEqual(lhs, rhs); + } + }; + CHECK(check_shape(var_shape, shape.value())) + << " The shape of var " << var->name_hint() << " is expected to be " << var_shape + << " but got annotation: " << shape.value(); + } } - var->checked_type_ = anno_type; - var->shape_ = anno_shape; + var->checked_type_ = type; + var->shape_ = shape; + var->struct_info_ = anno_struct_info; } -TVM_REGISTER_GLOBAL("script.ir_builder.relax.AnnotateTypeShape").set_body_typed(AnnotateTypeShape); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.AnnotateStructInfo") + .set_body_typed(AnnotateStructInfo); ///////////////////////////// If Then Else ///////////////////////////// @@ -320,6 +307,55 @@ TVM_REGISTER_GLOBAL("script.ir_builder.relax.If").set_body_typed(If); TVM_REGISTER_GLOBAL("script.ir_builder.relax.Then").set_body_typed(Then); TVM_REGISTER_GLOBAL("script.ir_builder.relax.Else").set_body_typed(Else); +//////////////////////// Symbolic Shape Rewriter //////////////////////// + +using tvm::relax::Expr; +class SymbolicShapeRewriter : public tvm::relax::StructInfoMutator { + public: + explicit SymbolicShapeRewriter(Map var_table) + : var_table_(std::move(var_table)) {} + + Array undefined_vars_; + + private: + Expr VisitStructInfoExprField(const Expr& expr) { + if (const auto* shape_expr = expr.as()) { + Array new_shape; + bool changed = false; + for (const tvm::PrimExpr& s : shape_expr->values) { + if (const auto* var = s.as()) { + auto it = var_table_.find(var->name_hint); + if (it != var_table_.end()) { + new_shape.push_back((*it).second); + changed = true; + } else { + undefined_vars_.push_back(GetRef(var)); + var_table_.Set(var->name_hint, GetRef(var)); + new_shape.push_back(s); + } + } else { + // TODO(siyuan, ruihang): confirm and use VisitPrimExpr to recursive rewrite. + new_shape.push_back(s); + } + } + if (changed) { + return tvm::relax::ShapeExpr(new_shape); + } + } + return expr; + } + + private: + Map var_table_; +}; + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.RewriteSymbolicShape") + .set_body_typed([](tvm::relax::StructInfo struct_info, Map var_table) { + SymbolicShapeRewriter rewriter(var_table); + tvm::relax::StructInfo rewritten_info = rewriter(std::move(struct_info)); + return Array{rewritten_info, rewriter.undefined_vars_}; + }); + } // namespace relax } // namespace ir_builder } // namespace script diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h index 3e31dd8d224f..3a72eb8e364a 100644 --- a/src/script/ir_builder/relax/utils.h +++ b/src/script/ir_builder/relax/utils.h @@ -19,6 +19,7 @@ #ifndef TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ #define TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ +#include #include #include diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py new file mode 100644 index 000000000000..9776c0a0b145 --- /dev/null +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -0,0 +1,562 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests analysis functions of struct info""" + +import pytest +import tvm +import tvm.testing +from tvm import relax as rx, TVMError +from tvm import tir + + +def test_get_static_type_basic(): + # object + s0 = rx.ObjectStructInfo() + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s0), rx.ObjectType()) + + # prim + s1 = rx.PrimStructInfo("float32") + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s1), tvm.ir.PrimType("float32")) + + +def test_get_static_type_shape(): + # shape + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s2 = rx.ShapeStructInfo([1, n + 1, m]) + s3 = rx.ShapeStructInfo(ndim=2) + + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s2), rx.ShapeType(ndim=3)) + + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s3), rx.ShapeType(ndim=2)) + + +def test_get_static_type_tensor(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + s4 = rx.TensorStructInfo([1, n + 1, m], "int64") + + tvm.ir.assert_structural_equal( + rx.analysis.get_static_type(s4), rx.DynTensorType(ndim=3, dtype="int64") + ) + + +def test_get_static_type_tuple(): + # tuple + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + s0 = rx.ObjectStructInfo() + s2 = rx.ShapeStructInfo([1, n + 1, m]) + s4 = rx.TensorStructInfo([1, n + 1, m], "int64") + t0 = rx.TupleStructInfo([s4, s0]) + t1 = rx.TupleStructInfo([t0, s2]) + + tvm.ir.assert_structural_equal( + rx.analysis.get_static_type(t1), + rx.TupleType( + [ + rx.TupleType([rx.DynTensorType(ndim=3, dtype="int64"), rx.ObjectType()]), + rx.ShapeType(ndim=3), + ] + ), + ) + + +def test_get_static_type_func(): + # tuple + def fn_info(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + def fn_type(): + x = rx.DynTensorType(ndim=3, dtype="float32") + y = rx.DynTensorType(ndim=3, dtype="float32") + z = rx.DynTensorType(ndim=2, dtype="float32") + return rx.FuncType([x, y], z) + + f0 = fn_info(1) + + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(fn_info(1)), fn_type()) + + +def test_erase_to_well_defined_basic(): + s0 = rx.ObjectStructInfo() + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s0), s0) + + # prim + s1 = rx.PrimStructInfo("float32") + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1), s1) + + +def test_erase_to_well_defined_shape(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s2 = rx.ShapeStructInfo([1, n + 1, m]) + s3 = rx.ShapeStructInfo(ndim=2) + # have undefined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s2), rx.ShapeStructInfo(ndim=3) + ) + # all defined + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2, {n: n, m: m}), s2) + + # replacement + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s2, {n: 2, m: m + 1}), rx.ShapeStructInfo([1, 3, m + 1]) + ) + + # partial defined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s2, {n: n}), rx.ShapeStructInfo(ndim=3) + ) + + +def test_erase_to_well_defined_tensor(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + rshape = rx.Var("shape", type_annotation=rx.ShapeType(ndim=2)) + s0 = rx.TensorStructInfo(rshape, dtype="int32") + + # undefined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s0, None, None), + rx.TensorStructInfo(ndim=2, dtype="int32"), + ) + + # defined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s0, None, {rshape: rshape}), s0 + ) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s0, None, {rshape: rx.ShapeExpr([1, 2])}), + rx.TensorStructInfo([1, 2], dtype="int32"), + ) + + s1 = rx.TensorStructInfo([m + 1, n], dtype="float32") + + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1, {n: n, m: m}), s1) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s1, {n: 2, m: 3}), + rx.TensorStructInfo([4, 2], dtype="float32"), + ) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s1, {m: m}, {rshape: rshape}), + rx.TensorStructInfo(ndim=2, dtype="float32"), + ) + + s2 = rx.TensorStructInfo([1, 2], dtype="float32") + + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2), s2) + + +def test_erase_to_well_defined_tuple(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + s0 = rx.ObjectStructInfo() + s2 = rx.ShapeStructInfo([1, m]) + s4 = rx.TensorStructInfo([1, n + 1, m], "int64") + t0 = rx.TupleStructInfo([s4, s0]) + t1 = rx.TupleStructInfo([t0, s2]) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(t1, {m: m + 1}), + rx.TupleStructInfo( + [ + rx.TupleStructInfo( + [rx.TensorStructInfo(ndim=3, dtype="int64"), rx.ObjectStructInfo()] + ), + rx.ShapeStructInfo([1, m + 1]), + ] + ), + ) + + +def test_erase_to_well_defined_func(): + def fn_info(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + f0 = fn_info(1) + + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(f0), f0) + + +def test_base_check(): + BR = rx.analysis.BaseCheckResult + bcheck = rx.analysis.struct_info_base_check + + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + obj0 = rx.ObjectStructInfo() + prim0 = rx.PrimStructInfo("int32") + prim1 = rx.PrimStructInfo("float32") + + shape0 = rx.ShapeStructInfo(ndim=-1) + shape1 = rx.ShapeStructInfo(ndim=2) + shape2 = rx.ShapeStructInfo(ndim=3) + shape3 = rx.ShapeStructInfo([1, 2, 3]) + shape4 = rx.ShapeStructInfo([1, n, 3]) + + tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32") + tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32") + tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32") + tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32") + tensor4 = rx.TensorStructInfo([n, m], "int32") + tensor5 = rx.TensorStructInfo([n, m, 1], "int32") + tensor6 = rx.TensorStructInfo([n, m, 2], "int32") + + # obj + assert bcheck(obj0, prim0) == BR.PASS + assert bcheck(obj0, shape1) == BR.PASS + assert bcheck(obj0, tensor2) == BR.PASS + assert obj0.is_base_of(tensor2) + + # prim + assert prim0.is_base_of(prim0) + assert not prim0.is_base_of(prim1) + assert bcheck(prim0, obj0) == BR.FAIL_L1 + assert bcheck(prim0, prim0) == BR.PASS + assert bcheck(prim0, prim1) == BR.FAIL_L0 + + # shape + assert bcheck(shape0, obj0) == BR.FAIL_L1 + assert bcheck(shape0, prim0) == BR.FAIL_L0 + + # unknown dim + assert bcheck(shape0, shape1) == BR.PASS + assert bcheck(shape1, shape0) == BR.FAIL_L1 + + # ndim mismatch + assert bcheck(shape1, shape2) == BR.FAIL_L0 + + # lhs do not have symbolic value but ndim match + assert bcheck(shape2, shape3) == BR.PASS + + # rhs do not symbolic but lhs do + assert bcheck(shape3, shape2) == BR.FAIL_L2 + + # shape mismatch + assert bcheck(shape3, shape4) == BR.FAIL_L2 + assert shape4.is_base_of(rx.ShapeStructInfo([1, n, 3])) + + # tensor + assert bcheck(tensor0, obj0) == BR.FAIL_L1 + assert bcheck(tensor0, prim0) == BR.FAIL_L0 + assert bcheck(tensor0, shape0) == BR.FAIL_L0 + + # dtype mismatch + assert bcheck(tensor0, tensor1) == BR.FAIL_L0 + assert bcheck(tensor0, tensor3) == BR.FAIL_L0 + assert bcheck(tensor3, tensor4) == BR.FAIL_L0 + assert bcheck(tensor1, tensor2) == BR.FAIL_L0 + + # ndim mismatch + assert bcheck(tensor2, tensor5) == BR.FAIL_L0 + + # static shape mismatch + assert bcheck(tensor5, tensor6) == BR.FAIL_L0 + + # match + assert tensor0.is_base_of(rx.TensorStructInfo(ndim=-1, dtype="int32")) + assert tensor0.is_base_of(tensor2) + assert tensor0.is_base_of(tensor4) + assert tensor0.is_base_of(tensor5) + assert tensor0.is_base_of(tensor6) + assert tensor2.is_base_of(tensor4) + assert tensor4.is_base_of(rx.TensorStructInfo([n, m], dtype="int32")) + + # tuple + t0 = rx.TupleStructInfo([obj0, tensor0]) + t1 = rx.TupleStructInfo([prim0, tensor4]) + t2 = rx.TupleStructInfo([obj0, tensor0, obj0]) + t3 = rx.TupleStructInfo([tensor0, obj0]) + + assert t0.is_base_of(t1) + + assert bcheck(t0, t2) == BR.FAIL_L0 + assert bcheck(t0, t3) == BR.FAIL_L1 + + assert rx.TupleStructInfo([t0, t1]).is_base_of(rx.TupleStructInfo([t1, t1])) + assert bcheck(rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t0])) == BR.FAIL_L1 + + def fn_info_shape(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + def fn_info_erased(): + x = rx.TensorStructInfo(ndim=3, dtype="float32") + y = rx.TensorStructInfo(ndim=3, dtype="float32") + z = rx.TensorStructInfo(ndim=2, dtype="float32") + return rx.FuncStructInfo([x, y], z) + + assert fn_info_shape(1).is_base_of(fn_info_shape(1)) + assert fn_info_erased().is_base_of(fn_info_shape(1)) + assert bcheck(fn_info_shape(1), fn_info_erased()) == BR.FAIL_L2 + + fopaque = rx.FuncStructInfo.opaque_func() + assert fopaque.is_base_of(fn_info_shape(1)) + + +def _check_derive(ctx, finfo, args_sinfo, ret): + gv = rx.GlobalVar("test") + rx.expr._update_struct_info(gv, finfo) + args = [] + for i, sinfo in enumerate(args_sinfo): + arg = rx.Var("arg%i" % i) + rx.expr._update_struct_info(arg, sinfo) + args.append(arg) + call = rx.Call(gv, args) + derived_ret = rx.analysis.derive_call_ret_struct_info(finfo, call, ctx) + tvm.ir.assert_structural_equal(ret, derived_ret) + + +def test_derive_call_ret_struct_info(): + obj0 = rx.ObjectStructInfo() + prim0 = rx.PrimStructInfo("float32") + + n, m = tir.Var("n0", "int64"), tir.Var("m0", "int64") + bb = rx.BlockBuilder() + # derivation cases + with bb.testing_scope(def_vars=[n, m]): + + def func0(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([n, m], "float32") + z = rx.TensorStructInfo([m + c, n], "float32") + return rx.FuncStructInfo([x], z) + + # Tensor => Tensor + _check_derive( + bb, + func0(1), + [rx.TensorStructInfo([10, 11], "float32")], + rx.TensorStructInfo([12, 10], "float32"), + ) + + _check_derive( + bb, + func0(2), + [rx.TensorStructInfo([n, m], "float32")], + rx.TensorStructInfo([m + 2, n], "float32"), + ) + + # passing in information that cannot deduce n, m + # it is still OK as type still matches, return an + # eriased output + _check_derive( + bb, + func0(2), + [rx.TensorStructInfo(ndim=2, dtype="float32")], + rx.TensorStructInfo(ndim=2, dtype="float32"), + ) + + # Error: wrong number of arguments + with pytest.raises(TVMError): + _check_derive( + bb, + func0(2), + [rx.TensorStructInfo(ndim=2, dtype="float32"), obj0], + rx.TensorStructInfo(ndim=2, dtype="float32"), + ) + + # Error:type mismatch + with pytest.raises(TVMError): + _check_derive(bb, func0(2), [obj0], obj0) + + # opaque derivation + fopaque0 = lambda: rx.FuncStructInfo.opaque_func() + fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0) + _check_derive(bb, fopaque0(), [obj0, prim0], obj0) + _check_derive(bb, fopaque1(), [obj0, prim0], prim0) + + # recursive tuple derivation + def func_tuple0(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x0 = rx.TensorStructInfo([n, c], "float32") + x1 = rx.TensorStructInfo([n + c, m], "float32") + z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) + return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z) + + _check_derive( + bb, + func_tuple0(2), + [ + rx.TupleStructInfo( + [ + rx.TensorStructInfo([n, 2], "float32"), + rx.TensorStructInfo([n + 2, 10], "float32"), + ] + ) + ], + rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]), + ) + + def func_tuple1(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x0 = rx.TensorStructInfo([n, m], "float32") + x1 = rx.TensorStructInfo([n + c, c], "float32") + z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) + return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z) + + # Still OK, to pass erased tensor into n+2, n is captured by other argument. + _check_derive( + bb, + func_tuple1(4), + [ + rx.TupleStructInfo( + [ + rx.TensorStructInfo([n, 4], "float32"), + rx.TensorStructInfo(ndim=2, dtype="float32"), + ] + ) + ], + rx.TupleStructInfo([rx.TensorStructInfo([4, n], "float32")]), + ) + + # tuple length mismatch is not causes an error + with pytest.raises(TVMError): + _check_derive( + bb, + func_tuple0(4), + [rx.TupleStructInfo([rx.TensorStructInfo([n, 4], "float32")])], + rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]), + ) + + # mixed shape types + def func_shape_mixed(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x0 = rx.ShapeStructInfo([n, m]) + f0 = func_tuple0(c) + z = rx.ShapeStructInfo([m + n, c]) + return rx.FuncStructInfo([x0, f0], z) + + _check_derive( + bb, + func_shape_mixed(3), + [ + rx.ShapeStructInfo([10, 20]), + rx.FuncStructInfo.opaque_func(ret=rx.ShapeStructInfo(ndim=2)), + ], + rx.ShapeStructInfo([30, 3]), + ) + + +def _check_lca(lhs, rhs, target): + tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(lhs, rhs), target) + tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(rhs, lhs), target) + + +def test_struct_info_lca(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + obj0 = rx.ObjectStructInfo() + prim0 = rx.PrimStructInfo("int32") + prim1 = rx.PrimStructInfo("float32") + + shape0 = rx.ShapeStructInfo(ndim=-1) + shape1 = rx.ShapeStructInfo(ndim=2) + shape2 = rx.ShapeStructInfo(ndim=3) + shape3 = rx.ShapeStructInfo([1, 2, 3]) + shape4 = rx.ShapeStructInfo([1, n, 3]) + + tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32") + tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32") + tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32") + tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32") + tensor4 = rx.TensorStructInfo([n, m], "int32") + tensor5 = rx.TensorStructInfo([n, m, 1], "int32") + tensor6 = rx.TensorStructInfo([n, m, 2], "int32") + + # obj + _check_lca(obj0, prim0, obj0) + _check_lca(obj0, prim1, obj0) + + # shape + _check_lca(shape0, tensor0, obj0) + _check_lca(shape0, shape1, shape0) + _check_lca(shape1, shape2, shape0) + _check_lca(shape1, shape3, shape0) + + _check_lca(shape2, shape3, shape2) + _check_lca(shape3, shape4, shape2) + _check_lca(shape4, rx.ShapeStructInfo([1, n, 3]), shape4) + + # tensor + _check_lca(tensor0, prim0, obj0) + _check_lca(tensor0, tensor1, rx.TensorStructInfo(ndim=-1, dtype=None)) + _check_lca(tensor0, tensor2, tensor0) + _check_lca(tensor0, tensor4, tensor0) + + _check_lca(tensor2, tensor4, tensor2) + _check_lca(tensor5, tensor6, rx.TensorStructInfo(ndim=3, dtype="int32")) + _check_lca(tensor4, tensor5, rx.TensorStructInfo(ndim=-1, dtype="int32")) + _check_lca(tensor4, rx.TensorStructInfo([n, m], dtype="int32"), tensor4) + + # tuple + t0 = rx.TupleStructInfo([obj0, tensor0]) + t1 = rx.TupleStructInfo([prim0, tensor4]) + t2 = rx.TupleStructInfo([obj0, tensor0, obj0]) + t3 = rx.TupleStructInfo([tensor0, obj0]) + + _check_lca(t0, t1, t0) + _check_lca(t0, t2, obj0) + _check_lca(t0, t3, rx.TupleStructInfo([obj0, obj0])) + + t5 = rx.TupleStructInfo([t0, t1]) + t6 = rx.TupleStructInfo([t1, t2]) + + _check_lca(t5, t6, rx.TupleStructInfo([t0, obj0])) + + t7 = rx.TupleStructInfo([]) + _check_lca(t7, rx.TupleStructInfo([]), t7) + + def fn_info_shape(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + def fn_info_erased(): + x = rx.TensorStructInfo(ndim=3, dtype="float32") + y = rx.TensorStructInfo(ndim=3, dtype="float32") + z = rx.TensorStructInfo(ndim=2, dtype="float32") + return rx.FuncStructInfo([x, y], z) + + fopaque0 = lambda: rx.FuncStructInfo.opaque_func() + fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0) + fopaque2 = lambda: rx.FuncStructInfo.opaque_func( + ret=rx.TensorStructInfo(ndim=2, dtype="float32") + ) + + _check_lca(fn_info_shape(1), fn_info_shape(2), fn_info_erased()) + _check_lca(fn_info_shape(2), fn_info_shape(2), fn_info_shape(2)) + + _check_lca(fopaque0(), fopaque1(), fopaque0()) + _check_lca(fopaque0(), fn_info_shape(1), fopaque0()) + _check_lca(fopaque2(), fn_info_shape(1), fopaque2()) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 8c45298f5034..a5c1ff12e5cb 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -129,7 +129,9 @@ def test_symbolic_var(): def test_symbolic_var_invalid_type(): - with pytest.raises(tvm.TVMError, match="the value in ShapeExpr can only have dtype of int64"): + with pytest.raises( + tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64" + ): dim = tir.Var("dim", "float32") type_anno = rx.DynTensorType(ndim=1, dtype="float32") y = rx.Var("y", [dim], type_anno) diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 66e9f352787d..86da133e1c3d 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -259,7 +259,7 @@ def test_shape_of(): assert 'Var(name_hint="v0")' in s0_str shape_anno = [96, 54] - v1 = rx.Var("v1", shape_anno) + v1 = rx.Var("v1", shape_anno, rx.DynTensorType(ndim=2)) s1 = v1.shape s1_str = dump_ast(s1) assert s1_str.startswith("ShapeExpr("), s1_str diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py index 249c76ca547a..0cdae6458dd7 100644 --- a/tests/python/relax/test_blockbuilder.py +++ b/tests/python/relax/test_blockbuilder.py @@ -238,19 +238,17 @@ def test_binary_shape_type_deduction(): assert lv1.checked_type.dtype == "float16" lv2 = bb.emit(rx.op.multiply(z, w)) - assert isinstance(lv2.shape, rx.Call) assert isinstance(lv2.checked_type, rx.DynTensorType) assert lv2.checked_type.ndim == 1 assert lv2.checked_type.dtype == "float16" lv3 = bb.emit(rx.op.multiply(y, w)) - assert isinstance(lv3.shape, rx.Call) - assert isinstance(lv3.checked_type, rx.DynTensorType) + assert isinstance(lv3.struct_info, rx.TensorStructInfo) assert lv3.checked_type.ndim == 1 assert lv3.checked_type.dtype == "float16" gv0 = bb.emit_output(lv3) bb.emit_func_output(gv0) - assert isinstance(gv0.shape, rx.Call) + assert isinstance(gv0.checked_type, rx.DynTensorType) assert gv0.checked_type.ndim == 1 assert gv0.checked_type.dtype == "float16" @@ -278,7 +276,7 @@ def test_emit_match_shape(): # lv1: Shape = match_shape(shape, [m, n]) lv1 = bb.match_shape(y, [m, n]) - assert lv1.checked_type == rx.ShapeType() + assert lv1.checked_type == rx.ShapeType(2) gv0 = bb.emit_output(lv1) bb.emit_func_output(gv0) @@ -346,8 +344,14 @@ def test_normalize(): bb.normalize(tuple_1) assert_structural_equal(tuple_1.checked_type, rx.TupleType([type_anno0, type_anno1])) assert_structural_equal(tuple_1.shape, rx.Tuple([x.shape, y.shape])) + assert isinstance(tuple_1.struct_info, rx.TupleStructInfo) + assert isinstance(tuple_1.struct_info.fields[0], rx.TensorStructInfo) + assert isinstance(tuple_1.struct_info.fields[1], rx.TensorStructInfo) + + # Note sure if it's needed assert_structural_equal( - tuple_1.shape.checked_type, rx.TupleType([rx.ShapeType(), rx.ShapeType()]) + tuple_1.shape.struct_info, + rx.TupleStructInfo([rx.ShapeStructInfo([m, n]), rx.ShapeStructInfo([n])]), ) # Nested Tuple @@ -357,10 +361,15 @@ def test_normalize(): tuple_2.checked_type, rx.TupleType([type_anno0, rx.TupleType([type_anno0, type_anno1])]) ) assert_structural_equal(tuple_2.shape, rx.Tuple([x.shape, rx.Tuple([x.shape, y.shape])])) - assert_structural_equal( - tuple_2.shape.checked_type, - rx.TupleType([rx.ShapeType(), rx.TupleType([rx.ShapeType(), rx.ShapeType()])]), - ) + assert isinstance(tuple_2.struct_info, rx.TupleStructInfo) + assert isinstance(tuple_2.struct_info.fields[0], rx.TensorStructInfo) + assert isinstance(tuple_2.struct_info.fields[1], rx.TupleStructInfo) + assert isinstance(tuple_2.struct_info.fields[1].fields[0], rx.TensorStructInfo) + assert isinstance(tuple_2.struct_info.fields[1].fields[1], rx.TensorStructInfo) + # assert_structural_equal( + # tuple_2.shape.checked_type, + # rx.TupleType([rx.ShapeType(), rx.TupleType([rx.ShapeType(), rx.ShapeType()])]), + # ) def test_call_te(): diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index 5b13ea65b889..5ced0a72833f 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -136,12 +136,15 @@ def test_shape_expr() -> None: assert s.values[0] == m assert s.values[1] == n + assert isinstance(s.struct_info, rx.ShapeStructInfo) + def test_func(): type_anno = rx.DynTensorType(2, "float32") x = rx.Var("foo", type_annotation=type_anno) bindings = [rx.VarBinding(x, rx.const(1))] blocks = [rx.BindingBlock(bindings)] + seqe = rx.SeqExpr(blocks, x) ret_type = rx.DynTensorType(-1, "float32") ret_shape = rx.RuntimeDepShape() @@ -150,7 +153,7 @@ def test_func(): assert func.params[0] == x assert func.body == seqe assert func.ret_type == ret_type - assert func.ret_shape == ret_shape + assert isinstance(func.ret_shape, rx.RuntimeDepShape) assert func.attrs["global_symbol"] == "func" @@ -161,7 +164,7 @@ def test_shape_of(): assert s0.op.name == "relax.shape_of" shape_anno = [96, 54] - v1 = rx.Var("v1", shape_anno) + v1 = rx.Var("v1", shape_anno, rx.DynTensorType(ndim=2)) s1 = v1.shape for x, y in zip(shape_anno, s1): assert x == y @@ -171,17 +174,19 @@ def test_shape_expr(): shape_expr = rx.ShapeExpr([10, 20]) assert shape_expr.values[0] == 10 assert shape_expr.values[1] == 20 - assert shape_expr.checked_type == rx.ShapeType() + assert shape_expr.checked_type == rx.ShapeType(ndim=2) assert shape_expr.shape_ is None x = rx.Var("v0", (10, 20), rx.DynTensorType(2, "float32")) assert x.shape_.values[0] == 10 assert x.shape_.values[1] == 20 - assert x.shape_.checked_type == rx.ShapeType() + assert x.shape_.checked_type == rx.ShapeType(ndim=2) assert x.shape_.shape_ is None m = tir.Var("m", "int32") - with pytest.raises(tvm.TVMError, match="the value in ShapeExpr can only have dtype of int64"): + with pytest.raises( + tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64" + ): rx.ShapeExpr([m, 3]) diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py index b3a6e0434724..68d4fefb8638 100644 --- a/tests/python/relax/test_expr_functor.py +++ b/tests/python/relax/test_expr_functor.py @@ -30,7 +30,7 @@ from tvm.relax.expr import Call, If, TupleGetItem from tvm.relax.expr import Binding, MatchShape, VarBinding from tvm.relax.expr import BindingBlock, DataflowBlock -from tvm.relax.expr import _update_shape, _update_type + m, n = tir.Var("m", "int64"), tir.Var("n", "int64") type_anno1 = relax.DynTensorType(1, "float32") @@ -281,7 +281,7 @@ def emit(b: VarBinding): emit(binding) return - temp = self.with_shape_and_type(new_var, new_value.shape_, new_value._checked_type_) + temp = self.with_struct_info(new_var, new_value.struct_info) if not temp.same_as(new_var): new_var = temp self.set_var_remap(binding.var.vid, new_var) @@ -294,11 +294,14 @@ def visit_match_shape_(self, binding: MatchShape) -> None: new_pattern = self.visit_expr(ShapeExpr(binding.pattern)) if binding.var: - new_shape = None - if new_value._checked_type_ and isinstance(new_value._checked_type_, DynTensorType): - new_shape = new_pattern + new_sinfo = None + if isinstance(new_value.struct_info, TensorStructInfo): + new_sinfo = relax.TensorStructInfo(new_pattern, dtype=new_value.struct_info) + else: + new_sinfo = new_value.struct_info + new_var = self.visit_var_def(binding.var) - temp = self.with_shape_and_type(new_var, new_shape, new_value._checked_type_) + temp = self.with_struct_info(new_var, new_sinfo) if not temp.same_as(new_var): new_var = temp self.set_var_remap(binding.var.vid, new_var) @@ -339,8 +342,7 @@ def visit_var_def_(self, var: Var) -> None: if shape_unchanged: return var else: - new_var = Var(var.vid, None, var._checked_type_, var.span) - _update_shape(new_var, new_shape) + new_var = Var(var.vid, new_shape, var._checked_type_, var.span) self.set_var_remap(var.vid, new_var) return new_var @@ -357,8 +359,7 @@ def visit_dataflow_var_def_(self, var: DataflowVar) -> None: if shape_unchanged: return var else: - new_var = DataflowVar(var.vid, None, var._checked_type_, var.span) - _update_shape(new_var, new_shape) + new_var = DataflowVar(var.vid, new_shape, var._checked_type_, var.span) self.set_var_remap(var.vid, new_var) return new_var @@ -403,6 +404,7 @@ def test_var(): basic_check(x, "Var", "Var") +@pytest.mark.skip("Revisit PyMutator tests after struct info") def test_dataflow_var(): lv = relax.DataflowVar("lv", [n], type_anno1) basic_check(lv, "DataflowVar", "DataflowVar") @@ -566,7 +568,7 @@ def test_function(): blocks = [relax.BindingBlock(bindings)] seq_expr = relax.SeqExpr(blocks, x) ret_type = relax.DynTensorType(1, "float32") - ret_shape = relax.RuntimeDepShape() + ret_shape = relax.ShapeExpr([n]) func = relax.Function([x], seq_expr, ret_type, ret_shape) basic_check( func, @@ -586,7 +588,7 @@ def test_function(): [ "ShapeExpr", "VarDef", - "RuntimeDepShape", + "ShapeExpr", "Constant", "ShapeExpr", "VarDef", diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index 743c5efb0f03..bcc7c5d77338 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -15,9 +15,9 @@ # specific language governing permissions and limitations # under the License. - import pytest import tvm +import tvm.testing from tvm import relax from tvm import tir @@ -204,6 +204,7 @@ def foo(x: R.Tensor((3, 3), "float32")): check_roundtrip(foo) +@pytest.mark.skip("Need to fix string ast expr") def test_primexpr_arithmetic(): @R.function def foo(x: R.Tensor(("n", "m"), "float32")): @@ -343,6 +344,7 @@ def h( check_roundtrip(my_module) +@pytest.mark.skip("Need to fix string ast expr") def test_tir_max(): @R.function def tir_max(x: R.Tensor(("m", "n"), "float32")): @@ -353,6 +355,7 @@ def tir_max(x: R.Tensor(("m", "n"), "float32")): check_roundtrip(tir_max) +@pytest.mark.skip("Need to fix string ast expr") def test_tir_cast(): @R.function def tir_cast(x: R.Tensor(("m",), "float32")): @@ -434,4 +437,4 @@ def local_func_3( if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() diff --git a/tests/python/relax/test_struct_info.py b/tests/python/relax/test_struct_info.py new file mode 100644 index 000000000000..9e6ac21be7af --- /dev/null +++ b/tests/python/relax/test_struct_info.py @@ -0,0 +1,222 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import pytest + +from tvm import relax as rx, TVMError, tir + + +def _check_equal(x, y, map_free_vars=False): + tvm.ir.assert_structural_equal(x, y, map_free_vars) + tvm.ir.assert_structural_equal(y, x, map_free_vars) + + xhash = tvm.ir.structural_hash(x, map_free_vars) + yhash = tvm.ir.structural_hash(y, map_free_vars) + + assert xhash == yhash + + +def _check_json_roundtrip(x): + xret = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, xret, map_free_vars=True) + return xret + + +def test_object_struct_info(): + s0 = rx.ObjectStructInfo() + s1 = rx.ObjectStructInfo() + + # can turn into str + str(s0) + _check_equal(s0, s1) + + assert isinstance(s0, rx.ObjectStructInfo) + _check_json_roundtrip(s0) + + +def test_prim_struct_info(): + s0 = rx.PrimStructInfo("float32") + s1 = rx.PrimStructInfo("float32") + s2 = rx.PrimStructInfo("int32") + + _check_equal(s0, s1) + + # can turn into str + str(s0) + + assert s0 == s1 + assert s0 != s2 + + assert isinstance(s0, rx.PrimStructInfo) + _check_json_roundtrip(s0) + _check_json_roundtrip(s1) + + assert s1.dtype == "float32" + assert s2.dtype == "int32" + + # wrong API constructors + with pytest.raises(TVMError): + rx.PrimStructInfo(1) + + +def test_shape_struct_info(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s0 = rx.ShapeStructInfo([1, n + 1, m]) + s1 = rx.ShapeStructInfo([1, n + 1, m]) + + _check_equal(s0, s1) + + assert s0 == s1 + assert s0.ndim == 3 + assert s1.ndim == 3 + + assert s0.values[2] == m + + assert isinstance(s0, rx.ShapeStructInfo) + _check_json_roundtrip(s0) + _check_json_roundtrip(s1) + + s2 = rx.ShapeStructInfo(ndim=2) + + assert s2.ndim == 2 + assert s2.values is None + _check_json_roundtrip(s2) + assert s0 != s2 + + # can turn into str + str(s0) + + # wrong argument type + with pytest.raises(TVMError): + rx.ShapeStructInfo(1) + + # cannot pass both ndim and values + with pytest.raises(ValueError): + rx.ShapeStructInfo([1, 2], ndim=3) + + # cannot pass both ndim and values even if they are consistent + with pytest.raises(ValueError): + rx.ShapeStructInfo([1, 2], ndim=2) + + +def test_tensor_struct_info(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s0 = rx.TensorStructInfo([1, n + 1, m], "float32") + s1 = rx.TensorStructInfo(rx.ShapeExpr([1, n + 1, m]), "float32") + + _check_equal(s0, s1) + + assert s0 == s1 + assert s0.ndim == 3 + assert s1.ndim == 3 + + assert isinstance(s0, rx.TensorStructInfo) + _check_json_roundtrip(s0) + _check_json_roundtrip(s1) + + s2 = rx.TensorStructInfo(ndim=2, dtype="int32") + + assert s2.ndim == 2 + assert s2.dtype == "int32" + assert s2.shape is None + _check_json_roundtrip(s2) + assert s0 != s2 + + # take in opaque var + rshape = rx.Var("shape", type_annotation=rx.ShapeType(ndim=2)) + + s3 = rx.TensorStructInfo(rshape, dtype="int32") + assert s3.dtype == "int32" + assert s3.shape == rshape + assert s3.ndim == 2 + _check_json_roundtrip(s3) + + # can turn into str + str(s0) + + # cannot pass both ndim and values + with pytest.raises(ValueError): + rx.TensorStructInfo([1, 2], ndim=3) + + # cannot pass both ndim and values even if they are consistent + with pytest.raises(ValueError): + rx.TensorStructInfo([1, 2], ndim=2) + + +def test_tuple_struct_info(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s0 = rx.TensorStructInfo([1, 2, m + n], "float32") + s1 = rx.ObjectStructInfo() + + t0 = rx.TupleStructInfo([s0, s1]) + t1 = rx.TupleStructInfo([s0, rx.ObjectStructInfo()]) + t2 = rx.TupleStructInfo([s0, s0]) + + _check_equal(t0, t1) + + assert t0 == t1 + + assert isinstance(t0, rx.TupleStructInfo) + t0 = _check_json_roundtrip(t0) + t1 = _check_json_roundtrip(t1) + t2 = _check_json_roundtrip(t2) + + # can turn into str + str(t0) + + # wrong argument type + with pytest.raises(TVMError): + rx.TupleStructInfo(1) + + +def test_func_struct_info(): + def fn_info(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n, m], "float32") + return rx.FuncStructInfo([x, y], z) + + f0 = fn_info(1) + f1 = fn_info(1) + f2 = fn_info(2) + f3 = rx.FuncStructInfo.opaque_func() + + _check_equal(f0, f1) + + assert f0 == f1 + assert f0 != f2 + + assert len(f0.params) == 2 + assert isinstance(f0.ret, rx.TensorStructInfo) + assert f2.derive_func is None + assert f3.params is None + assert f3.derive_func is None + _check_equal(f3.ret, rx.ObjectStructInfo()) + + assert isinstance(f0, rx.FuncStructInfo) + f0 = _check_json_roundtrip(f0) + f1 = _check_json_roundtrip(f1) + f2 = _check_json_roundtrip(f2) + f3 = _check_json_roundtrip(f3) + + # can turn into str + str(f3) diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index 5b7d269dfe6d..bd8023640652 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -18,6 +18,7 @@ import tvm import tvm.script import tvm.testing +import pytest from tvm import relax from tvm.ir.base import assert_structural_equal from tvm.script import relax as R, tir as T diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index f3cac8f3bf51..785b95860297 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -89,7 +89,9 @@ def test_closure(): @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")): + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): outer_func = lifted_func_0 in_call = outer_func(x) res = R.invoke_closure(in_call, (y,), type_args=(R.Tensor(ndim=2, dtype="float32"))) @@ -192,6 +194,7 @@ def while_loop( # Perform Lamda Lifting after = transform.LambdaLift()(before) assert len(after.functions) == 2 + assert_structural_equal(after, expected, map_free_vars=True) _check_save_roundtrip(after) diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py index 404cfa456dd6..98806abc61a0 100644 --- a/tests/python/relax/test_tvmscript_ir_builder.py +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -34,7 +34,7 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) R.func_name("foo") R.func_attr({"Primitive": 1}) x = R.arg("x", R.tensor((128, 128), "float32")) - R.func_ret_type(R.tensor(dtype="float32", ndim=2)) + R.func_ret_struct_info(R.tensor(dtype="float32", ndim=2)) out = R.emit(R.call_tir("extern_func", x, (128, 128), dtype="float32")) IRBuilder.name("out", out) R.func_ret_value(out) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 0d23575ba3c8..abce4c8d82a1 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -747,6 +747,19 @@ def bar(x: R.Tensor): return R.print(x, format="{}") +def test_erase_to_well_defined(): + @R.function + def foo(x: R.Tensor): + q = x + m, n = T.var("int64"), T.var("int64") + z = R.match_shape(q, (m, n)) + w = z + return w + + assert isinstance(foo.ret_shape, RuntimeDepShape) + _check(foo, None) + + @pytest.mark.skip(reason="potential upstream Metadata changes.") def test_meta(): metadata = tvm.ir.load_json( diff --git a/tests/python/relax/test_vm.py b/tests/python/relax/test_vm.py index eb1e28ddfb34..c73244b8db09 100644 --- a/tests/python/relax/test_vm.py +++ b/tests/python/relax/test_vm.py @@ -915,7 +915,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: @R.function def relax_matmul_tir( x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") - ) -> R.Tensor: + ) -> R.Tensor((32, 32), dtype="float32"): with R.dataflow(): gv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") R.output(gv0) @@ -924,12 +924,12 @@ def relax_matmul_tir( @R.function def relax_matmul_packed( x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") - ) -> Object: + ) -> R.Object: gv0 = R.call_packed("test.vm.mul", x, w, type_args=(R.Tensor(ndim=2, dtype="float32"))) return gv0 @R.function - def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> Object: + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Object: gv0 = relax_matmul_tir(x, w) gv1 = relax_matmul_packed(gv0, gv0) return gv1