diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 3c39f47478c9..cd8d9ada413a 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -369,15 +369,6 @@ class RelayExprNode : public BaseExprNode { */ mutable Type checked_type_ = Type(nullptr); - /*! - * \brief Stores the result of static shape analysis. It must be a RelayExpr - * and ObjectRef is used here to avoid cyclic typing. - * - * \note The value will be optional if a static shape can not be inferred. - * use .shape() instead to acesss an always defined shape expression. - */ - mutable Optional shape_ = Optional(); - /*! * \brief Stores the result of structure information of the * expression that encapsulate both static shape and @@ -390,13 +381,6 @@ class RelayExprNode : public BaseExprNode { */ inline const Type& checked_type() const; - /*! - * \return An expression which corresponds to the shape of the expression. - * - * Only valid when the expression's type is a Tensor. - */ - RelayExpr shape() const; - /*! * \brief Check if the inferred(checked) type of the Expr * is backed by a TTypeNode and return it. diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h index ac1631e5e384..ba35840d4007 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -95,7 +95,6 @@ class TypeFunctor { virtual R VisitType_(const relax::ShapeTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const relax::ObjectTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const relax::DynTensorTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const relax::DimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const relax::PackedFuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); @@ -123,7 +122,6 @@ class TypeFunctor { TVM_TYPE_FUNCTOR_DISPATCH(relax::ShapeTypeNode); TVM_TYPE_FUNCTOR_DISPATCH(relax::ObjectTypeNode); TVM_TYPE_FUNCTOR_DISPATCH(relax::DynTensorTypeNode); - TVM_TYPE_FUNCTOR_DISPATCH(relax::DimTypeNode); TVM_TYPE_FUNCTOR_DISPATCH(relax::PackedFuncTypeNode); return vtable; } diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 4e38c055e147..eb83dc66f271 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -85,22 +85,6 @@ TVM_DLL Type GetStaticType(const StructInfo& info); */ TVM_DLL StructInfo StructInfoFromType(const Type& type); -// TODO(relax-team): Remove legacy shape related functionalities after phasing out shape_ -/*! - * \brief Get the corresponding struct info from static type. - * \param type The input type - * \param shape_hint The shape hint - * \return the corresponding struct info. - */ -TVM_DLL StructInfo StructInfoFromTypeLegacyShapeHint(const Type& type, Optional shape_hint); - -/*! - * \brief Get the corresponding legacy shape hint from struct info - * \param info The struct info. - * \return the corresponding legacy shape hint. - */ -TVM_DLL Optional GetLegacyShapeHint(const StructInfo& info); - /*! * \return Derive the call's ret value struct info from inputs. * \param func_info The function struct info. @@ -425,17 +409,6 @@ std::pair>, Array> FunctionUseDef(const Function& fn); */ TVM_DLL Function RemoveAllUnused(const Function fn); -/*! - * \brief Given the argument vars and body, derives a return shape for a function with those args - * and that body. If the body's shape contains free shape vars (those not used in the args), the - * return shape is relaxed to RuntimeDepShape; otherwise, the body's shape is used. - * - * \param args The argument variables, ideally with the shape_ field filled in - * \param body The functino body, ideally with the shape_ field filled in - * \return An expression that can serve as the return shape for the function - */ -TVM_DLL Expr DeriveFuncRetShape(Array args, Expr body); - } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index e9392b4fb1d2..fce2151651e4 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -72,14 +72,6 @@ class BlockBuilderNode : public Object { */ virtual NameTable* name_table() = 0; - /*! - * \brief Check if two shape expressions can be proven equal at compile time. - * \param lhs The input lhs shape. - * \param rhs The input rhs shape. - * \return Whether we can prove lhs shape is the same as the rhs shape. - */ - virtual bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs) = 0; - /*! * \brief Get the context IRModule in this builder. * @@ -167,15 +159,6 @@ class BlockBuilderNode : public Object { */ virtual Var Emit(Expr expr, String name_hint = "") = 0; - /*! - * \brief Emits a variable binding, and returns the bound Var. - * \param binding The variable binding. - * \return The bound variable. - * - * \note This function requires binding to be pre-normalized. - */ - virtual Var Emit(VarBinding binding) = 0; - /*! * \brief Emit a MatchShape. * \param value The value of the MatchShape to be emitted. @@ -185,15 +168,6 @@ class BlockBuilderNode : public Object { */ virtual Var EmitMatchShape(Expr value, Array pattern, String name_hint = "") = 0; - /*! - * \brief Emit a MatchShape binding. - * \param binding The MatchShape binding to be emitted. - * \return The variable bound to the MatchShape. - * - * \note This function requires binding to be pre-normalized. - */ - virtual Var EmitMatchShape(MatchShape binding) = 0; - /*! * \brief Generate an output for the current dataflow block. * \param output The output variable of the block. @@ -202,15 +176,6 @@ class BlockBuilderNode : public Object { */ virtual Var EmitOutput(Expr output, String name_hint = "") = 0; - /*! - * \brief Generate an output for the current dataflow block. - * \param binding The output binding to output. - * \return The variable bound to \p output. - * - * \note This function requires binding to be pre-normalized. - */ - virtual Var EmitOutput(VarBinding binding) = 0; - /*! * \brief Emit a binding that is already normalized. * diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 6fbf0523d4c8..33f509d7dbcd 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -46,7 +46,6 @@ class OrPattern; class AndPattern; class NotPattern; class ShapePattern; -class RuntimeDepShapePattern; class TypePattern; class DataTypePattern; class AttrPattern; @@ -112,8 +111,6 @@ class DFPattern : public ObjectRef { TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const; /*! \brief Syntatic Sugar for creating a ShapePattern */ TVM_DLL ShapePattern HasShape(const Array& shape) const; - /*! \brief Syntatic Sugar for creating a RuntimeDepShapePattern */ - TVM_DLL RuntimeDepShapePattern HasRuntimeDepShape() const; /*! \brief Syntatic Sugar for duplicating the current pattern */ TVM_DLL DFPattern dup() const; @@ -778,30 +775,6 @@ class ExternFuncPattern : public DFPattern { TVM_DEFINE_OBJECT_REF_METHODS(ExternFuncPattern, DFPattern, ExternFuncPatternNode); }; -/*! - * \brief A pattern that asserting a root pattern has a runtime-dependent shape. - * \sa RuntimeDepShape - * \sa RuntimeDepShapePattern - */ -class RuntimeDepShapePatternNode : public DFPatternNode { - public: - DFPattern pattern; /*!< The root pattern to match */ - void VisitAttrs(tvm::AttrVisitor* v) {} - - static constexpr const char* _type_key = "relax.dpl.RuntimeDepShapePattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(RuntimeDepShapePatternNode, DFPatternNode); -}; - -/*! - * \brief Managed reference to RuntimeDepShapePatternNode. - * \sa RuntimeDepShapePatternNode - */ -class RuntimeDepShapePattern : public DFPattern { - public: - TVM_DLL explicit RuntimeDepShapePattern(DFPattern pattern); - TVM_DEFINE_OBJECT_REF_METHODS(RuntimeDepShapePattern, DFPattern, RuntimeDepShapePatternNode); -}; - /*! \brief Syntatic Sugar for creating a VarPattern with a name */ VarPattern IsVar(const String& name); /*! \brief Syntatic Sugar for creating a ConstantPattern */ diff --git a/include/tvm/relax/dataflow_pattern_functor.h b/include/tvm/relax/dataflow_pattern_functor.h index 4ac5fe1173b4..983881ddc9a7 100644 --- a/include/tvm/relax/dataflow_pattern_functor.h +++ b/include/tvm/relax/dataflow_pattern_functor.h @@ -98,8 +98,6 @@ class DFPatternFunctor { virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const RuntimeDepShapePatternNode* op, - Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const DataflowVarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const GlobalVarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; @@ -135,7 +133,6 @@ class DFPatternFunctor { RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); - RELAX_DFPATTERN_FUNCTOR_DISPATCH(RuntimeDepShapePatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(GlobalVarPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExternFuncPatternNode); @@ -170,7 +167,6 @@ class DFPatternVisitor : public DFPatternFunctor { void VisitDFPattern_(const WildcardPatternNode* op) override; void VisitDFPattern_(const VarPatternNode* op) override; - void VisitDFPattern_(const RuntimeDepShapePatternNode* op) override; void VisitDFPattern_(const DataflowVarPatternNode* op) override; void VisitDFPattern_(const GlobalVarPatternNode* op) override; void VisitDFPattern_(const ExternFuncPatternNode* op) override; diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 558a5c3d3de7..e893affac932 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -135,17 +135,17 @@ class CallNode : public ExprNode { v->Visit("args", &args); v->Visit("attrs", &attrs); v->Visit("type_args", &type_args); - v->Visit("span", &span); - v->Visit("_checked_type_", &checked_type_); - v->Visit("shape_", &shape_); v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { // skip type_args check for primitive ops. equal->MarkGraphNode(); return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) && - (IsPrimitiveOp(op) || equal(type_args, other->type_args)); + (IsPrimitiveOp(op) || equal(type_args, other->type_args)) && + equal(struct_info_, other->struct_info_); } void SHashReduce(SHashReducer hash_reduce) const { @@ -156,6 +156,7 @@ class CallNode : public ExprNode { if (!IsPrimitiveOp(op)) { hash_reduce(type_args); } + hash_reduce(struct_info_); } static constexpr const char* _type_key = "relax.expr.Call"; @@ -214,16 +215,15 @@ class IfNode : public ExprNode { v->Visit("cond", &cond); v->Visit("true_branch", &true_branch); v->Visit("false_branch", &false_branch); - v->Visit("span", &span); - v->Visit("shape_", &shape_); v->Visit("_checked_type_", &checked_type_); v->Visit("struct_info_", &struct_info_); + v->Visit("span", &span); } bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); return equal(cond, other->cond) && equal(true_branch, other->true_branch) && - equal(false_branch, other->false_branch); + equal(false_branch, other->false_branch) && equal(struct_info_, other->struct_info_); } void SHashReduce(SHashReducer hash_reduce) const { @@ -231,6 +231,7 @@ class IfNode : public ExprNode { hash_reduce(cond); hash_reduce(true_branch); hash_reduce(false_branch); + hash_reduce(struct_info_); } static constexpr const char* _type_key = "relax.expr.If"; @@ -270,28 +271,17 @@ class TupleNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); - v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); - v->Visit("shape_", &shape_); v->Visit("struct_info_", &struct_info_); + v->Visit("span", &span); } bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const { - // specially handle empty tuple as a constant is not a graph node. - if (fields.size() == other->fields.size() && fields.size() == 0) { - return true; - } else { - equal->MarkGraphNode(); - return equal(fields, other->fields); - } + // struct info can be deterministically derived from fields. + return equal(fields, other->fields); } - void SHashReduce(SHashReducer hash_reduce) const { - if (fields.size() != 0) { - hash_reduce->MarkGraphNode(); - hash_reduce(fields); - } - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } static constexpr const char* _type_key = "relax.expr.Tuple"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode); @@ -329,13 +319,13 @@ class TupleGetItemNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tuple_value", &tuple); v->Visit("index", &index); - v->Visit("span", &span); - v->Visit("shape_", &shape_); v->Visit("struct_info_", &struct_info_); v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); } bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const { + // struct info can be deterministically tuple and index. return equal(tuple, other->tuple) && equal(index, other->index); } @@ -380,22 +370,17 @@ class ShapeExprNode : public ExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("values", &values); - v->Visit("shape_", &shape_); v->Visit("struct_info_", &struct_info_); v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const { - return equal(values, other->values) && equal(checked_type_, other->checked_type_) && - equal(shape_, other->shape_); + // struct info can be deterministically derived from values. + return equal(values, other->values); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(values); - hash_reduce(checked_type_); - hash_reduce(shape_); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(values); } static constexpr const char* _type_key = "relax.expr.ShapeExpr"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -410,43 +395,6 @@ class ShapeExpr : public Expr { TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode); }; -/*! \brief Runtime dependent shape expression. - * - * Sometimes shape of a tensor cannot be deduced statically either because the shape is truly data - * dependent such as output of `unique` operator or cannot be deduced because of limited shape - * inference capability. - */ -class RuntimeDepShapeNode : public ExprNode { - public: - void VisitAttrs(AttrVisitor* v) { - v->Visit("shape_", &shape_); - v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); - v->Visit("span", &span); - } - - bool SEqualReduce(const RuntimeDepShapeNode* other, SEqualReducer equal) const { - return equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(checked_type_); - hash_reduce(shape_); - } - - static constexpr const char* _type_key = "relax.expr.RuntimeDepShape"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; - TVM_DECLARE_FINAL_OBJECT_INFO(RuntimeDepShapeNode, ExprNode); -}; - -class RuntimeDepShape : public Expr { - public: - TVM_DLL explicit RuntimeDepShape(Span span = Span()); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RuntimeDepShape, Expr, RuntimeDepShapeNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(RuntimeDepShapeNode); -}; - /*! \brief The variable class for all Relax bindings. */ class VarNode : public ExprNode { public: @@ -459,7 +407,6 @@ class VarNode : public ExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("vid", &vid); - v->Visit("shape_", &shape_); v->Visit("struct_info_", &struct_info_); v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); @@ -467,14 +414,12 @@ class VarNode : public ExprNode { bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return equal(vid, other->vid) && equal(checked_type_, other->checked_type_) && - equal(shape_, other->shape_); + return equal(vid, other->vid) && equal(struct_info_, other->struct_info_); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(vid); - hash_reduce(shape_); - hash_reduce(checked_type_); + hash_reduce(struct_info_); } static constexpr const char* _type_key = "relax.expr.Var"; @@ -502,7 +447,6 @@ class DataflowVarNode : public VarNode { public: void VisitAttrs(AttrVisitor* v) { v->Visit("vid", &vid); - v->Visit("shape_", &shape_); v->Visit("struct_info_", &struct_info_); v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); @@ -510,14 +454,12 @@ class DataflowVarNode : public VarNode { bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return equal(vid, other->vid) && equal(shape_, other->shape_) && - equal(checked_type_, other->checked_type_); + return equal(vid, other->vid) && equal(struct_info_, other->struct_info_); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(vid); - hash_reduce(shape_); - hash_reduce(checked_type_); + hash_reduce(struct_info_); } static constexpr const char* _type_key = "relax.expr.DataflowVar"; @@ -557,13 +499,13 @@ class ConstantNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); - v->Visit("span", &span); - v->Visit("_checked_type_", &checked_type_); - v->Visit("shape_", &shape_); v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); } bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. return equal(data, other->data); } @@ -751,7 +693,6 @@ class SeqExprNode : public ExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("blocks", &blocks); v->Visit("body", &body); - v->Visit("shape_", &shape_); v->Visit("struct_info_", &struct_info_); v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); @@ -759,14 +700,13 @@ class SeqExprNode : public ExprNode { bool SEqualReduce(const SeqExprNode* other, SEqualReducer equal) const { return equal(blocks, other->blocks) && equal(body, other->body) && - equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_); + equal(struct_info_, other->struct_info_); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(blocks); hash_reduce(body); - hash_reduce(shape_); - hash_reduce(checked_type_); + hash_reduce(struct_info_); } static constexpr const char* _type_key = "relax.expr.SeqExpr"; @@ -796,19 +736,17 @@ class FunctionNode : public BaseFuncNode { v->Visit("params", ¶ms); v->Visit("body", &body); v->Visit("ret_struct_info", &ret_struct_info); - v->Visit("_checked_type_", &checked_type_); - v->Visit("shape_", &shape_); - v->Visit("struct_info_", &struct_info_); v->Visit("attrs", &attrs); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); return equal.DefEqual(params, other->params) && equal(body, other->body) && - equal(ret_struct_info, other->ret_struct_info) && - equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_) && - equal(attrs, other->attrs); + equal(ret_struct_info, other->ret_struct_info) && equal(attrs, other->attrs) && + equal(struct_info_, other->struct_info_); } void SHashReduce(SHashReducer hash_reduce) const { @@ -816,9 +754,8 @@ class FunctionNode : public BaseFuncNode { hash_reduce.DefHash(params); hash_reduce(body); hash_reduce(ret_struct_info); - hash_reduce(checked_type_); - hash_reduce(shape_); hash_reduce(attrs); + hash_reduce(struct_info_); } static constexpr const char* _type_key = "relax.expr.Function"; @@ -868,14 +805,18 @@ class ExternFuncNode : public BaseFuncNode { void VisitAttrs(AttrVisitor* v) { v->Visit("global_symbol", &global_symbol); v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const { - return equal(global_symbol, other->global_symbol); + return equal(global_symbol, other->global_symbol) && equal(struct_info_, other->struct_info_); } - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(global_symbol); } + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(global_symbol); + hash_reduce(struct_info_); + } static constexpr const char* _type_key = "relax.expr.ExternFunc"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -890,6 +831,19 @@ class ExternFunc : public BaseFunc { TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode); }; +/*! + * \brief Get the shape of Expr. + * \param expr The input expr. + * \return The corresonding shape. + * + * \note This function requires expr to be normalized. + * The function will report an error if expr's StructInfo is not TensorStructInfo. + * It will try to return symbolic function when possible. If the tensor do not + * have a compile-time symbolic shape, the function will then choose to return + * Call(relax.op.shape_of, [expr]). + */ +TVM_DLL Expr GetShapeOf(const Expr& expr); + } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 306e961badb8..236233e6c7ce 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -145,7 +145,6 @@ class ExprFunctor { virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const DataflowVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ShapeExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const RuntimeDepShapeNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ExternFuncNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -169,7 +168,6 @@ class ExprFunctor { RELAX_EXPR_FUNCTOR_DISPATCH(VarNode); RELAX_EXPR_FUNCTOR_DISPATCH(DataflowVarNode); RELAX_EXPR_FUNCTOR_DISPATCH(ShapeExprNode); - RELAX_EXPR_FUNCTOR_DISPATCH(RuntimeDepShapeNode); RELAX_EXPR_FUNCTOR_DISPATCH(ExternFuncNode); RELAX_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode); @@ -199,7 +197,6 @@ class ExprVisitor : public ExprFunctor { void VisitExpr_(const VarNode* op) override; void VisitExpr_(const DataflowVarNode* op) override; void VisitExpr_(const ShapeExprNode* op) override; - void VisitExpr_(const RuntimeDepShapeNode* op) override; void VisitExpr_(const ExternFuncNode* op) override; void VisitExpr_(const GlobalVarNode* op) override; void VisitExpr_(const FunctionNode* op) override; @@ -281,7 +278,6 @@ class ExprMutatorBase : public ExprFunctor { Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const DataflowVarNode* op) override; Expr VisitExpr_(const ShapeExprNode* op) override; - Expr VisitExpr_(const RuntimeDepShapeNode* op) override; Expr VisitExpr_(const ExternFuncNode* op) override; Expr VisitExpr_(const GlobalVarNode* op) override; Expr VisitExpr_(const FunctionNode* op) override; @@ -467,8 +463,6 @@ class PyExprVisitorNode : public Object, public ExprVisitor { PackedFunc f_visit_dataflow_var_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ PackedFunc f_visit_shape_expr_{nullptr}; - /*! \brief The packed function to the `VisitExpr_(const RuntimeDepShapeNode* op)` function. */ - PackedFunc f_visit_runtime_dep_shape_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ PackedFunc f_visit_extern_func_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ @@ -569,7 +563,6 @@ class PyExprVisitorNode : public Object, public ExprVisitor { PY_EXPR_VISITOR_DISPATCH(VarNode, f_visit_var_); PY_EXPR_VISITOR_DISPATCH(DataflowVarNode, f_visit_dataflow_var_); PY_EXPR_VISITOR_DISPATCH(ShapeExprNode, f_visit_shape_expr_); - PY_EXPR_VISITOR_DISPATCH(RuntimeDepShapeNode, f_visit_runtime_dep_shape_); PY_EXPR_VISITOR_DISPATCH(ExternFuncNode, f_visit_extern_func_); PY_EXPR_VISITOR_DISPATCH(GlobalVarNode, f_visit_global_var_); PY_EXPR_VISITOR_DISPATCH(FunctionNode, f_visit_function_); @@ -598,8 +591,6 @@ class PyExprVisitor : public ObjectRef { * \param f_visit_var_ The packed function of `VisitExpr_(const VarNode* op)`. * \param f_visit_dataflow_var_ The packed function of `VisitExpr_(const DataflowVarNode* op)`. * \param f_visit_shape_expr_ The packed function of `VisitExpr_(const ShapeExprNode* op)`. - * \param f_visit_runtime_dep_shape_ The packed function of `VisitExpr_(const RuntimeDepShapeNode* - * op)`. * \param f_visit_extern_func_ The packed function of `VisitExpr_(const ExternFuncNode* op)`. * \param f_visit_global_var_ The packed function of `VisitExpr_(const GlobalVarNode* op)`. * \param f_visit_function_ The packed function of `VisitExpr_(const FunctionNode* op)`. @@ -630,10 +621,9 @@ class PyExprVisitor : public ObjectRef { TVM_DLL static PyExprVisitor MakePyExprVisitor( PackedFunc f_visit_expr, PackedFunc f_visit_constant_, PackedFunc f_visit_tuple_, PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, PackedFunc f_visit_shape_expr_, - PackedFunc f_visit_runtime_dep_shape_, PackedFunc f_visit_extern_func_, - PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_, - PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_, - PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding, + PackedFunc f_visit_extern_func_, PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, + PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, + PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding, PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_, PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, @@ -650,7 +640,6 @@ class PyExprVisitor : public ObjectRef { n->f_visit_var_ = f_visit_var_; n->f_visit_dataflow_var_ = f_visit_dataflow_var_; n->f_visit_shape_expr_ = f_visit_shape_expr_; - n->f_visit_runtime_dep_shape_ = f_visit_runtime_dep_shape_; n->f_visit_extern_func_ = f_visit_extern_func_; n->f_visit_global_var_ = f_visit_global_var_; n->f_visit_function_ = f_visit_function_; @@ -692,8 +681,6 @@ class PyExprMutatorNode : public Object, public ExprMutator { PackedFunc f_visit_dataflow_var_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ PackedFunc f_visit_shape_expr_{nullptr}; - /*! \brief The packed function to the `VisitExpr_(const RuntimeDepShapeNode* op)` function. */ - PackedFunc f_visit_runtime_dep_shape_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ PackedFunc f_visit_extern_func_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ @@ -820,7 +807,6 @@ class PyExprMutatorNode : public Object, public ExprMutator { PY_EXPR_MUTATOR_DISPATCH(VarNode, f_visit_var_); PY_EXPR_MUTATOR_DISPATCH(DataflowVarNode, f_visit_dataflow_var_); PY_EXPR_MUTATOR_DISPATCH(ShapeExprNode, f_visit_shape_expr_); - PY_EXPR_MUTATOR_DISPATCH(RuntimeDepShapeNode, f_visit_runtime_dep_shape_); PY_EXPR_MUTATOR_DISPATCH(ExternFuncNode, f_visit_extern_func_); PY_EXPR_MUTATOR_DISPATCH(GlobalVarNode, f_visit_global_var_); PY_EXPR_MUTATOR_DISPATCH(FunctionNode, f_visit_function_); @@ -841,7 +827,6 @@ class PyExprMutatorNode : public Object, public ExprMutator { PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(VarNode); PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataflowVarNode); PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ShapeExprNode); - PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(RuntimeDepShapeNode); PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ExternFuncNode); PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(GlobalVarNode); PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(FunctionNode); @@ -870,8 +855,6 @@ class PyExprMutator : public ObjectRef { * \param f_visit_var_ The packed function of `VisitExpr_(const VarNode* op)`. * \param f_visit_dataflow_var_ The packed function of `VisitExpr_(const DataflowVarNode* op)`. * \param f_visit_shape_expr_ The packed function of `VisitExpr_(const ShapeExprNode* op)`. - * \param f_visit_runtime_dep_shape_ The packed function of `VisitExpr_(const RuntimeDepShapeNode* - * op)`. * \param f_visit_extern_func_ The packed function of `VisitExpr_(const ExternFuncNode* op)`. * \param f_visit_global_var_ The packed function of `VisitExpr_(const GlobalVarNode* op)`. * \param f_visit_function_ The packed function of `VisitExpr_(const FunctionNode* op)`. @@ -902,10 +885,10 @@ class PyExprMutator : public ObjectRef { TVM_DLL static PyExprMutator MakePyExprMutator( BlockBuilder builder_, PackedFunc f_visit_expr, PackedFunc f_visit_constant_, PackedFunc f_visit_tuple_, PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, - PackedFunc f_visit_shape_expr_, PackedFunc f_visit_runtime_dep_shape_, - PackedFunc f_visit_extern_func_, PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, - PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, - PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding, + PackedFunc f_visit_shape_expr_, PackedFunc f_visit_extern_func_, + PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_, + PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_, + PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding, PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_, PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, @@ -918,7 +901,6 @@ class PyExprMutator : public ObjectRef { n->f_visit_var_ = f_visit_var_; n->f_visit_dataflow_var_ = f_visit_dataflow_var_; n->f_visit_shape_expr_ = f_visit_shape_expr_; - n->f_visit_runtime_dep_shape_ = f_visit_runtime_dep_shape_; n->f_visit_extern_func_ = f_visit_extern_func_; n->f_visit_global_var_ = f_visit_global_var_; n->f_visit_function_ = f_visit_function_; diff --git a/include/tvm/relax/ir_functor.h b/include/tvm/relax/ir_functor.h index 53b718beb9a5..5615e00188f0 100644 --- a/include/tvm/relax/ir_functor.h +++ b/include/tvm/relax/ir_functor.h @@ -77,7 +77,6 @@ class IRFunctor { virtual R VisitNode_(const relax::VarNode* op, Args... args) IR_FUNCTOR_DEFAULT; virtual R VisitNode_(const relax::DataflowVarNode* op, Args... args) IR_FUNCTOR_DEFAULT; virtual R VisitNode_(const relax::ShapeExprNode* op, Args... args) IR_FUNCTOR_DEFAULT; - virtual R VisitNode_(const relax::RuntimeDepShapeNode* op, Args... args) IR_FUNCTOR_DEFAULT; virtual R VisitNode_(const relax::MatchShapeNode* op, Args... args) IR_FUNCTOR_DEFAULT; virtual R VisitNode_(const relax::VarBindingNode* op, Args... args) IR_FUNCTOR_DEFAULT; virtual R VisitNode_(const relax::BindingBlockNode* op, Args... args) IR_FUNCTOR_DEFAULT; @@ -104,7 +103,6 @@ class IRFunctor { RELAX_IR_FUNCTOR_DISPATCH(relax::VarNode); RELAX_IR_FUNCTOR_DISPATCH(relax::DataflowVarNode); RELAX_IR_FUNCTOR_DISPATCH(relax::ShapeExprNode); - RELAX_IR_FUNCTOR_DISPATCH(relax::RuntimeDepShapeNode); RELAX_IR_FUNCTOR_DISPATCH(relax::MatchShapeNode); RELAX_IR_FUNCTOR_DISPATCH(relax::VarBindingNode); RELAX_IR_FUNCTOR_DISPATCH(relax::BindingBlockNode); diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index 3722d8eb4512..92d2b58ea4f2 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -24,9 +24,8 @@ #ifndef TVM_RELAX_OP_ATTR_TYPES_H_ #define TVM_RELAX_OP_ATTR_TYPES_H_ +#include #include -#include -#include #include #include @@ -35,16 +34,6 @@ namespace tvm { namespace relax { -/*! - * \brief Infer the output shape for operators. This function will - * be invoked to fill the \p shape_ field of expressions. - * \param call The call node. - * \param diag_ctx The diagnostic context for reporting errors. - * \return The inferred output shape expression. - */ -using FInferShape = - runtime::TypedPackedFunc; - /*! * \brief Infer output struct info given the call * @@ -54,15 +43,6 @@ using FInferShape = using FInferStructInfo = runtime::TypedPackedFunc; -/*! - * \brief Infer the output type for operators. This function will - * be invoked to fill the \p checked_type_ field of expressions. - * \param call The call node. - * \param diag_ctx The diagnostic context for reporting errors. - * \return The inferred output type. - */ -using FInferType = runtime::TypedPackedFunc; - /*! * \brief Packed function implementation for operators. The relax operator will be lowered to * this packed function call during codegen. diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index 6efe8fdef2e0..34e4ef4a3e15 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -98,12 +98,12 @@ class ShapeStructInfoNode : public StructInfoNode { Optional> values; /*! * \brief The number of dimension of the shape, can be unknown. - * \sa kUnknownDim + * \sa kUnknownNDim */ int ndim; /*! \return Whether the struct info contains unknown ndim. */ - bool IsUnknownNdim() const { return ndim == kUnknownDim; } + bool IsUnknownNdim() const { return ndim == kUnknownNDim; } void VisitAttrs(AttrVisitor* v) { v->Visit("values", &values); @@ -138,7 +138,7 @@ class ShapeStructInfo : public StructInfo { TVM_DLL ShapeStructInfo(Array values, Span span = Span()); /*! * \brief Construction with known unknown symbolic shape patterns. - * \param ndim Number of dimensions -- can be kUnknownDim + * \param ndim Number of dimensions -- can be kUnknownNDim * \param span The span of the AST. */ TVM_DLL ShapeStructInfo(int ndim, Span span = Span()); @@ -160,12 +160,12 @@ class TensorStructInfoNode : public StructInfoNode { DataType dtype; /*! * \brief The number of dimension of the tensor, can be unknown. - * \sa kUnknownDim + * \sa kUnknownNDim */ int ndim; /*! \return Whether the struct info contains unknown ndim. */ - bool IsUnknownNdim() const { return ndim == kUnknownDim; } + bool IsUnknownNdim() const { return ndim == kUnknownNDim; } /*! \return Whether the struct info contains unknown dtype. */ bool IsUnknownDtype() const { return dtype.is_void(); } diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index 0e66c9516866..0e8fef005ef8 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -38,7 +38,7 @@ namespace tvm { namespace relax { /*! \brief Indicates the number of dimensions of a tensor is unknown at compile time. */ -static constexpr int kUnknownDim = -1; +static constexpr int kUnknownNDim = -1; class ShapeTypeNode : public TypeNode { public: @@ -63,7 +63,7 @@ class ShapeTypeNode : public TypeNode { class ShapeType : public Type { public: // TODO(relax-team): remove the default value later. - TVM_DLL ShapeType(int ndim = kUnknownDim, Span span = Span()); + TVM_DLL ShapeType(int ndim = kUnknownNDim, Span span = Span()); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode); }; @@ -112,7 +112,7 @@ class DynTensorTypeNode : public BaseTensorTypeNode { hash_reduce(dtype); } - inline bool IsUnknownNdim() const { return ndim == kUnknownDim; } + inline bool IsUnknownNdim() const { return ndim == kUnknownNDim; } inline bool IsUnknownDtype() const { return dtype.is_void(); } @@ -141,25 +141,6 @@ class DynTensorType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode); }; -class DimTypeNode : public TypeNode { - public: - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } - - bool SEqualReduce(const DimTypeNode* other, SEqualReducer equal) const { return true; } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } - - static constexpr const char* _type_key = "relax.DimType"; - TVM_DECLARE_FINAL_OBJECT_INFO(DimTypeNode, TypeNode); -}; - -class DimType : public Type { - public: - TVM_DLL DimType(Span span = Span()); - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DimType, Type, DimTypeNode); -}; - class PackedFuncTypeNode : public TypeNode { public: void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index f781372f23b5..12f751e82141 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -125,7 +125,7 @@ TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true, /*! * \brief Check if the given expression is a "leaf" node for normalization purposes. * The following expressions are defined as leaf nodes: Var, Constant, ShapeExpr, - * GlobalVar, RuntimeDepShape, Op, ExternFunc, and Tuple. + * GlobalVar, Op, ExternFunc, and Tuple. * Tuples are included in this list mainly for convenience in grouping operator arguments. * *Note*: Since tuples can contain nested expressions, it is necessary to ensure that * values nested inside them are also leaves. diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index c49f12ef5339..854050464d4a 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -84,7 +84,6 @@ class ConstantNode : public ExprNode { v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); - v->Visit("shape_", &shape_); } bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const { @@ -131,7 +130,6 @@ class TupleNode : public ExprNode { v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); - v->Visit("shape_", &shape_); } bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const { @@ -330,7 +328,6 @@ class CallNode : public ExprNode { v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); - v->Visit("shape_", &shape_); } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index deda47c24a36..1fce137300dc 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -52,17 +52,6 @@ def checked_type(self): raise ValueError("The type checker has not populated the checked_type for this node") return ret - @property - def shape(self): - """Get the shape of tvm.relay.Expr. - - Returns - ------- - shape : tvm.ir.RelayExpr - The expression that represents the shape. - """ - return _ffi_api.RelayExprShape(self) - @property def struct_info(self) -> "tvm.relax.StructInfo": """Get the struct info field @@ -234,7 +223,6 @@ def is_relax_expr(expr: RelayExpr) -> bool: relax.Var, relax.DataflowVar, relax.ShapeExpr, - relax.RuntimeDepShape, relax.SeqExpr, relax.Function, relax.ExternFunc, diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 4b80b16d2679..d19bc70f0471 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -28,66 +28,56 @@ from . import struct_info # Expr -Expr = expr.Expr -Span = expr.Span -SourceName = expr.SourceName -Id = expr.Id -GlobalVar = expr.GlobalVar -Var = expr.Var -DataflowVar = expr.DataflowVar -Binding = expr.Binding -MatchShape = expr.MatchShape -VarBinding = expr.VarBinding -BindingBlock = expr.BindingBlock -DataflowBlock = expr.DataflowBlock -SeqExpr = expr.SeqExpr -ShapeExpr = expr.ShapeExpr -RuntimeDepShape = expr.RuntimeDepShape -Tuple = expr.Tuple -TupleGetItem = expr.TupleGetItem -Function = expr.Function -ExternFunc = expr.ExternFunc -Call = expr.Call -If = expr.If +from .expr import ( + Expr, + Span, + SourceName, + Id, + GlobalVar, + Var, + DataflowVar, + Binding, + MatchShape, + VarBinding, + BindingBlock, + DataflowBlock, + SeqExpr, + ShapeExpr, + Tuple, + TupleGetItem, + Function, + ExternFunc, + Call, + If, + Constant, +) -# helper functions -const = expr.const -Constant = expr.Constant -extern = expr.extern -te_tensor = expr.te_tensor +from .expr import const, extern, get_shape_of # Type -Type = ty.Type -ShapeType = ty.ShapeType -ObjectType = ty.ObjectType -DynTensorType = ty.DynTensorType -DimType = ty.DimType -TupleType = ty.TupleType -FuncType = ty.FuncType -PackedFuncType = ty.PackedFuncType +from .ty import Type, ObjectType, ShapeType, DynTensorType, TupleType, FuncType, PackedFuncType # VM -ExecBuilder = exec_builder.ExecBuilder -VirtualMachine = vm.VirtualMachine +from .exec_builder import ExecBuilder +from .vm import VirtualMachine # Operator from .op.base import call_tir, make_closure, invoke_closure from .op.op_attrs import VMAllocStorageAttrs, VMAllocTensorAttrs # IRBuilder -BlockBuilder = block_builder.BlockBuilder +from .block_builder import BlockBuilder # ExprFunctor -ExprFunctor = expr_functor.ExprFunctor -PyExprVisitor = expr_functor.PyExprVisitor -PyExprMutator = expr_functor.PyExprMutator - +from .expr_functor import ExprFunctor, PyExprVisitor, PyExprMutator # StructInfo -StructInfo = struct_info.StructInfo -ObjectStructInfo = struct_info.ObjectStructInfo -PrimStructInfo = struct_info.PrimStructInfo -ShapeStructInfo = struct_info.ShapeStructInfo -TensorStructInfo = struct_info.TensorStructInfo -TupleStructInfo = struct_info.TupleStructInfo -FuncStructInfo = struct_info.FuncStructInfo +from .struct_info import ( + StructInfo, + ObjectStructInfo, + PrimStructInfo, + ShapeStructInfo, + TensorStructInfo, + TupleStructInfo, + FuncStructInfo, +) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index c5b9850e8b64..e92d09e266f6 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -21,7 +21,7 @@ configuring the passes and scripting them in Python. """ -from typing import Dict, List, Optional +from typing import Dict, List from enum import IntEnum import tvm @@ -48,22 +48,6 @@ def get_static_type(sinfo: StructInfo) -> Type: return _ffi_api.GetStaticType(sinfo) # type: ignore -def get_legacy_shape_hint(sinfo: StructInfo) -> Optional[Expr]: - """Get the corresponding shape from a StructInfo. - - Parameters - ---------- - sinfo : StructInfo - The input struct info. - - Returns - ------- - ret : Type - The corresponding shape. - """ - return _ffi_api.GetLegacyShapeHint(sinfo) # type: ignore - - def erase_to_well_defined( sinfo: StructInfo, shape_var_map: Dict[tir.Var, tir.PrimExpr] = None, @@ -298,29 +282,6 @@ def shape_vars(expr: Expr) -> List[tir.Var]: return _ffi_api.shape_vars(expr) # type: ignore -def derive_func_ret_shape(args: List[Var], body: Expr) -> Expr: - """ - Given the argument vars and body, derives a return shape for - a function with those args and that body. - If the body's shape contains free shape vars (those not used in the args), the - return shape is relaxed to RuntimeDepShape; otherwise, the body's shape is used. - - Parameters - ---------- - args: List[Var] - The argument variables, ideally with the shape_ field filled in - - body: Expr - The functino body, ideally with the shape_ field filled in - - Returns - ------- - ret: Expr - An expression that can serve as the return shape for the function - """ - return _ffi_api.derive_func_ret_shape(args, body) # type: ignore - - def bound_vars(expr: Expr) -> List[Var]: """ Return all bound variables from expression expr. diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index e8dc1b03d5aa..679370bff61c 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -33,8 +33,7 @@ BindingBlock, Tuple, BaseFunc, - VarBinding, - MatchShape, + Binding, ) from .op.base import call_tir from . import _ffi_api @@ -702,24 +701,6 @@ def update_func(self, gv: GlobalVar, updated_func: BaseFunc) -> None: """ return _ffi_api.BlockBuilderUpdateFunction(self, gv, updated_func) # type: ignore - def can_prove_shape_equal(self, lhs: Expr, rhs: Expr) -> bool: - """Check if two shape expressions can be proven equal at compile time. - - Parameters - ---------- - lhs : Expr - The input lhs shape. - - rhs: Expr - The input rhs shape. - - Returns - ------- - ret : bool - Whether we can prove lhs shape is the same as the rhs shape. - """ - return _ffi_api.BlockBuilderCanProveShapeEqual(self, lhs, rhs) # type: ignore - def current_block_is_dataflow(self) -> bool: """Check if the block being built is DataflowBlock or not. @@ -730,50 +711,15 @@ def current_block_is_dataflow(self) -> bool: """ return _ffi_api.BlockBuilderCurrentBlockIsDataFlow(self) # type: ignore - def emit_var_binding(self, binding: VarBinding) -> Var: - """Emits a variable binding, and returns the bound Var. - - Parameters - ---------- - binding: VarBinding - The variable binding. - - Returns - ------- - var: Var - The bound variable. - """ - return _ffi_api.BlockBuilderEmitVarBinding(self, binding) # type: ignore - - def emit_output_var_binding(self, binding: VarBinding) -> Var: - """Generate an output for the current dataflow block. - - Parameters - ---------- - binding: VarBinding - The output binding to output. - - Returns - ------- - var: Var - The variable bound to output. - """ - return _ffi_api.BlockBuilderEmitOutputVarBinding(self, binding) # type: ignore - - def match_shape_binding(self, binding: MatchShape) -> Var: - """Emit a MatchShape binding. + def emit_normalized(self, binding: Binding) -> None: + """Emit an already normalized binding. Parameters ---------- - binding: MatchShape - The MatchShape binding to be emitted. - - Returns - ------- - var: Var - The variable bound to the MatchShape. + binding: Binding + The binding to be emitted. """ - return _ffi_api.BlockBuilderEmitMatchShapeBinding(self, binding) # type: ignore + _ffi_api.BlockBuilderEmitNormalized(self, binding) # type: ignore def lookup_binding(self, var: Var) -> Optional[Expr]: """Lookup a var in the binding table binding_table_. diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 97b9ce317cb2..31dbffda4a77 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -204,17 +204,6 @@ def match(self, expr, var2val: Optional[Dict[Var, Expr]] = None) -> bool: """ return ffi.match_expr(self, expr, var2val) # type: ignore - def has_rt_dep_shape(self) -> "AndPattern": - """ - Syntax sugar for assuming current node has a runtime-dependent shape - - Returns - ------- - result: AndPattern - The resulting AndPattern - """ - return RuntimeDepShapePattern(self) - def used_by(self, other: Union["DFPattern", "PatternSeq"], index=-1) -> "PatternSeq": """ The current pattern being used by another pattern (sequence) @@ -276,14 +265,6 @@ def fork_to(self, *args) -> None: self ^ v -@register_df_node -class RuntimeDepShapePattern(DFPattern): - """A pattern matching a Relax RuntimeDepShape.""" - - def __init__(self, pattern: DFPattern): - self.__init_handle_by_constructor__(ffi.RuntimeDepShapePattern, pattern) # type: ignore - - @register_df_node class ExprPattern(DFPattern): """A pattern which matches an expression. diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index e60f846b172a..17e36fca211c 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -212,14 +212,6 @@ def make_shape(shape: Union[List[Any], typing.Tuple[Any, ...]]) -> ShapeExpr: raise ValueError("Wrong type") -@tvm._ffi.register_object("relax.expr.RuntimeDepShape") -class RuntimeDepShape(Expr): - """A shape expression which allows users to construct a runtime dependent shape.""" - - def __init__(self, span: Span = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.RuntimeDepShape, span) # type: ignore - - @tvm._ffi.register_object("relax.expr.Constant") class Constant(Expr): def __init__(self, data: tvm.nd.NDArray, span: Span = None) -> None: @@ -509,5 +501,29 @@ def te_tensor(value: Expr, name: str = "rxplaceholder"): return _ffi_api.TETensor(value, name) # type: ignore +def get_shape_of(expr: Expr) -> Expr: + """Get shape of expr. + + Parameters + ---------- + expr: Expr + The input expr. + + Returns + ------- + shape: Expr + The shape expression + + Note + ---- + This function requires expr to be normalized. + The function will report an error if expr's StructInfo is not TensorStructInfo. + It will try to return symbolic function when possible. If the tensor do not + have a compile-time symbolic shape, the function will then choose to return + `Call(relax.op.shape_of, [expr])`. + """ + return _ffi_api.GetShapeOf(expr) # type: ignore + + def _update_struct_info(expr: Expr, struct_info: Optional[StructInfo]) -> None: _ffi_api.UpdateStructInfo(expr, struct_info) # type: ignore diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py index b05ff53c77b5..7370d2aa14ec 100644 --- a/python/tvm/relax/expr_functor.py +++ b/python/tvm/relax/expr_functor.py @@ -26,7 +26,7 @@ from .expr import Type, Span, Expr from .expr import Function, ExternFunc from .expr import Constant, Var, DataflowVar -from .expr import ShapeExpr, RuntimeDepShape +from .expr import ShapeExpr from .expr import GlobalVar, SeqExpr, Tuple from .expr import Call, If, TupleGetItem from .expr import Binding, MatchShape, VarBinding @@ -127,8 +127,6 @@ def visit_expr(self, expr: Expr) -> Expr: ret = self.visit_var_(expr) elif isinstance(expr, ShapeExpr): ret = self.visit_shape_expr_(expr) - elif isinstance(expr, RuntimeDepShape): - ret = self.visit_runtime_dep_shape_(expr) elif isinstance(expr, ExternFunc): ret = self.visit_extern_func_(expr) elif isinstance(expr, GlobalVar): # type: ignore @@ -165,9 +163,6 @@ def visit_var_(self, op: Var): def visit_shape_expr_(self, op: ShapeExpr): raise NotImplementedError() - def visit_runtime_dep_shape_(self, op: RuntimeDepShape): - raise NotImplementedError() - def visit_extern_func_(self, op: ExternFunc): raise NotImplementedError() @@ -254,7 +249,6 @@ def __init__( f_visit_var_: Callable = None, f_visit_dataflow_var_: Callable = None, f_visit_shape_expr_: Callable = None, - f_visit_runtime_dep_shape_: Callable = None, f_visit_extern_func_: Callable = None, f_visit_global_var_: Callable = None, f_visit_function_: Callable = None, @@ -285,7 +279,6 @@ def __init__( f_visit_var_, f_visit_dataflow_var_, f_visit_shape_expr_, - f_visit_runtime_dep_shape_, f_visit_extern_func_, f_visit_global_var_, f_visit_function_, @@ -375,7 +368,6 @@ def MyExprVisitor(PyExprVisitor): "visit_var_", "visit_dataflow_var_", "visit_shape_expr_", - "visit_runtime_dep_shape_", "visit_extern_func_", "visit_global_var_", "visit_function_", @@ -514,19 +506,6 @@ def visit_shape_expr_(self, op: ShapeExpr) -> None: # Using self._outer() to ref _PyExprVisitor return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore - def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> None: - """Visit RuntimeDepShape. - Users can customized this function to overwrite VisitExpr_(const RuntimeDepShapeNode* op) - on the C++ side. - - Parameters - ---------- - op : RuntimeDepShape - The RuntimeDepShape to be visited. - """ - # Using self._outer() to ref _PyExprVisitor - return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore - def visit_extern_func_(self, op: ExternFunc) -> None: """Visit ExternFunc. Users can customized this function to overwrite VisitExpr_(const ExternFuncNode* op) @@ -754,7 +733,6 @@ def __init__( f_visit_var_: Callable = None, f_visit_dataflow_var_: Callable = None, f_visit_shape_expr_: Callable = None, - f_visit_runtime_dep_shape_: Callable = None, f_visit_extern_func_: Callable = None, f_visit_global_var_: Callable = None, f_visit_function_: Callable = None, @@ -786,7 +764,6 @@ def __init__( f_visit_var_, f_visit_dataflow_var_, f_visit_shape_expr_, - f_visit_runtime_dep_shape_, f_visit_extern_func_, f_visit_global_var_, f_visit_function_, @@ -892,7 +869,6 @@ def MyExprMutator(PyExprMutator): "visit_var_", "visit_dataflow_var_", "visit_shape_expr_", - "visit_runtime_dep_shape_", "visit_extern_func_", "visit_global_var_", "visit_function_", @@ -1075,24 +1051,6 @@ def visit_shape_expr_(self, op: ShapeExpr) -> Expr: # Using self._outer() to ref _PyExprMutator return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore - def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> Expr: - """Visit RuntimeDepShape. - Users can customized this function to overwrite VisitExpr_(const RuntimeDepShapeNode* op) - on the C++ side. - - Parameters - ---------- - op : RuntimeDepShape - The RuntimeDepShape to be visited. - - Returns - ------- - result : Expr - The Expr after transformation - """ - # Using self._outer() to ref _PyExprMutator - return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore - def visit_extern_func_(self, op: ExternFunc) -> Expr: """Visit ExternFunc. Users can customized this function to overwrite VisitExpr_(const ExternFuncNode* op) diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index d80542a7be58..f272b3aa5a73 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -134,9 +134,6 @@ def visit_shape_expr_(self, op: relax.ShapeExpr) -> str: op, "ShapeExpr", values=self.build_list(map(self.visit_prim_expr_, op.values)) ) - def visit_runtime_dep_shape_(self, op: relax.RuntimeDepShape) -> str: - return self.build_expr(op, "RuntimeDepShape") - def visit_extern_func_(self, op: relax.ExternFunc) -> str: # ExternFunc does not inherit from relax.Expr either, # so it doesn't have checked_type_ or shape_ fields and we don't use build_expr @@ -238,8 +235,6 @@ def visit_type_(self, type_node: relax.Type) -> str: if type_node.dtype != "": fields["dtype"] = type_node.dtype return self.build_ast_node("DynTensorType", **fields) - if isinstance(type_node, relax.DimType): - return self.build_ast_node("DimType") if isinstance(type_node, relax.TupleType): return self.build_ast_node( "TupleType", fields=self.build_list(map(self.visit_type_, type_node.fields)) diff --git a/python/tvm/relax/testing/nn.py b/python/tvm/relax/testing/nn.py index e8dddae1ef4f..77125faa5086 100644 --- a/python/tvm/relax/testing/nn.py +++ b/python/tvm/relax/testing/nn.py @@ -119,7 +119,7 @@ def _unpack_params(value: object) -> List[relax.Var]: def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]: """Utility function to initialize model's parameters.""" - shape_dict = {v.name_hint: v.shape_ for v in mod["main"].params} + shape_dict = {v.name_hint: v.struct_info.shape for v in mod["main"].params} params = [] for k, v in shape_dict.items(): if k.startswith("data"): diff --git a/python/tvm/relax/transform/tuning_api/default_functions.py b/python/tvm/relax/transform/tuning_api/default_functions.py index d6c90da05688..b72b2f30ee2b 100644 --- a/python/tvm/relax/transform/tuning_api/default_functions.py +++ b/python/tvm/relax/transform/tuning_api/default_functions.py @@ -244,7 +244,9 @@ def relax_eval_func(rt_mod, device, evaluator_config, repeated_args): else: # If build passes, set up runner input and measure the performance. args_info = [ - TensorInfo(shape=[int(i) for i in p.shape], dtype=p.checked_type.dtype) + TensorInfo( + shape=[int(i) for i in p.struct_info.shape], dtype=p.struct_info.dtype + ) for p in mod["main"].params ] # convert list[Var] to list[TensorInfo] runner_input = RunnerInput( diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py index 784643843341..8af63d95d4af 100644 --- a/python/tvm/relax/ty.py +++ b/python/tvm/relax/ty.py @@ -75,14 +75,6 @@ def __init__(self, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.PackedFuncType, span) # type: ignore -@tvm._ffi.register_object("relax.DimType") -class DimType(Type): - """The type of indices/shape dimensions in Relax.""" - - def __init__(self, span: Span = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.DimType, span) # type: ignore - - def is_base_of(base: Type, derived: Type) -> bool: """Check the subtype relationship between base and derived. diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index d63047cb6e20..51dfce26ef8c 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -80,34 +80,6 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - raise TypeError(f"Unsupported type {type(value)} in assignment") -def eval_shape_annotation( - self: Parser, node: Union[doc.Expression, doc.expr], shape: relax.Expr -) -> Any: - if shape is None: - return None - if isinstance(shape, relax.RuntimeDepShape): - return shape - elif isinstance(shape, relax.ShapeExpr): - shape = list(shape.values) - for i, expr in enumerate(shape): - # Define the symbolic shape var - if isinstance(expr, tir.Var): - name = expr.name - if name in self.var_table.get(): - shape[i] = self.var_table.get()[name] - else: - self.var_table.add(name, shape[i]) - return relax.ShapeExpr(shape) - elif isinstance(shape, relax.Tuple): - shapes = [eval_shape_annotation(self, node, s) for s in shape.fields] - if any([s is None for s in shapes]): - return None - return relax.Tuple(shapes) - else: - self.report_error(node, f"Unsupported shape {type(shape)}") - return None - - # pylint: disable=inconsistent-return-statements def eval_type_annotation( self: Parser, node: Union[doc.Expression, doc.expr] diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc index bcad70cec943..6bdaa52810b8 100644 --- a/src/relax/analysis/analysis.cc +++ b/src/relax/analysis/analysis.cc @@ -123,10 +123,6 @@ class VarVisitor : protected ExprVisitor { VisitExpr(arg); } - if (call_node->shape_) { - VisitExpr(Downcast(call_node->shape_.value())); - } - if (const GlobalVarNode* global_var_node = call_node->op.as()) { called_global_vars_.Insert(GetRef(global_var_node)); } diff --git a/src/relax/analysis/func_ret_shape.cc b/src/relax/analysis/func_ret_shape.cc deleted file mode 100644 index 45f283e308fd..000000000000 --- a/src/relax/analysis/func_ret_shape.cc +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include - -namespace tvm { -namespace relax { - -Expr DeriveFuncRetShape(Array args, Expr body) { - std::unordered_set arg_shape_var_set; - for (auto v : args) { - if (const ExprNode* s = v->shape_.as()) { - Expr shape = GetRef(s); - Array arg_shape_vars = ShapeVars(shape); - for (auto v : arg_shape_vars) { - arg_shape_var_set.insert(v); - } - } - } - - if (const ExprNode* s = body->shape_.as()) { - Expr body_shape = GetRef(s); - Array body_shape_vars = ShapeVars(body_shape); - for (auto v : body_shape_vars) { - // if the body shape contains a free var, then we can't - // be more specific than RuntimeDepShape - if (arg_shape_var_set.count(v) == 0) { - return RuntimeDepShape(); - } - } - // all vars are defined in the args, so we can use the body shape - return body_shape; - } - return RuntimeDepShape(); -} - -TVM_REGISTER_GLOBAL(("relax.analysis.derive_func_ret_shape")).set_body_typed(DeriveFuncRetShape); - -} // namespace relax -} // namespace tvm diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 06a2a2872bcd..2de06fe5d6f2 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -71,66 +71,11 @@ TVM_REGISTER_GLOBAL("relax.analysis.GetStaticType").set_body_typed([](const Stru return GetStaticType(info); }); -//-------------------------- -// GetLegacyShapeHint -//-------------------------- -// TODO(relax-team) remove this function after phasing out shape. -class LegacyShapeDeriver : public StructInfoFunctor(const StructInfo&)> { - public: - Optional VisitStructInfo_(const ObjectStructInfoNode* op) final { return NullOpt; } - - Optional VisitStructInfo_(const PrimStructInfoNode* op) final { return NullOpt; } - - Optional VisitStructInfo_(const ShapeStructInfoNode* op) final { return NullOpt; } - - Optional VisitStructInfo_(const TensorStructInfoNode* op) final { - if (op->shape.defined()) { - return op->shape; - } else { - return RuntimeDepShape(); - } - } - - Optional VisitStructInfo_(const TupleStructInfoNode* op) final { - bool valid = true; - Array fields = op->fields.Map([this, &valid](const StructInfo& sinfo) { - Optional shape = this->VisitStructInfo(sinfo); - valid &= shape.defined(); - return shape.value_or(Expr(nullptr)); - }); - - // recursively collect structinfo to make sure legacy shape is also well defined. - if (valid && fields.size() != 0) { - Tuple tuple(fields, op->span); - Array tuple_sinfo; - for (Expr field : tuple->fields) { - tuple_sinfo.push_back(GetStructInfo(field)); - } - UpdateStructInfo(tuple, TupleStructInfo(tuple_sinfo)); - return tuple; - } else { - return NullOpt; - } - } - - Optional VisitStructInfo_(const FuncStructInfoNode* op) final { return NullOpt; } -}; - -Optional GetLegacyShapeHint(const StructInfo& info) { return LegacyShapeDeriver()(info); } - -TVM_REGISTER_GLOBAL("relax.analysis.GetLegacyShapeHint").set_body_typed([](const StructInfo& info) { - return GetLegacyShapeHint(info); -}); - //-------------------------- // StructInfoFromType //-------------------------- StructInfo StructInfoFromType(const Type& type) { - return StructInfoFromTypeLegacyShapeHint(type, NullOpt); -} - -StructInfo StructInfoFromTypeLegacyShapeHint(const Type& type, Optional shape_hint) { if (type.as()) { return ObjectStructInfo(type->span); } else if (const PrimTypeNode* prim_type = type.as()) { @@ -138,24 +83,11 @@ StructInfo StructInfoFromTypeLegacyShapeHint(const Type& type, Optional sh } else if (const ShapeTypeNode* shape_type = type.as()) { return ShapeStructInfo(shape_type->ndim, type->span); } else if (const DynTensorTypeNode* tensor_type = type.as()) { - if (!shape_hint.defined() || shape_hint->IsInstance()) { - return TensorStructInfo(tensor_type->dtype, tensor_type->ndim); - } else { - return TensorStructInfo(shape_hint.value(), tensor_type->dtype); - } + return TensorStructInfo(tensor_type->dtype, tensor_type->ndim); } else if (const TupleTypeNode* tuple_type = type.as()) { Array fields; - if (shape_hint.defined() && shape_hint.value()->IsInstance()) { - Array shape_hint_fields = Downcast(shape_hint.value())->fields; - ICHECK_EQ(shape_hint_fields.size(), tuple_type->fields.size()); - for (size_t i = 0; i < tuple_type->fields.size(); ++i) { - fields.push_back( - StructInfoFromTypeLegacyShapeHint(tuple_type->fields[i], shape_hint_fields[i])); - } - } else { - for (const Type& field : tuple_type->fields) { - fields.push_back(StructInfoFromType(field)); - } + for (const Type& field : tuple_type->fields) { + fields.push_back(StructInfoFromType(field)); } return TupleStructInfo(fields, type->span); } else if (const FuncTypeNode* func_type = type.as()) { @@ -782,7 +714,7 @@ class StructInfoLCAFinder auto* rhs = other.as(); if (rhs == nullptr) return ObjectStructInfo(lhs->span); - int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownDim; + int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim; if (lhs->ndim != rhs->ndim || !lhs->values.defined() || !rhs->values.defined() || !CanProveShapeEqual(lhs->values.value(), rhs->values.value(), analyzer_)) { // prefers return same when possible @@ -802,7 +734,7 @@ class StructInfoLCAFinder // find the target dtype and ndim. DataType dtype = lhs->dtype == rhs->dtype ? lhs->dtype : DataType::Void(); - int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownDim; + int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim; // if ndim mismatch or one side of shape is missing // then we cannot keep in symbolic shape if (lhs->ndim != rhs->ndim || !lhs->shape.defined() || !rhs->shape.defined() || diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index f69472700372..67f16f17f108 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -41,7 +41,7 @@ * 10. The IR is in ANF: * (a) Expressions cannot contain nested complex expressions. * Here are the expressions that may be nested inside other expressions: - * Var, DataflowVar, GlobalVar, Constant, ShapeExpr, RuntimeDepShape, + * Var, DataflowVar, GlobalVar, Constant, ShapeExpr, * Op, Tuple (we call these "leaf" expressions). * (b) The right-hand side of a binding may contain a non-leaf expression * (where all expressions nested in it are leaf expressions), diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index e44f2205f240..10a769d1b6cb 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -187,7 +188,7 @@ class JSONSerializer * will flatten it. */ std::vector AddNode(JSONGraphObjectPtr node, const Expr& expr) { - auto checked_type = expr->checked_type(); + auto struct_info = GetStructInfo(expr); auto node_id = nodes_.size(); nodes_.push_back(node); std::vector ret; @@ -195,26 +196,27 @@ class JSONSerializer TypeVector dtype; // Flatten tuple node. - if (const auto* tuple_type = checked_type.as()) { - for (size_t i = 0; i < tuple_type->fields.size(); ++i) { - const auto* tensor_type = tuple_type->fields[i].as(); - ICHECK(tensor_type) << "Expect DynTensorType, but received: ." - << tuple_type->fields[i]->GetTypeKey(); - ICHECK(expr->shape_.defined()) << "Expect shape to be defined. "; - ShapeExpr output_shape = Downcast(expr->shape_.value()); + if (const auto* tuple_sinfo = struct_info.as()) { + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + const auto* tensor_sinfo = tuple_sinfo->fields[i].as(); + ICHECK(tensor_sinfo) << "Expect TensorStructInfo, but received: ." + << tuple_sinfo->fields[i]->GetTypeKey(); + ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined."; + ShapeExpr output_shape = Downcast(tensor_sinfo->shape.value()); ret.push_back(JSONGraphNodeEntry(node_id, i)); shape.emplace_back(GetIntShape(output_shape->values)); - dtype.emplace_back(DType2String(tensor_type->dtype)); + dtype.emplace_back(DType2String(tensor_sinfo->dtype)); } - node->SetNumOutput(tuple_type->fields.size()); + node->SetNumOutput(tuple_sinfo->fields.size()); } else { - const auto* tensor_type = checked_type.as(); - ICHECK(tensor_type) << "Expect DynTensorType, but received: " << checked_type->GetTypeKey(); - ICHECK(expr->shape_.defined()) << "Expect shape to be defined. "; - ShapeExpr output_shape = Downcast(expr->shape_.value()); + const auto* tensor_sinfo = struct_info.as(); + ICHECK(tensor_sinfo) << "Expect TensorStructInfo, but received: " + << struct_info->GetTypeKey(); + ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined."; + ShapeExpr output_shape = Downcast(tensor_sinfo->shape.value()); shape.emplace_back(GetIntShape(output_shape->values)); - dtype.emplace_back(DType2String(tensor_type->dtype)); + dtype.emplace_back(DType2String(tensor_sinfo->dtype)); ret.push_back(JSONGraphNodeEntry(node_id, 0)); } std::vector shape_attrs; diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 93b9310444c7..29a6f38990f8 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -32,6 +33,49 @@ namespace tvm { namespace relax { +class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { + public: + // collect the PrimExpr slot for a given function + static Map Collect(Function func) { + PrimExprSlotCollector collector; + // collect shape delcarations in func params + for (auto param : func->params) { + collector.VisitStructInfo(GetStructInfo(param)); + collector.VisitExpr(param); + } + collector.VisitExpr(func->body); + // avoid create any slot for static shape. + if (!collector.dyn_shape_) collector.slot_map_.clear(); + return std::move(collector.slot_map_); + } + + private: + void VisitPrimExpr(const PrimExpr& expr) final { + if (!expr->IsInstance()) { + dyn_shape_ = true; + } + if (slot_map_.count(expr) == 0) { + slot_map_.Set(expr, slot_count_++); + } + } + + void VisitExpr_(const FunctionNode* op) final { + // Do not recurse into function node as it is self-contained + } + + void VisitStructInfo_(const FuncStructInfoNode* op) final { + // Do not recurse into function struct info as it is self-contained + } + + void VisitStructInfoExprField(const PrimExpr& expr) final { VisitPrimExpr(expr); } + + void VisitStructInfoExprField(const Expr& expr) final { ExprVisitor::VisitExpr(expr); } + + bool dyn_shape_ = false; + int slot_count_ = 0; + Map slot_map_; +}; + class VMShapeLowerMutator : public ExprMutator { public: static DataType ShapeDType() { return DataType::Int(64); } @@ -43,7 +87,7 @@ class VMShapeLowerMutator : public ExprMutator { Expr func = p.second; if (func->IsInstance()) { // prepare mapping and heap var - expr2slot_ = PrepareExpr2Slot(Downcast(func)); + expr2slot_ = PrimExprSlotCollector::Collect(Downcast(func)); heap_size_ = IntImm(ShapeDType(), expr2slot_.size()); shape_heap_ = Var("shape_heap", TensorStructInfo(ShapeExpr({heap_size_}), ShapeDType())); @@ -89,16 +133,16 @@ class VMShapeLowerMutator : public ExprMutator { Expr VisitExpr_(const FunctionNode* node) override { if (heap_size_->value > 0) { builder_->BeginBindingBlock(); - builder_->Emit(VarBinding( - shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})}))); - + auto alloc_shape_heap = builder_->Normalize( + Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})})); + builder_->EmitNormalized(VarBinding(shape_heap_, alloc_shape_heap)); for (Var param : node->params) { - if (param->shape_.operator bool() && param->shape_.value().as()) { - if (auto* param_type = param->checked_type_.as()) { - if (param_type->ndim != 0) { - Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {param}), "sh"); - StoreShape(shape, Downcast(param->shape_.value())->values); - } + // TODO(relax-team): handle generalized case with tuple of Tensors + if (auto* tensor_info = GetStructInfoAs(param)) { + auto* shape_expr = tensor_info->shape.as(); + if (tensor_info->ndim != 0 && shape_expr) { + Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {param}), "sh"); + StoreShape(shape, shape_expr->values); } } } @@ -125,7 +169,7 @@ class VMShapeLowerMutator : public ExprMutator { // The ret_type is weakened to unknown-dimensional DynTensorType. // TODO(@yuchen): change all tensor types in the function to unknown ndim if (const auto* tensor_sinfo = ret_struct_info.as()) { - ret_struct_info = TensorStructInfo(tensor_sinfo->dtype, /*ndim=*/kUnknownDim); + ret_struct_info = TensorStructInfo(tensor_sinfo->dtype, /*ndim=*/kUnknownNDim); } return builder_->Normalize(Function(node->params, new_body, ret_struct_info, node->attrs)); @@ -168,32 +212,6 @@ class VMShapeLowerMutator : public ExprMutator { return ret; } - Map PrepareExpr2Slot(Function expr) const { - int cnt = 0; - bool is_dyn_shape = false; - Map ret; - auto func = [&](const Expr& e) { - if (e->IsInstance()) { - ShapeExpr shape = Downcast(e); - for (auto prim_e : shape->values) { - if (!prim_e->IsInstance()) { - is_dyn_shape = true; - } - if (ret.count(prim_e) == 0) { - ret.Set(prim_e, cnt++); - } - } - } - }; - PostOrderVisit(expr, func); - - // Avoid allocating shape heap and do shape computation for static-shape program - if (!is_dyn_shape) { - ret.clear(); - } - return ret; - } - /*! \brief Store symbolic shape into indices of the VM shape heap. */ void StoreShape(Expr shape, Array pattern) { static const Op& store_shape_op = Op::Get("relax.vm.builtin.store_shape"); @@ -201,8 +219,9 @@ class VMShapeLowerMutator : public ExprMutator { Array indices; for (size_t i = 0; i < pattern.size(); ++i) { - Integer idx = expr2slot_.at(pattern[i]); - indices.push_back(idx); + auto it = expr2slot_.find(pattern[i]); + ICHECK(it != expr2slot_.end()) << "PrimExpr pattern " << pattern[i] << " is not in expr2slot"; + indices.push_back((*it).second); } store_shape_attr->indices = indices; builder_->Emit(Call(store_shape_op, {shape, shape_heap_}, Attrs(store_shape_attr)), "gv"); diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index df253f6aec02..bea13438ffce 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -69,54 +69,6 @@ class BlockBuilderImpl : public BlockBuilderNode { //------------------------------- NameTable* name_table() final { return name_table_.get(); } - bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs) final { - if (lhs.same_as(rhs)) { - return true; - } - - // TODO(relax-team): revisit this logic after struct info. - if (lhs->IsInstance() && rhs->IsInstance()) { - return true; - } - - // try run symbolic shape proves that two shape equals each other. - if (lhs->IsInstance() && rhs->IsInstance()) { - const auto* lhs_shape = lhs.as(); - const auto* rhs_shape = rhs.as(); - size_t lhs_ndim = lhs_shape->values.size(); - size_t rhs_ndim = rhs_shape->values.size(); - if (lhs_ndim != rhs_ndim) { - return false; - } - for (size_t i = 0; i < lhs_ndim; ++i) { - PrimExpr lhs_dim = lhs_shape->values[i]; - PrimExpr rhs_dim = rhs_shape->values[i]; - if (lhs_dim.dtype() != rhs_dim.dtype() || !analyzer_.CanProveEqual(lhs_dim, rhs_dim)) { - return false; - } - } - return true; - } - - // tuple comparison - // TODO(relax-team): can be removed later after struct info. - if (lhs->IsInstance() && rhs->IsInstance()) { - const auto* lhs_tuple = lhs.as(); - const auto* rhs_tuple = rhs.as(); - if (lhs_tuple->fields.size() != rhs_tuple->fields.size()) { - return false; - } - for (size_t i = 0; i < lhs_tuple->fields.size(); ++i) { - if (!CanProveShapeEqual(lhs_tuple->fields[i], rhs_tuple->fields[i])) { - return false; - } - } - return true; - } - - return false; - } - IRModule GetContextIRModule() const final { return context_mod_; } GlobalVar AddFunction(const BaseFunc& func, String func_name_hint) final { @@ -242,18 +194,6 @@ class BlockBuilderImpl : public BlockBuilderNode { return this->Emit(expr, CurrentBlockFrame()->is_dataflow, name_hint); } - Var Emit(VarBinding binding) final { - BlockFrame* cur_frame = CurrentBlockFrame(); - if (cur_frame->is_dataflow) { - ICHECK(binding->var.as()) - << "Emit can only be used for local bindings in a dataflow block, use EmitOutput for " - "output bindings instead"; - } - cur_frame->bindings.push_back(binding); - binding_table_[binding->var->vid] = binding->value; - return binding->var; - } - Var EmitMatchShape(Expr value, Array pattern, String name_hint) final { value = this->Normalize(value); @@ -277,14 +217,6 @@ class BlockBuilderImpl : public BlockBuilderNode { return var; } - Var EmitMatchShape(MatchShape binding) final { - BlockFrame* cur_frame = CurrentBlockFrame(); - // NOTE match shape do not follow simple binding rule - // as a result should not appear in binding table. - cur_frame->bindings.push_back(binding); - return binding->var; - } - Var EmitOutput(Expr output, String name_hint) final { BlockFrame* cur_frame = CurrentBlockFrame(); @@ -293,17 +225,6 @@ class BlockBuilderImpl : public BlockBuilderNode { return Emit(output, false, name_hint); } - Var EmitOutput(VarBinding binding) final { - BlockFrame* cur_frame = CurrentBlockFrame(); - - ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; - ICHECK(!binding->var.as()) << "EmitOutput can only emit Var bindings."; - - cur_frame->bindings.push_back(binding); - binding_table_[binding->var->vid] = binding->value; - return binding->var; - } - void EmitNormalized(Binding binding) final { BlockFrame* cur_frame = CurrentBlockFrame(); @@ -312,14 +233,23 @@ class BlockBuilderImpl : public BlockBuilderNode { ICHECK(!var_binding->var.as()) << "Cannot emit dataflowvar in non-dataflow block"; } + // normalized check + ICHECK(var_binding->var->struct_info_.defined()); + ICHECK(var_binding->value->struct_info_.defined()); cur_frame->bindings.push_back(binding); binding_table_[var_binding->var->vid] = var_binding->value; } else { - auto* ptr = binding.as(); - ICHECK(ptr); - if (!cur_frame->is_dataflow) { - ICHECK(!ptr->var.as()) << "Cannot emit dataflowvar in non-dataflow block"; + auto* match_shape = binding.as(); + ICHECK(match_shape); + if (match_shape->var.defined()) { + if (!cur_frame->is_dataflow) { + ICHECK(!match_shape->var.as()) + << "Cannot emit dataflowvar in non-dataflow block"; + } + ICHECK(match_shape->var->struct_info_.defined()); } + // normalized check + ICHECK(match_shape->value->struct_info_.defined()); // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. cur_frame->bindings.push_back(binding); @@ -572,7 +502,6 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor GetTupleShape(const Tuple& tuple) { - Array tuple_shape; - for (Expr field : tuple->fields) { - if (field->shape_.defined()) { - tuple_shape.push_back(Downcast(field->shape_.value())); - } else { - break; - } - } - if (tuple_shape.size() == tuple->fields.size()) { - return Tuple(tuple_shape); - } - return NullOpt; - } - Expr VisitExpr_(const TupleNode* op) final { bool unchanged = true; Array new_fields; @@ -1003,29 +916,19 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit").set_body_typed([](BlockBuilder bui return builder->Emit(expr); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitVarBinding") - .set_body_typed([](BlockBuilder builder, VarBinding binding) { - return builder->Emit(binding); - }); - TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchShape") .set_body_typed([](BlockBuilder builder, Expr value, Array pattern) { return builder->EmitMatchShape(value, pattern); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchShapeBinding") - .set_body_typed([](BlockBuilder builder, MatchShape binding) { - return builder->EmitMatchShape(binding); - }); - TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput") .set_body_typed([](BlockBuilder builder, const Expr& output) { return builder->EmitOutput(output); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutputVarBinding") - .set_body_typed([](BlockBuilder builder, VarBinding binding) { - return builder->EmitOutput(binding); +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitNormalized") + .set_body_typed([](BlockBuilder builder, Binding binding) { + return builder->EmitNormalized(binding); }); TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName") @@ -1042,9 +945,6 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderUpdateFunction") TVM_REGISTER_GLOBAL("relax.BlockBuilderGetContextIRModule") .set_body_method(&BlockBuilderNode::GetContextIRModule); -TVM_REGISTER_GLOBAL("relax.BlockBuilderCanProveShapeEqual") - .set_body_method(&BlockBuilderNode::CanProveShapeEqual); - TVM_REGISTER_GLOBAL("relax.BlockBuilderCurrentBlockIsDataFlow") .set_body_method(&BlockBuilderNode::CurrentBlockIsDataFlow); diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 680e62e7f906..76bfdb12d2bc 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -432,9 +433,12 @@ static bool ShapeEqual(Analyzer* analyzer, const Array& lhs, const Arr bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) { // no need to jump, as var.shape == value.shape - if (const ShapeExprNode* shape_expr = expr->shape().as()) - return ShapeEqual(&analyzer_, op->shape, shape_expr->values) && - VisitDFPattern(op->pattern, expr); + if (const auto* tinfo = GetStructInfoAs(expr)) { + if (const ShapeExprNode* shape_expr = tinfo->shape.as()) { + return ShapeEqual(&analyzer_, op->shape, shape_expr->values) && + VisitDFPattern(op->pattern, expr); + } + } return false; } @@ -494,10 +498,6 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr return true; } -bool DFPatternMatcher::VisitDFPattern_(const RuntimeDepShapePatternNode* op, const Expr& expr) { - return expr->shape_->IsInstance() && VisitDFPattern(op->pattern, expr); -} - bool MatchExpr(DFPattern pattern, Expr expr, Optional> var2val) { if (var2val.defined()) // autojump is enabled with var2val. return DFPatternMatcher(std::move(var2val.value())).Match(pattern, expr); diff --git a/src/relax/ir/dataflow_matcher_impl.h b/src/relax/ir/dataflow_matcher_impl.h index 13f2b36840e1..89f3d114c1e3 100644 --- a/src/relax/ir/dataflow_matcher_impl.h +++ b/src/relax/ir/dataflow_matcher_impl.h @@ -63,7 +63,6 @@ class DFPatternMatcher : public DFPatternFunctorstream << "GlobalVarPattern(" << node->name_hint() << ")"; }); -TVM_REGISTER_NODE_TYPE(RuntimeDepShapePatternNode); -RuntimeDepShapePattern::RuntimeDepShapePattern(DFPattern root) { - ObjectPtr n = make_object(); - n->pattern = std::move(root); - data_ = std::move(n); -} -TVM_REGISTER_GLOBAL("relax.dpl.RuntimeDepShapePattern").set_body_typed([](DFPattern root) { - return RuntimeDepShapePattern(std::move(root)); -}); -RELAX_PATTERN_PRINTER_DEF(RuntimeDepShapePatternNode, [](auto p, auto node) { - p->stream << "RuntimeDepShapePattern(" << node->pattern << " has runtime-dep shape)"; -}); - TVM_REGISTER_NODE_TYPE(ExprPatternNode); ExprPattern::ExprPattern(Expr expr) { ObjectPtr n = make_object(); @@ -364,9 +351,6 @@ class DFPatternDuplicator : public DFPatternFunctor DFPattern VisitDFPattern_(const TypePatternNode* op) override { return TypePattern(op->pattern, op->type); } - DFPattern VisitDFPattern_(const RuntimeDepShapePatternNode* op) override { - return RuntimeDepShapePattern(op->pattern); - } DFPattern VisitDFPattern_(const DataflowVarPatternNode* op) override { return DataflowVarPattern(op->name); } @@ -401,9 +385,6 @@ DataTypePattern DFPattern::HasDtype(const std::string& dtype) const { ShapePattern DFPattern::HasShape(const Array& shape) const { return ShapePattern(*this, shape); } -RuntimeDepShapePattern DFPattern::HasRuntimeDepShape() const { - return RuntimeDepShapePattern(*this); -} DFPattern::operator PatternSeq() const { return PatternSeq{{*this}}; } diff --git a/src/relax/ir/dataflow_pattern_functor.cc b/src/relax/ir/dataflow_pattern_functor.cc index d05ebee4ca6b..37a98f28beef 100644 --- a/src/relax/ir/dataflow_pattern_functor.cc +++ b/src/relax/ir/dataflow_pattern_functor.cc @@ -101,7 +101,6 @@ void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPatte // leaf nodes. void DFPatternVisitor::VisitDFPattern_(const PrimArrPatternNode* op) {} void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} -void DFPatternVisitor::VisitDFPattern_(const RuntimeDepShapePatternNode* op) {} void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {} void DFPatternVisitor::VisitDFPattern_(const DataflowVarPatternNode* op) {} void DFPatternVisitor::VisitDFPattern_(const GlobalVarPatternNode* op) {} diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index 151052725cf0..48828689ad05 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -22,7 +22,7 @@ */ #include "./emit_te.h" -#include +#include namespace tvm { namespace relax { @@ -57,19 +57,16 @@ te::Tensor TETensor(Expr value, std::string name) { n->shape = std::move(shape); return te::PlaceholderOp(n).output(0); } - - Expr shape_expr = value->shape(); - CHECK(shape_expr->IsInstance()) + ICHECK(value->struct_info_.defined()) << "value must be normalized and contain StructInfo"; + auto* tensor_sinfo = GetStructInfoAs(value); + ICHECK(tensor_sinfo) << "Value must be a tensor"; + auto* shape_expr = tensor_sinfo->shape.as(); + CHECK(shape_expr) << "ValueError: Expression does not have an known symbolic shape, please consider use " "match_shape " << "to constrain the shape before passing into te_tensor"; - Array shape = Downcast(shape_expr)->values; - n->shape = shape; - Type type = value->checked_type(); - ICHECK(type->IsInstance()) - << "ValueError: Expression should have a inferred DynTensorType: " << type->GetTypeKey(); - DataType dtype = Downcast(type)->dtype; - n->dtype = dtype; + n->shape = shape_expr->values; + n->dtype = tensor_sinfo->dtype; return te::PlaceholderOp(n).output(0); } diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index e89353280b89..0f0ce72211da 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -23,27 +23,6 @@ #include namespace tvm { - -RelayExpr RelayExprNode::shape() const { - if (this->shape_.defined()) { - return Downcast(this->shape_); - } - if (this->struct_info_.defined()) { - Optional shape = - relax::GetLegacyShapeHint(Downcast(this->struct_info_.value())); - if (shape.defined()) { - return shape.value(); - } - } - static const Op& op = Op::Get("relax.shape_of"); - RelayExpr self = GetRef(this); - relax::Call call_shape_of(op, {self}, {}, {}); - call_shape_of->checked_type_ = relax::ShapeType(); - return call_shape_of; -} - -TVM_REGISTER_GLOBAL("ir.RelayExprShape").set_body_method(&RelayExprNode::shape); - namespace relax { using tvm::ReprPrinter; using tvm::runtime::Optional; @@ -250,7 +229,6 @@ ShapeExpr::ShapeExpr(Array values, Span span) { return value; }); n->span = span; - n->shape_ = NullOpt; n->checked_type_ = ShapeType(values.size()); n->struct_info_ = ShapeStructInfo(values, span); data_ = std::move(n); @@ -260,20 +238,6 @@ TVM_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, return ShapeExpr(values, span); }); -TVM_REGISTER_NODE_TYPE(RuntimeDepShapeNode); - -RuntimeDepShape::RuntimeDepShape(Span span) { - ObjectPtr n = make_object(); - n->span = span; - n->struct_info_ = ShapeStructInfo(kUnknownDim); - n->checked_type_ = ShapeType(); - data_ = std::move(n); -} - -TVM_REGISTER_GLOBAL("relax.RuntimeDepShape").set_body_typed([](Span span) { - return RuntimeDepShape(span); -}); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { const ShapeExprNode* node = static_cast(ref.get()); @@ -294,7 +258,6 @@ Var::Var(Id vid, Optional struct_info_annotation, Span span) { n->vid = std::move(vid); if (struct_info_annotation) { n->checked_type_ = GetStaticType(struct_info_annotation.value()); - n->shape_ = GetLegacyShapeHint(struct_info_annotation.value()); } n->struct_info_ = std::move(struct_info_annotation); n->span = std::move(span); @@ -318,7 +281,6 @@ DataflowVar::DataflowVar(Id vid, Optional struct_info_annotation, Sp n->vid = std::move(vid); if (struct_info_annotation) { n->checked_type_ = GetStaticType(struct_info_annotation.value()); - n->shape_ = GetLegacyShapeHint(struct_info_annotation.value()); } n->struct_info_ = std::move(struct_info_annotation); n->span = std::move(span); @@ -351,8 +313,6 @@ Constant::Constant(runtime::NDArray data, Span span) { n->struct_info_ = tinfo; n->checked_type_ = DynTensorType(tinfo->ndim, tinfo->dtype); - n->shape_ = tinfo->shape; - data_ = std::move(n); } @@ -557,5 +517,24 @@ TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol, return ExternFunc(global_symbol, span); }); +Expr GetShapeOf(const Expr& expr) { + // default case, to be normalized. + ICHECK(expr->struct_info_.defined()) << "GetShapeOf can only be applied to normalized expr"; + auto* tinfo = GetStructInfoAs(expr); + + ICHECK(tinfo != nullptr) << "ShapeOf can only be applied to expr with TensorStructInfo"; + if (tinfo->shape.defined()) return tinfo->shape.value(); + + static const Op& op = Op::Get("relax.shape_of"); + // default case, call shape of, eagerly normalize the expr. + relax::Call call_shape_of(op, {expr}, {}, {}); + UpdateStructInfo(call_shape_of, ShapeStructInfo(tinfo->ndim)); + return call_shape_of; +} + +TVM_REGISTER_GLOBAL("relax.GetShapeOf").set_body_typed([](const Expr& expr) { + return GetShapeOf(expr); +}); + } // namespace relax } // namespace tvm diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index d6ff6d1b2749..d2cbe8562362 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -85,13 +85,7 @@ namespace relax { void ExprVisitor::VisitExpr(const Expr& expr) { ExprFunctor::VisitExpr(expr); } -void ExprVisitor::VisitExpr_(const ConstantNode* op) { - this->VisitSpan(op->span); - - if (op->shape_) { - this->VisitExpr(Downcast(op->shape_.value())); - } -} +void ExprVisitor::VisitExpr_(const ConstantNode* op) { this->VisitSpan(op->span); } void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); } @@ -100,10 +94,6 @@ void ExprVisitor::VisitExpr_(const TupleNode* op) { for (Expr field : op->fields) { this->VisitExpr(field); } - - if (op->shape_) { - this->VisitExpr(Downcast(op->shape_.value())); - } } // Visit the use-site of a defined Var @@ -132,10 +122,6 @@ void ExprVisitor::VisitExpr_(const CallNode* op) { for (Expr arg : op->args) { this->VisitExpr(arg); } - - if (op->shape_) { - this->VisitExpr(Downcast(op->shape_.value())); - } } void ExprVisitor::VisitExpr_(const IfNode* op) { @@ -159,8 +145,6 @@ void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { this->VisitSpan(op->span); } -void ExprVisitor::VisitExpr_(const RuntimeDepShapeNode* op) { this->VisitSpan(op->span); } - void ExprVisitor::VisitExpr_(const ExternFuncNode* op) { this->VisitSpan(op->span); } void ExprVisitor::VisitExpr_(const SeqExprNode* op) { @@ -215,21 +199,9 @@ void ExprVisitor::VisitBindingBlock_(const DataflowBlockNode* block) { } } -void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) { - this->VisitSpan(var->span); - - if (var->shape_) { - this->VisitExpr(Downcast(var->shape_.value())); - } -} - -void ExprVisitor::VisitVarDef_(const VarNode* var) { - this->VisitSpan(var->span); +void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) { this->VisitSpan(var->span); } - if (var->shape_) { - this->VisitExpr(Downcast(var->shape_.value())); - } -} +void ExprVisitor::VisitVarDef_(const VarNode* var) { this->VisitSpan(var->span); } void ExprVisitor::VisitBinding(const Binding& binding) { if (const auto* node = binding.as()) { @@ -383,8 +355,6 @@ Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { } } -Expr ExprMutatorBase::VisitExpr_(const RuntimeDepShapeNode* op) { return GetRef(op); } - Expr ExprMutatorBase::VisitExpr_(const ExternFuncNode* op) { return GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const SeqExprNode* op) { @@ -557,24 +527,9 @@ RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(TupleGetItemNode); void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { Var new_var = this->VisitVarDef(binding->var); - auto emit = [this](VarBinding b) { - if (this->builder_->CurrentBlockIsDataFlow() && !b->var.as()) { - this->builder_->EmitOutput(b); - } else { - this->builder_->Emit(b); - } - }; - - // FIXME(@altanh): try to clean up all the fast paths and ty/shape infer, it's getting unwieldy - // if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { - // // no-op if there is no change - // emit(GetRef(binding)); - // return; - // } - // fast path: reemit binding if nothing changes if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { - emit(GetRef(binding)); + builder_->EmitNormalized(GetRef(binding)); return; } @@ -584,7 +539,7 @@ void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { this->var_remap_[binding->var->vid] = new_var; } - emit(VarBinding(new_var, new_value)); + builder_->EmitNormalized(VarBinding(new_var, new_value)); } void ExprMutator::VisitBinding_(const MatchShapeNode* binding) { @@ -610,15 +565,12 @@ void ExprMutator::VisitBinding_(const MatchShapeNode* binding) { // reemit old binding if nothing changes if (new_value.same_as(binding->value) && new_pattern.same_as(binding->pattern)) { if (!binding->var.defined() || (binding->var.defined() && new_var.same_as(binding->var))) { - builder_->EmitMatchShape(GetRef(binding)); + builder_->EmitNormalized(GetRef(binding)); return; } } - // TODO(@altanh, @yuchen): shape and type inference here too... - // TODO(@yuchen): when value's shape/type changed, create new var - // TODO(@yuchen): group the can prove shape/type logic and replace var into a function - builder_->EmitMatchShape( + builder_->EmitNormalized( MatchShape(new_value, Downcast(new_pattern)->values, new_var)); } diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 4b53baa43c8e..d831f9eedc0e 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -94,7 +94,7 @@ TVM_REGISTER_NODE_TYPE(ShapeStructInfoNode); TVM_REGISTER_GLOBAL("relax.ShapeStructInfo") .set_body_typed([](Optional> values, int ndim, Span span) { if (values.defined()) { - CHECK_EQ(ndim, kUnknownDim) << "ValueError: Cannot both specify values and ndim"; + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify values and ndim"; return ShapeStructInfo(values.value(), span); } else { return ShapeStructInfo(ndim, span); @@ -141,7 +141,7 @@ TVM_REGISTER_NODE_TYPE(TensorStructInfoNode); TVM_REGISTER_GLOBAL("relax.TensorStructInfo") .set_body_typed([](Optional shape, DataType dtype, int ndim, Span span) { if (shape.defined()) { - CHECK_EQ(ndim, kUnknownDim) << "ValueError: Cannot both specify shape and ndim"; + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape and ndim"; return TensorStructInfo(shape.value(), dtype, span); } else { return TensorStructInfo(dtype, ndim, span); @@ -234,7 +234,6 @@ void UpdateStructInfo(Expr expr, StructInfo struct_info) { expr->struct_info_ = struct_info; // also set checked type expr->checked_type_ = GetStaticType(struct_info); - expr->shape_ = GetLegacyShapeHint(struct_info); } TVM_REGISTER_GLOBAL("relax.UpdateStructInfo").set_body_typed([](Expr expr, StructInfo struct_info) { diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc index 71a06cbce787..49ef1d7163f1 100644 --- a/src/relax/ir/type.cc +++ b/src/relax/ir/type.cc @@ -72,16 +72,6 @@ TVM_REGISTER_GLOBAL("relax.DynTensorType").set_body_typed([](int ndim, DataType return DynTensorType(ndim, dtype, span); }); -DimType::DimType(Span span) { - ObjectPtr n = make_object(); - n->span = span; - data_ = std::move(n); -} - -TVM_REGISTER_NODE_TYPE(DimTypeNode); - -TVM_REGISTER_GLOBAL("relax.DimType").set_body_typed([](Span span) { return DimType(span); }); - PackedFuncType::PackedFuncType(Span span) { ObjectPtr n = make_object(); n->span = span; diff --git a/src/relax/ir/type_analysis.cc b/src/relax/ir/type_analysis.cc index 7c532d816c95..3a04eb51b6c2 100644 --- a/src/relax/ir/type_analysis.cc +++ b/src/relax/ir/type_analysis.cc @@ -42,7 +42,7 @@ class BaseTypeChecker : public TypeFunctor { bool VisitType_(const ShapeTypeNode* base) final { if (auto* rhs = derived_.as()) { - return base->ndim == kUnknownDim || base->ndim == rhs->ndim; + return base->ndim == kUnknownNDim || base->ndim == rhs->ndim; } return false; } diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index a60b9bf775e7..a72fa0643521 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -64,7 +64,7 @@ StructInfo ReturnObjectStructInfo(const Call& call, const BlockBuilder& ctx) { } StructInfo ReturnShapeStructInfo(const Call& call, const BlockBuilder& ctx) { - return ShapeStructInfo(kUnknownDim); + return ShapeStructInfo(kUnknownNDim); } // call_tir @@ -369,7 +369,7 @@ StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& ct if (const auto* output_shape = call->args[1].as()) { return TensorStructInfo(GetRef(output_shape), attrs->dtype); } - return TensorStructInfo(attrs->dtype, kUnknownDim); + return TensorStructInfo(attrs->dtype, kUnknownNDim); } RELAY_REGISTER_OP("relax.vm.builtin.alloc_tensor") diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index e72cfcc0aefd..3900c7f18b1f 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -63,7 +63,7 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx) { // ndims int output_ndim; if (lhs_sinfo->IsUnknownNdim() || rhs_sinfo->IsUnknownNdim()) { - output_ndim = kUnknownDim; + output_ndim = kUnknownNDim; } else { output_ndim = std::max(lhs_sinfo->ndim, rhs_sinfo->ndim); } diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index d13d23baabd2..0d72fddccc5f 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -114,7 +114,7 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { int output_ndim; if (t0->IsUnknownNdim() || t1->IsUnknownNdim() || t2->IsUnknownNdim()) { - output_ndim = kUnknownDim; + output_ndim = kUnknownNDim; } else { output_ndim = t0->ndim; } diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index f72b607f4163..21e144f48df7 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -54,52 +55,43 @@ class CallTIRMutator : public ExprMutator { if (call->op == call_tir_op) { Array outs; - if (call->shape_) { - if (call->shape_.value()->IsInstance()) { - // single output case - ShapeExpr output_shape = Downcast(call->shape_.value()); + if (const auto& _tensor_sinfo = MatchStructInfo(expr)) { + // single output case + const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value(); + ICHECK(tensor_sinfo->shape.defined()) + << "the TensorStructInfo shape of call_tir has not populated"; + auto alloc_tensor_attr = make_object(); + alloc_tensor_attr->dtype = tensor_sinfo->dtype; + alloc_tensor_attr->runtime_device_index = 0; + outs.push_back(builder_->Emit(Call(alloc_tensor_op, // + {Downcast(tensor_sinfo->shape.value())}, // + Attrs(alloc_tensor_attr)), + "alloc")); + } else if (const auto& _tuple_sinfo = MatchStructInfo(expr)) { + // multiple output case + const TupleStructInfo& tuple_sinfo = _tuple_sinfo.value(); + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + const auto& field = tuple_sinfo->fields[i]; + + ICHECK(field->IsInstance()) + << "call_tir expects Tuple of TensorStructInfo, but got " << field + << " as an element of TupleStructInfo"; + const auto& field_tensor = Downcast(field); + ICHECK(field_tensor->shape.defined()) + << "call_tir expects all TensorStructInfo has shape, but got " << field_tensor + << " as an element of TupleStructInfo"; auto alloc_tensor_attr = make_object(); - - if (call->checked_type_.defined()) { - auto output_type = Downcast(call->checked_type_); - alloc_tensor_attr->dtype = output_type->dtype; - alloc_tensor_attr->runtime_device_index = 0; - outs.push_back(builder_->Emit( - Call(alloc_tensor_op, {output_shape}, Attrs(alloc_tensor_attr)), "alloc")); - } else { - LOG(FATAL) << "ValueError: the checked_type_ of call_tir has not populated."; - } - } else { - // multiple output case - ICHECK(call->shape_.value()->IsInstance()) - << "call_tir expects ShapeExpr or Tuple as its shape, but got " << call->shape_; - ICHECK(call->checked_type_->IsInstance()) - << "call_tir expects DynTensorType or TupleType as its checked type, but got " - << call->checked_type_; - Tuple output_shapes = Downcast(call->shape_); - TupleType output_types = Downcast(call->checked_type_); - ICHECK(output_shapes->fields.size() == output_types->fields.size()) - << "The output of call_tir should have the same amount of fields in its shape_ and " - "checked_type_"; - for (size_t i = 0; i < output_shapes->fields.size(); ++i) { - ICHECK(output_shapes->fields[i]->IsInstance()) - << "call_tir expects Tuple of ShapeExprs, but got " << output_shapes->fields[i] - << " as an element of tuple"; - ICHECK(output_types->fields[i]->IsInstance()) - << "call_tir expects TupleType of DynTensorType, but got " - << output_types->fields[i] << " as an element of TupleType"; - auto output_type = Downcast(output_types->fields[i]); - auto alloc_tensor_attr = make_object(); - alloc_tensor_attr->dtype = output_type->dtype; - alloc_tensor_attr->runtime_device_index = 0; - outs.push_back(builder_->Emit( - Call(alloc_tensor_op, {Downcast(output_shapes->fields[i])}, - Attrs(alloc_tensor_attr)), - "alloc")); - } + alloc_tensor_attr->dtype = field_tensor->dtype; + alloc_tensor_attr->runtime_device_index = 0; + outs.push_back(builder_->Emit( + Call(alloc_tensor_op, {Downcast(field_tensor->shape.value())}, + Attrs(alloc_tensor_attr)), + "alloc")); } } else { - LOG(FATAL) << "ValueError: the shape of call_tir has not populated."; + LOG(FATAL) << "TypeError: The struct info of call_tir expects to be TensorStructInfo or " + "TupleStructInfo, but got" + << expr->struct_info_; } Array args; diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index e938f8938c1b..b0af3d7d4f4e 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -26,6 +26,7 @@ #include #include +#include #include namespace tvm { @@ -59,20 +60,12 @@ class BindingCanonicalizer : public ExprMutator { Expr new_value = this->VisitExpr(binding->value); Var new_var = this->VisitVarDef(binding->var); - auto emit = [this](VarBinding b) { - if (this->builder_->CurrentBlockIsDataFlow() && !b->var.as()) { - this->builder_->EmitOutput(b); - } else { - this->builder_->Emit(b); - } - }; - if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { - emit(GetRef(binding)); + this->builder_->EmitNormalized(GetRef(binding)); return; } - emit(VarBinding(new_var, new_value)); + this->builder_->EmitNormalized(VarBinding(new_var, new_value)); } void VisitBinding_(const MatchShapeNode* binding) override { @@ -89,22 +82,20 @@ class BindingCanonicalizer : public ExprMutator { } // if the LHS and RHS have the same shape_, we canonicalize to a var binding instead - if (new_var.defined() && new_value->shape_.defined() && - builder_->CanProveShapeEqual(Downcast(new_var->shape_), - Downcast(new_value->shape_))) { - builder_->Emit(VarBinding(new_var, new_value)); + if (new_var.defined() && StructuralEqual()(GetStructInfo(new_var), GetStructInfo(new_value))) { + builder_->EmitNormalized(VarBinding(new_var, new_value)); return; } // reemit old binding if nothing changes if (new_value.same_as(binding->value)) { if (!binding->var.defined() || (binding->var.defined() && new_var.same_as(binding->var))) { - builder_->EmitMatchShape(GetRef(binding)); + builder_->EmitNormalized(GetRef(binding)); return; } } - builder_->EmitMatchShape(MatchShape(new_value, binding->pattern, new_var)); + builder_->EmitNormalized(MatchShape(new_value, binding->pattern, new_var)); } private: @@ -130,16 +121,10 @@ class BindingCanonicalizer : public ExprMutator { // In this case, we could be overriding user annotations. // 2. If the child is a Var and the parent is a DataflowVar. // That could result in a DataflowVar leaving the current DataflowBlock. - bool annotations_differ = - AnnotationsDiffer(v->shape_, parent_var->shape_, - [&](const ObjectRef& shape1, const ObjectRef& shape2) { - return builder_->CanProveShapeEqual(Downcast(shape1), - Downcast(shape2)); - }) || - AnnotationsDiffer(v->checked_type_, parent_var->checked_type_, - [&](const ObjectRef& type1, const ObjectRef& type2) { - return tvm::StructuralEqual()(type1, type2); - }); + bool annotations_differ = AnnotationsDiffer(v->struct_info_, parent_var->struct_info_, + [&](const ObjectRef& lhs, const ObjectRef& rhs) { + return tvm::StructuralEqual()(lhs, rhs); + }); bool var_to_dataflow = (!v.as() && parent_var.as()); return !annotations_differ && !var_to_dataflow; } diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index c0e51cff2da1..15dc33f9c985 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -242,8 +243,7 @@ class FusedTIRConstructor : public ExprVisitor { void VisitExpr_(const FunctionNode* func) final { // Step 1. Create buffers for function params for (const Var& relax_param : func->params) { - auto ret = CreateParamsAndBuffers(relax_param->checked_type(), // - relax_param->shape(), // + auto ret = CreateParamsAndBuffers(GetStructInfo(relax_param), // relax_param->name_hint()); const Array& params = ret.first; const Array& buffers = ret.second; @@ -468,22 +468,23 @@ class FusedTIRConstructor : public ExprVisitor { /*! * \brief Create an TIR func params and buffers with specified relax type and shape - * \param type The specified relax type, which can be DynTensorType or Tuple - * \param shape The specified shape, which can be ShapeExpr or Tuple + * \param struct_info The struct info * \param name_hint The name hint for params and buffers * \param index The index used for unique name_hint if type is Tuple. * -1 means no need to add postfix since the relax param is not a Tuple. * \return The created TIR func params and buffers */ static std::pair, Array> CreateParamsAndBuffers( - Type type, relax::Expr shape, const String& name_hint, int index = -1) { + StructInfo struct_info, const String& name_hint, int index = -1) { Array params; Array buffers; - if (const auto* shape_expr = shape.as()) { + if (const auto* tensor = struct_info.as()) { // Case 1. the relax param is a DynTensor, we directly create a tir var and buffer - ICHECK(type->IsInstance()); + const auto* shape_expr = tensor->shape.as(); + ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with symbolic shape."; + String name = index == -1 ? name_hint : name_hint + "_" + std::to_string(index); - DataType dtype = Downcast(type)->dtype; + DataType dtype = tensor->dtype; tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name); // Differentiate buffer name and param name by adding prefix `v_` to param // Every symbol should be unique in TVMScript, and Buffer is used more than param @@ -491,15 +492,12 @@ class FusedTIRConstructor : public ExprVisitor { tir::Var param = tir::Var("p_" + name, PrimType(DataType::Handle())); params.push_back(std::move(param)); buffers.push_back(std::move(buffer)); - } else if (const auto* shape_tuple = shape.as()) { + } else if (const auto* tuple = struct_info.as()) { // Case 2. the relax param is a Tuple, we recursively visit each field until it's a DynTensor - ICHECK(type->IsInstance()); - TupleType tuple_type = Downcast(type); // Enable postfix if (index == -1) index = 0; - for (size_t i = 0; i < shape_tuple->fields.size(); ++i) { - auto ret = - CreateParamsAndBuffers(tuple_type->fields[i], shape_tuple->fields[i], name_hint, index); + for (size_t i = 0; i < tuple->fields.size(); ++i) { + auto ret = CreateParamsAndBuffers(tuple->fields[i], name_hint, index); const Array& ret_params = ret.first; const Array& ret_buffers = ret.second; ICHECK_EQ(ret_params.size(), ret_buffers.size()); @@ -631,9 +629,24 @@ class TIRFuseMutator : public ExprMutator { using ExprMutator::VisitExpr_; + // Get shape from call tir + static Expr GetCallTIRShape(StructInfo sinfo) { + if (auto* tuple = sinfo.as()) { + Array fields = tuple->fields.Map([&](StructInfo x) { return GetCallTIRShape(x); }); + return Tuple(fields); + } else { + auto* tensor = sinfo.as(); + ICHECK(tensor) << "FuseTIR can only take tensor or tuple type"; + auto* shape_expr = tensor->shape.as(); + ICHECK(shape_expr) << "FuseTIR requires all intermediate values have shape"; + return GetRef(shape_expr); + } + } + Expr VisitExpr_(const CallNode* op) final { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); - Call call = Downcast(ExprMutator::VisitExpr_(op)); + Call call = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(op))); + if (call->op->IsInstance()) { // Case 1. It is a relax cross function call GlobalVar old_gv = Downcast(call->op); @@ -649,7 +662,8 @@ class TIRFuseMutator : public ExprMutator { arg_list.insert(arg_list.end(), flattened.begin(), flattened.end()); } // Step b. Create call_tir - Array call_args = {fused_tir_gv, Tuple(arg_list), call->shape()}; + Array call_args = {fused_tir_gv, Tuple(arg_list), + GetCallTIRShape(GetStructInfo(call))}; return Call(call_tir_op_, call_args, call->attrs, {call->checked_type()}); } else { // Case 1.2. The callee function is not primitive, nothing to do. @@ -672,9 +686,9 @@ class TIRFuseMutator : public ExprMutator { /*! \brief Flatten the call args if it's Tuple by emitting `TupleGetItem`. */ Array FlattenArg(const Expr& arg) { - if (const auto* tuple_shape = arg->shape().as()) { + if (const auto* tuple_sinfo = GetStructInfoAs(arg)) { Array arg_list; - for (size_t i = 0; i < tuple_shape->fields.size(); ++i) { + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { Expr new_arg = builder_->Emit(TupleGetItem(arg, i)); Array flattened = FlattenArg(new_arg); arg_list.insert(arg_list.end(), flattened.begin(), flattened.end()); diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 3f3e60a30e56..335e2fc7a16f 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -140,23 +140,15 @@ class NormalizeMutator : public ExprMutatorBase { } void VisitBinding_(const VarBindingNode* binding) { - auto emit = [this](VarBinding b) { - if (this->builder_->CurrentBlockIsDataFlow() && !b->var.as()) { - this->builder_->EmitOutput(b); - } else { - this->builder_->Emit(b); - } - }; - Expr new_value = this->VisitExpr(binding->value); if (!binding->var->struct_info_.defined()) { UpdateStructInfo(binding->var, GetStructInfo(new_value)); } if (new_value.same_as(binding->value)) { - emit(GetRef(binding)); + builder_->EmitNormalized(GetRef(binding)); } else { - emit(VarBinding(binding->var, new_value)); + builder_->EmitNormalized(VarBinding(binding->var, new_value)); } } @@ -169,9 +161,9 @@ class NormalizeMutator : public ExprMutatorBase { } } if (new_value.same_as(binding->value)) { - builder_->EmitMatchShape(GetRef(binding)); + builder_->EmitNormalized(GetRef(binding)); } else { - builder_->EmitMatchShape(MatchShape(new_value, binding->pattern, binding->var)); + builder_->EmitNormalized(MatchShape(new_value, binding->pattern, binding->var)); } } diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index 8c14d3aac6ca..ef41614f794f 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -75,7 +75,7 @@ class CodeGenRunner : ExprMutator { tmp_args.push_back(VisitExpr(arg)); } new_args.push_back(Tuple(tmp_args)); - new_args.push_back(func->body->shape()); + new_args.push_back(GetShapeOf(func->body)); static const Op& call_op = Op::Get("relax.call_tir"); diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 91575e4c8b61..4bf7588510ab 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -80,8 +80,8 @@ bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unkn bool IsLeafExpr(const Expr& expr) { // NB: tuples are treated as leaf nodes for ergonomics return expr.as() || expr.as() || expr.as() || - expr.as() || expr.as() || expr.as() || - expr.as() || expr.as(); + expr.as() || expr.as() || expr.as() || + expr.as(); } } // namespace relax diff --git a/src/relay/printer/relax_script_printer.cc b/src/relay/printer/relax_script_printer.cc index 79bbe6f93696..cb18f0fe4ec3 100644 --- a/src/relay/printer/relax_script_printer.cc +++ b/src/relay/printer/relax_script_printer.cc @@ -262,12 +262,6 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::ShapeExprNode* op) { return doc; } -Doc RelaxScriptPrinter::VisitNode_(const relax::RuntimeDepShapeNode* op) { - Doc doc; - doc << "None"; - return doc; -} - Doc RelaxScriptPrinter::VisitNode_(const relax::MatchShapeNode* op) { Doc doc; if (op->var.defined()) { diff --git a/src/relay/printer/text_printer.h b/src/relay/printer/text_printer.h index 61f3c79ea98e..42b95db854bf 100644 --- a/src/relay/printer/text_printer.h +++ b/src/relay/printer/text_printer.h @@ -276,7 +276,6 @@ class RelaxScriptPrinter : public relax::IRFunctor, Doc VisitNode_(const relax::VarNode* op) override; Doc VisitNode_(const relax::DataflowVarNode* op) override; Doc VisitNode_(const relax::ShapeExprNode* op) override; - Doc VisitNode_(const relax::RuntimeDepShapeNode* op) override; Doc VisitNode_(const relax::MatchShapeNode* op) override; Doc VisitNode_(const relax::VarBindingNode* op) override; Doc VisitNode_(const relax::BindingBlockNode* op) override; @@ -600,8 +599,7 @@ class TextPrinter { doc << tir_text_printer_.Print(node); } else if (node.defined() && (node->IsInstance() || node->IsInstance() || - node->IsInstance() || - node->IsInstance())) { + node->IsInstance())) { doc << relax_text_printer_.Print(node); } else { doc << relay_text_printer_.PrintFinal(node); diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 1de12fc9e2f0..da0531ac984d 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -212,8 +212,9 @@ Optional EmitMatchShape(const tvm::relax::Expr& value, // if (!emit_var) { // If we don't intend to emit a variable, just emit the binding and return. - tvm::relax::MatchShape match_shape(value, pattern, tvm::relax::Var{nullptr}); - block_builder->EmitMatchShape(match_shape); + tvm::relax::MatchShape match_shape(block_builder->Normalize(value), pattern, + tvm::relax::Var{nullptr}); + block_builder->EmitNormalized(match_shape); return NullOpt; } else { // Otherwise, we need to emit a variable and bind it to the match shape. @@ -230,49 +231,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchShape").set_body_typed(Emi void AnnotateStructInfo(const tvm::relax::Var& var, const tvm::relax::StructInfo& anno_struct_info) { - using tvm::relax::IsBaseOf; - using tvm::relax::StructInfo; - - Type type = GetStaticType(anno_struct_info); - Optional shape = GetLegacyShapeHint(anno_struct_info); - - // TODO(siyuan, ruihang): Revisit the checks aftr we fully migrate to struct info. - // consider simplify assumption to require var not contain struct info at all. - if (var->struct_info_.defined()) { - // Case 1. The var already has struct info. - const StructInfo& var_struct_info = Downcast(var->struct_info_.value()); - CHECK(IsBaseOf(var_struct_info, anno_struct_info) || - IsBaseOf(anno_struct_info, var_struct_info)) - << "ValueError: The annotation struct info is not a base of the existing struct info."; - } else { - // Case 2. The var doesn't have struct info. - // This path may be removed later - - if (var->checked_type_.defined()) { - const Type& var_type = var->checked_type(); - CHECK(IsBaseOf(type, var_type) || IsBaseOf(var_type, type)) - << "TypeError: The annotated type and value type are not compatible. " - << "The Type is expected to be " << var_type << " but got annotation: " << type; - } - if (var->shape_.defined() && shape.defined()) { - tvm::relax::Expr var_shape = Downcast(var->shape_.value()); - auto check_shape = [](const tvm::relax::Expr& lhs, const tvm::relax::Expr& rhs) { - if (lhs->IsInstance() || - rhs->IsInstance()) { - return true; - } else { - const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); - return block_builder->CanProveShapeEqual(lhs, rhs); - } - }; - CHECK(check_shape(var_shape, shape.value())) - << " The shape of var " << var->name_hint() << " is expected to be " << var_shape - << " but got annotation: " << shape.value(); - } - } - - var->checked_type_ = type; - var->shape_ = shape; + var->checked_type_ = GetStaticType(anno_struct_info); var->struct_info_ = anno_struct_info; } diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index b6ce6c75d10a..2e638bfd32a5 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -26,7 +26,6 @@ remove_all_unused, name_to_binding, shape_vars, - derive_func_ret_shape, all_vars, free_vars, bound_vars, @@ -257,33 +256,6 @@ def test_shape_var_nested(): assert sv1 in vars -def test_derive_func_ret_shape_no_free(): - sv1 = tir.Var("sv1", "int64") - sv2 = tir.Var("sv2", "int64") - sv3 = tir.Var("sv3", "int64") - a1 = rx.Var("a1", R.Tensor([sv1, sv2])) - a2 = rx.Var("a2", R.Tensor([sv2, sv3])) - body = a2 - shape_expr = derive_func_ret_shape([a1, a2], body) - - assert isinstance(shape_expr, rx.ShapeExpr) - assert shape_expr[0] == sv2 - assert shape_expr[1] == sv3 - - -def test_derive_func_ret_shape_free(): - sv1 = tir.Var("sv1", "int64") - sv2 = tir.Var("sv2", "int64") - sv3 = tir.Var("sv3", "int64") - a1 = rx.Var("a1", R.Tensor([sv1, sv2])) - a2 = rx.Var("a2", R.Tensor([sv2, sv1])) - # Artifically introducing a free shape variable. - # This would not be a valid program, but this is being done to test the logic - body = rx.Var("a3", R.Tensor([sv1, sv3])) - shape_expr = derive_func_ret_shape([a1, a2], body) - assert isinstance(shape_expr, rx.RuntimeDepShape) - - @tvm.script.ir_module class VarExample: @R.function diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 37a96b962d29..9f4cea057b4f 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -248,16 +248,16 @@ def test_func(): def test_shape_of(): - v0 = rx.Var("v0") - s0 = v0.shape + v0 = rx.Var("v0", R.Tensor(ndim=2)) + s0 = rx.get_shape_of(v0) s0_str = dump_ast(s0) assert s0_str.startswith("Call(") assert 'op=Op(name="relax.shape_of")' in s0_str assert "args=" in s0_str - assert 'Var(name_hint="v0")' in s0_str + assert 'name_hint="v0"' in s0_str v1 = rx.Var("v1", R.Tensor([96, 54])) - s1 = v1.shape + s1 = rx.get_shape_of(v1) s1_str = dump_ast(s1) assert s1_str.startswith("ShapeExpr("), s1_str assert "values=" in s1_str diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py index 77c7fe39dea2..4ff524096a2a 100644 --- a/tests/python/relax/test_blockbuilder.py +++ b/tests/python/relax/test_blockbuilder.py @@ -80,10 +80,7 @@ def test_function_single_block(): assert func.params[0] == x assert func.params[1] == y assert func.body.body == gv0 - assert gv0.shape[0] == m - assert gv0.shape[1] == n - assert gv0.checked_type.ndim == 2 - assert gv0.checked_type.dtype == "float16" + assert_structural_equal(gv0.struct_info, rx.TensorStructInfo([m, n], "float16")) assert len(func.body.blocks) == 1 assert len(func.body.blocks[0].bindings) == 3 @@ -110,10 +107,8 @@ def test_function_multi_blocks(): bb.emit_func_output(gv2) func = bb.get()["func"] - assert gv2.shape[0] == m - assert gv2.shape[1] == n - assert gv2.checked_type.ndim == 2 - assert gv2.checked_type.dtype == "float16" + + assert_structural_equal(gv2.struct_info, rx.TensorStructInfo([m, n], "float16")) assert func.params[0] == x assert func.params[1] == y assert func.body.body == gv2 @@ -214,28 +209,21 @@ def test_binary_shape_type_deduction(): with bb.function("func", [x, y, z, w]): with bb.dataflow(): lv0 = bb.emit(rx.op.add(x, y)) - assert lv0.shape[0] == m - assert lv0.shape[1] == n - assert isinstance(lv0.checked_type, rx.DynTensorType) - assert lv0.checked_type.ndim == 2 - assert lv0.checked_type.dtype == "float16" + assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float16")) lv1 = bb.emit(rx.op.multiply(x, z)) - assert lv1.shape[0] == m - assert lv1.shape[1] == 5 - assert isinstance(lv1.checked_type, rx.DynTensorType) - assert lv1.checked_type.ndim == 2 - assert lv1.checked_type.dtype == "float16" + assert_structural_equal(lv1.struct_info, rx.TensorStructInfo([m, 5], "float16")) lv2 = bb.emit(rx.op.multiply(z, w)) - assert isinstance(lv2.checked_type, rx.DynTensorType) - assert lv2.checked_type.ndim == 1 - assert lv2.checked_type.dtype == "float16" + assert isinstance(lv2.struct_info, rx.TensorStructInfo) + assert lv2.struct_info.ndim == 1 + assert lv2.struct_info.dtype == "float16" lv3 = bb.emit(rx.op.multiply(y, w)) assert isinstance(lv3.struct_info, rx.TensorStructInfo) - assert lv3.checked_type.ndim == 1 - assert lv3.checked_type.dtype == "float16" + assert lv3.struct_info.ndim == 1 + assert lv3.struct_info.dtype == "float16" + gv0 = bb.emit_output(lv3) bb.emit_func_output(gv0) @@ -257,14 +245,11 @@ def test_emit_match_shape(): # match_shape(x: Tensor(_, "float32"], [m, n)) lv0 = bb.match_shape(x, [m, n]) assert isinstance(lv0, rx.DataflowVar) - assert lv0.shape[0] == m - assert lv0.shape[1] == n - assert lv0.checked_type.ndim == 2 - assert lv0.checked_type.dtype == "float32" + assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float32")) # lv1: Shape = match_shape(shape, [m, n]) lv1 = bb.match_shape(y, [m, n]) - assert lv1.checked_type == rx.ShapeType(2) + assert lv1.struct_info == rx.ShapeStructInfo(ndim=2) gv0 = bb.emit_output(lv1) bb.emit_func_output(gv0) @@ -295,7 +280,7 @@ def test_emit_match_shape_binding_in_dataflow_block(): with bb.function("main", [x]): with bb.dataflow(): - bb.match_shape_binding(match_shape) + bb.emit_normalized(match_shape) bb.emit_output(gv) bb.emit_func_output(x) @@ -312,8 +297,6 @@ def test_emit_match_shape_binding_in_dataflow_block(): def test_normalize(): m = tir.Var("m", "int64") n = tir.Var("n", "int64") - type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") - type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") x = rx.Var("x", R.Tensor([m, n], "float16")) y = rx.Var("y", R.Tensor([n], "float16")) @@ -321,44 +304,34 @@ def test_normalize(): # Call node add_call = rx.op.multiply(x, y) - assert isinstance(add_call.shape, rx.Call) bb.normalize(add_call) - assert isinstance(add_call.shape, rx.ShapeExpr) - assert add_call.shape[0] == m - assert add_call.shape[1] == n + shape = rx.get_shape_of(add_call) + + assert isinstance(shape, rx.ShapeExpr) + assert shape[0] == m + assert shape[1] == n # Tuple node tuple_1 = rx.Tuple([x, y]) bb.normalize(tuple_1) - assert_structural_equal(tuple_1.checked_type, rx.TupleType([type_anno0, type_anno1])) - assert_structural_equal(tuple_1.shape, rx.Tuple([x.shape, y.shape])) assert isinstance(tuple_1.struct_info, rx.TupleStructInfo) assert isinstance(tuple_1.struct_info.fields[0], rx.TensorStructInfo) assert isinstance(tuple_1.struct_info.fields[1], rx.TensorStructInfo) - # Note sure if it's needed - assert_structural_equal( - tuple_1.shape.struct_info, - rx.TupleStructInfo([rx.ShapeStructInfo([m, n]), rx.ShapeStructInfo([n])]), - ) - # Nested Tuple tuple_2 = rx.Tuple([x, rx.Tuple([x, y])]) bb.normalize(tuple_2) + type_anno0 = x.checked_type + type_anno1 = y.checked_type assert_structural_equal( tuple_2.checked_type, rx.TupleType([type_anno0, rx.TupleType([type_anno0, type_anno1])]) ) - assert_structural_equal(tuple_2.shape, rx.Tuple([x.shape, rx.Tuple([x.shape, y.shape])])) assert isinstance(tuple_2.struct_info, rx.TupleStructInfo) assert isinstance(tuple_2.struct_info.fields[0], rx.TensorStructInfo) assert isinstance(tuple_2.struct_info.fields[1], rx.TupleStructInfo) assert isinstance(tuple_2.struct_info.fields[1].fields[0], rx.TensorStructInfo) assert isinstance(tuple_2.struct_info.fields[1].fields[1], rx.TensorStructInfo) - # assert_structural_equal( - # tuple_2.shape.checked_type, - # rx.TupleType([rx.ShapeType(), rx.TupleType([rx.ShapeType(), rx.ShapeType()])]), - # ) def test_call_te(): @@ -540,20 +513,16 @@ def test_emit_tuple_get_item(): y = bb.emit_te(topi.nn.batch_norm, data, gamma, beta, moving_mean, moving_var) z = bb.emit(rx.TupleGetItem(y, 0)) - assert z.shape[0] == n - assert z.shape[1] == m - assert z.shape[2] == 224 - assert z.shape[3] == 224 - assert z.checked_type.ndim == 4 - assert z.checked_type.dtype == "float32" + assert_structural_equal( + z.struct_info, rx.TensorStructInfo([n, m, 224, 224], dtype="float32") + ) w = bb.emit(rx.TupleGetItem(y, 1)) - assert w.shape[0] == m - assert w.checked_type.dtype == "float32" + assert_structural_equal(w.struct_info, rx.TensorStructInfo([m], dtype="float32")) o = bb.emit(rx.TupleGetItem(y, 2)) - assert o.shape[0] == m - assert o.checked_type.dtype == "float32" + assert_structural_equal(o.struct_info, rx.TensorStructInfo([m], dtype="float32")) + bb.emit_func_output([y, w], params=[data, gamma, beta, moving_mean, moving_var]) func = bb.get()["rx_func"] diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 7a57633ea10c..cc2ac7e7ea3e 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -241,22 +241,13 @@ def test_prim_arr_pattern(): assert pattern[0] == 32 assert pattern[1] == 32 assert isinstance(pattern, PrimArrPattern) - assert pattern.match(bindings[0].var.shape) + assert pattern.match(rx.get_shape_of(bindings[0].var)) n, m = tir.Var("n", dtype="int64"), tir.Var("m", dtype="int64") symbolic_shape = rx.ShapeExpr([n, m, n + m]) assert is_shape([n, m, n + m]).match(symbolic_shape) assert not is_shape([n, m, n * m]).match(symbolic_shape) -def test_rt_dep_shape_pattern(): - # runtime-dep-shape var - rts_var = rx.Var("rts_var", R.Tensor("float32", ndim=4)) - # static-shape var - ss_var = rx.Var("ss_var", R.Tensor([32, 32], "float32")) - assert wildcard().has_rt_dep_shape().match(rts_var) - assert not wildcard().has_rt_dep_shape().match(ss_var) - - def test_extern_fn_pattern(): pattern = ExternFuncPattern("test.blockbuilder.nop") assert pattern.match(rx.ExternFunc("test.blockbuilder.nop")) diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index f7a59bcaedf0..55d7ee072cc4 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -25,13 +25,12 @@ def test_var() -> None: v0 = rx.Var("v0") assert v0.name_hint == "v0" - assert v0.shape_ is None assert v0._checked_type_ is None assert v0.struct_info_ is None shape = [54, 96] v1 = rx.Var("v1", R.Tensor(shape, "float32")) assert v1.name_hint == "v1" - for s0, s1 in zip(v1.shape_, shape): + for s0, s1 in zip(v1.struct_info.shape, shape): assert s0 == s1 assert v1.checked_type == rx.DynTensorType(2, "float32") tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float32")) @@ -40,15 +39,13 @@ def test_var() -> None: def test_dataflow_var() -> None: v0 = rx.DataflowVar("v0") assert v0.name_hint == "v0" - assert v0.shape_ is None assert v0._checked_type_ is None assert v0.struct_info_ is None shape = [54, 96] v1 = rx.DataflowVar("v1", R.Tensor(shape, "float16")) assert v1.name_hint == "v1" - for s0, s1 in zip(v1.shape_, shape): - assert s0 == s1 + assert v1._checked_type_ == rx.DynTensorType(2, "float16") assert isinstance(v1, rx.DataflowVar) tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float16")) @@ -77,8 +74,6 @@ def test_match_shape() -> None: assert b1.pattern[0] == m assert b1.pattern[1] == n assert b1.var is not None - for s0, s1 in zip(b1.var.shape, [m, n]): - assert s0 == s1 assert b1.var.checked_type == rx.DynTensorType(2, "float32") @@ -146,14 +141,14 @@ def test_func(): def test_shape_of(): - v0 = rx.Var("v0") - s0 = v0.shape + v0 = rx.Var("v0", R.Tensor("float32", ndim=2)) + s0 = rx.get_shape_of(v0) assert isinstance(s0, rx.Call) assert s0.op.name == "relax.shape_of" shape = [96, 54] v1 = rx.Var("v1", R.Tensor(shape)) - s1 = v1.shape + s1 = rx.get_shape_of(v1) for x, y in zip(shape, s1): assert x == y @@ -170,13 +165,13 @@ def test_shape_expr(): assert shape_expr.values[0] == 10 assert shape_expr.values[1] == 20 assert shape_expr.checked_type == rx.ShapeType(ndim=2) - assert shape_expr.shape_ is None + tvm.ir.assert_structural_equal(shape_expr.struct_info, R.Shape((10, 20))) x = rx.Var("v0", R.Tensor((10, 20), "float32")) - assert x.shape_.values[0] == 10 - assert x.shape_.values[1] == 20 - assert x.shape_.checked_type == rx.ShapeType(ndim=2) - assert x.shape_.shape_ is None + assert x.struct_info.shape[0] == 10 + assert x.struct_info.shape[1] == 20 + assert x.struct_info.shape.checked_type == rx.ShapeType(ndim=2) + tvm.ir.assert_structural_equal(x.struct_info.shape.struct_info, R.Shape((10, 20))) m = tir.Var("m", "int32") with pytest.raises( diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py index ef17ca4a3e50..a0abaa72d75c 100644 --- a/tests/python/relax/test_expr_functor.py +++ b/tests/python/relax/test_expr_functor.py @@ -31,7 +31,6 @@ GlobalVar, If, MatchShape, - RuntimeDepShape, SeqExpr, ShapeExpr, Tuple, @@ -138,9 +137,6 @@ def visit_tuple_getitem_(self, op: TupleGetItem) -> None: def visit_shape_expr_(self, op: ShapeExpr) -> None: self.log.add("ShapeExpr") - def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> None: - self.log.add("RuntimeDepShape") - def visit_extern_func_(self, op: ExternFunc) -> None: self.log.add("ExternFunc") @@ -257,11 +253,6 @@ def visit_shape_expr_(self, op: ShapeExpr) -> Expr: self.log.add("ShapeExpr") return op - def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> Expr: - op = self.visit_expr_post_order(op) - self.log.add("RuntimeDepShape") - return op - def visit_extern_func_(self, op: ExternFunc) -> Expr: op = self.visit_expr_post_order(op) self.log.add("ExternFunc") @@ -277,15 +268,9 @@ def visit_var_binding_(self, binding: VarBinding) -> None: new_value = self.visit_expr(binding.value) new_var = self.visit_var_def(binding.var) - def emit(b: VarBinding): - if self.builder_.current_block_is_dataflow() and not isinstance(b.var, DataflowVar): - self.builder_.emit_output_var_binding(b) - else: - self.builder_.emit_var_binding(b) - self.log.add("VarBinding") if binding.var.same_as(new_var) and binding.value.same_as(new_value): - emit(binding) + self.builder_.emit_normalized(binding) return temp = self.with_struct_info(new_var, new_value.struct_info) @@ -293,7 +278,7 @@ def emit(b: VarBinding): new_var = temp self.set_var_remap(binding.var.vid, new_var) - emit(VarBinding(new_var, new_value)) + self.builder_.emit_normalized(VarBinding(new_var, new_value)) def visit_match_shape_(self, binding: MatchShape) -> None: """Identical with ExprMutator::VisitBinding_(const MatchShapeNode* binding) on the C++ side.""" @@ -316,10 +301,10 @@ def visit_match_shape_(self, binding: MatchShape) -> None: self.log.add("MatchShape") if binding.value.same_as(new_value) and binding.pattern.same_as(new_pattern): if not binding.var or (binding.var and binding.var.same_as(new_var)): - self.builder_.match_shape_binding(binding) + self.builder_.emit_normalized(binding) return - self.builder_.match_shape_binding(MatchShape(new_value, new_pattern.values, new_var)) + self.builder_.emit_normalized(MatchShape(new_value, new_pattern.values, new_var)) def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: """Identical with ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) on the C++ side.""" @@ -339,37 +324,13 @@ def visit_dataflow_block_(self, block: DataflowBlock) -> None: def visit_var_def_(self, var: Var) -> None: """Identical with ExprMutator::VisitVarDef_(const VarNode* var) on the C++ side.""" - shape_unchanged = True - new_shape = None - if var.shape_: - new_shape = self.visit_expr(var.shape_) - shape_unchanged &= var.shape_.same_as(new_shape) - self.log.add("VarDef") - if shape_unchanged: - return var - else: - new_var = Var(var.vid, new_shape, var._checked_type_, var.span) - - self.set_var_remap(var.vid, new_var) - return new_var + return var def visit_dataflow_var_def_(self, var: DataflowVar) -> None: """Identical with ExprMutator::VisitVarDef_(const DataflowVarNode* var) on the C++ side.""" - shape_unchanged = True - new_shape = None - if var.shape_: - new_shape = self.visit_expr(var.shape_) - shape_unchanged &= var.shape_.same_as(new_shape) - self.log.add("DataflowVarDef") - if shape_unchanged: - return var - else: - new_var = DataflowVar(var.vid, new_shape, var._checked_type_, var.span) - - self.set_var_remap(var.vid, new_var) - return new_var + return var def basic_check(expr, visitor_str, mutator_str): @@ -443,9 +404,7 @@ def test_seq_expr(): "\tVar", ] ), - "\n".join( - ["Constant", "ShapeExpr", "VarDef", "VarBinding", "BindingBlock", "Var", "SeqExpr"] - ), + "\n".join(["Constant", "VarDef", "VarBinding", "BindingBlock", "Var", "SeqExpr"]), ) @@ -454,11 +413,6 @@ def test_shape_expr(): basic_check(x, "ShapeExpr", "ShapeExpr") -def test_runtime_dep_shape(): - runtime_dep_shape = relax.RuntimeDepShape() - basic_check(runtime_dep_shape, "RuntimeDepShape", "RuntimeDepShape") - - def test_call(): call_node = relax.op.add(x, y) basic_check( @@ -591,10 +545,8 @@ def test_function(): ), "\n".join( [ - "ShapeExpr", "VarDef", "Constant", - "ShapeExpr", "VarDef", "VarBinding", "BindingBlock", diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index f18c06ec9db9..8f013923d359 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -27,43 +27,6 @@ # c.f. tests/python/unittest/test_tvmscript_error_report.py -def check_shape(e, s): - if isinstance(e, relax.ShapeExpr): - pass - elif isinstance(e, relax.Call): - e = e.shape - elif isinstance(e, relax.Expr): - e = e.shape_ - - if s is None: - assert e is None - return - - if isinstance(s, relax.RuntimeDepShape): - assert isinstance(e, relax.RuntimeDepShape) - return - - assert len(e) == len(s) - - for edim, sdim in zip(e, s): - if isinstance(sdim, str): - assert isinstance(edim, tir.Var) - assert edim.name == sdim - else: - assert isinstance(edim, tir.IntImm) - assert edim.value == sdim - - -def check_tensor_var(v, s, d, ndim=None): - assert isinstance(v._checked_type_, relax.ty.DynTensorType) - assert v._checked_type_.dtype == d - if isinstance(s, (list, tuple)): - assert v._checked_type_.ndim == len(s) - if ndim is not None: - assert v._checked_type_.ndim == ndim - check_shape(v, s) - - def check_call(call, op, args): assert isinstance(call, relax.Call) if isinstance(op, str): @@ -97,14 +60,22 @@ def f( sh, shape_of = sh_bind.var, sh_bind.value o, o_call_packed = o_bind.var, o_bind.value - check_tensor_var(x, (32, "m"), "float32") - check_tensor_var(y, ("m",), "float32") - check_tensor_var(r, relax.RuntimeDepShape(), "int64") - check_tensor_var(z, (32, "m"), "float32") - check_tensor_var(w, relax.RuntimeDepShape(), "") - check_tensor_var(q, relax.RuntimeDepShape(), "", ndim=2) - assert isinstance(t._checked_type_, relax.ty.DynTensorType) - assert isinstance(sh._checked_type_, relax.ty.ShapeType) + m = tvm.tir.Var("m", dtype="int64") + assert_structural_equal( + x.struct_info, relax.TensorStructInfo([32, m], "float32"), map_free_vars=True + ) + assert_structural_equal( + y.struct_info, relax.TensorStructInfo([m], "float32"), map_free_vars=True + ) + assert_structural_equal( + r.struct_info, relax.TensorStructInfo(dtype="int64", ndim=-1), map_free_vars=True + ) + assert_structural_equal( + z.struct_info, relax.TensorStructInfo([32, m], "float32"), map_free_vars=True + ) + assert_structural_equal( + w.struct_info, relax.TensorStructInfo(ndim=-1, dtype=""), map_free_vars=True + ) check_call(mm, "relax.multiply", [x, y]) check_call(mul, "relax.multiply", [z, z]) @@ -199,8 +170,6 @@ def f(x: R.Tensor(dtype="float32")): match_sh = f.body.blocks[0].bindings[0] pattern, value = match_sh.pattern, match_sh.value - - check_shape(pattern, ("n", "m")) check_call(value, "relax.shape_of", [f.params[0]]) @@ -219,17 +188,22 @@ def f(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")) -> R.Tensor: y_bind = f.body.blocks[0].bindings[0] y, ite = y_bind.var, y_bind.value - check_tensor_var(cond, tuple(), "bool") - check_tensor_var(x, (1,), "float32") + assert_structural_equal( + cond.struct_info, relax.TensorStructInfo([], "bool"), map_free_vars=True + ) + assert_structural_equal( + x.struct_info, relax.TensorStructInfo([1], "float32"), map_free_vars=True + ) assert isinstance(y, relax.Var) assert y.name_hint == "y" assert isinstance(ite, relax.If) assert ite.checked_type == relax.DynTensorType(1, "float32") - check_shape(ite.shape, (1,)) + assert_structural_equal(ite.struct_info, relax.TensorStructInfo([1], "float32")) + assert y.checked_type == relax.DynTensorType(1, "float32") - check_shape(y.shape, (1,)) + assert_structural_equal(y.struct_info, relax.TensorStructInfo([1], "float32")) assert isinstance(ite.true_branch, relax.SeqExpr) assert isinstance(ite.false_branch, relax.SeqExpr) @@ -315,13 +289,14 @@ def f(x: R.Tensor, y: R.Tensor((32,), "float32")): isinstance(annot.fields[1], relax.ty.DynTensorType) and annot.fields[1].dtype == "float32" ) - assert isinstance(t.shape_, relax.Tuple) assert isinstance(tup, relax.Tuple) assert_structural_equal(tup.fields, [x, y]) - - assert isinstance(tup.shape_, relax.Tuple) - check_shape(tup.fields[0], relax.RuntimeDepShape()) - check_shape(tup.fields[1], (32,)) + assert_structural_equal( + tup.struct_info, + relax.TupleStructInfo( + [relax.TensorStructInfo(ndim=-1, dtype=""), relax.TensorStructInfo([32], "float32")] + ), + ) def test_tuplegetitem(): @@ -425,12 +400,13 @@ def f(x: R.Tensor): q_bind = f.body.blocks[1].bindings[1] assert x2_bind.var.name_hint == "x2" - check_tensor_var(x2_bind.var, ("n", "m"), "") - check_shape(x2_bind.pattern, ("n", "m")) - assert x2_bind.value == x - - check_shape(z_shape_bind.pattern, ("n", "m")) + m = tvm.tir.Var("m", dtype="int64") + n = tvm.tir.Var("n", dtype="int64") + assert_structural_equal( + x2_bind.var.struct_info, relax.TensorStructInfo([n, m], ""), map_free_vars=True + ) + assert x2_bind.value == x assert q_bind.value.args[1] == x2_bind.var @@ -525,7 +501,7 @@ def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: return z x, y = f.params - B = x.shape_[0] + B = x.struct_info.shape[0] mm_bind, z_bind = f.body.blocks[0].bindings assert mm_bind.var.name_hint == "my_matmul" @@ -572,7 +548,6 @@ def f(x: R.Tensor((3, 3), "float32")): (z_bind, w_bind, o_bind, k_bind) = f.body.blocks[0].bindings z_var, z_value = z_bind.var, z_bind.value - check_tensor_var(z_var, ("n", "m"), "float32") assert isinstance(z_value.op, relax.ExternFunc) assert z_value.op.global_symbol == "contrib.my_matmul" @@ -656,10 +631,10 @@ def f(x: R.Tensor(("n", "m"), "float32")): return z x = f.params[0] - n, m = x.shape_ + n, m = x.struct_info.shape z_bind, sh_bind = f.body.blocks[0].bindings - assert_structural_equal(z_bind.var.shape_.values, [tir.Mul(n, m)]) + assert_structural_equal(z_bind.var.struct_info.shape.values, [tir.Mul(n, m)]) assert_structural_equal(sh_bind.value.values, [tir.Add(n, m), tir.FloorDiv(n, m)]) @@ -780,12 +755,13 @@ def k(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.T assert func_g.body.body == g_call_var gv_bind = func_j.body.blocks[0].bindings[0] - assert gv_bind.value.checked_type.ndim == 2 - assert gv_bind.value.checked_type.dtype == "float32" - assert gv_bind.var.checked_type.ndim == 2 - assert gv_bind.var.checked_type.dtype == "float32" - check_shape(gv_bind.value, ("n", "n")) - check_shape(gv_bind.var, ("n", "n")) + n = tvm.tir.Var("n", "int64") + assert_structural_equal( + gv_bind.value.struct_info, relax.TensorStructInfo([n, n], "float32"), map_free_vars=True + ) + assert_structural_equal( + gv_bind.var.struct_info, relax.TensorStructInfo([n, n], "float32"), map_free_vars=True + ) # check call_packed checked_type_ gv0_bind = func_k.body.blocks[0].bindings[0] @@ -806,9 +782,10 @@ def k(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.T # check SeqExpr type/shape assert isinstance(func_j.body, relax.SeqExpr) - assert func_j.body.checked_type.dtype == "float32" - assert func_j.body.checked_type.ndim == 2 - check_shape(func_j.body, ("n", "n")) + + assert_structural_equal( + func_j.body.struct_info, relax.TensorStructInfo([n, n], "float32"), map_free_vars=True + ) # check tuple type/shape gv1_bind = func_j.body.blocks[0].bindings[1] @@ -817,22 +794,22 @@ def k(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.T isinstance(gv1_bind.var.checked_type, relax.TupleType) assert gv1_bind.var.checked_type.fields[0].ndim == 2 assert gv1_bind.var.checked_type.fields[0].dtype == "float32" - isinstance(gv1_bind.var.shape, relax.Tuple) - isinstance(gv1_bind.value.shape, relax.Tuple) - check_shape(gv1_bind.value.shape.fields[0], ("n", "n")) - check_shape(gv1_bind.value.shape.fields[1], ("n", "n")) - check_shape(gv1_bind.var.shape.fields[0], ("n", "n")) - check_shape(gv1_bind.var.shape.fields[1], ("n", "n")) + + assert_structural_equal( + gv1_bind.var.struct_info, + relax.TupleStructInfo( + [relax.TensorStructInfo([n, n], "float32"), relax.TensorStructInfo([n, n], "float32")] + ), + map_free_vars=True, + ) # check TupleGetItem type/shape gv2_bind = func_j.body.blocks[0].bindings[2] isinstance(gv2_bind.value, relax.TupleGetItem) - assert gv2_bind.value.checked_type.ndim == 2 - assert gv2_bind.value.checked_type.dtype == "float32" - assert gv2_bind.var.checked_type.ndim == 2 - assert gv2_bind.var.checked_type.dtype == "float32" - check_shape(gv2_bind.value.shape, ("n", "n")) - check_shape(gv2_bind.var, ("n", "n")) + + assert_structural_equal( + gv2_bind.var.struct_info, relax.TensorStructInfo([n, n], "float32"), map_free_vars=True + ) def test_function_attrs(): @@ -899,7 +876,6 @@ def memory(x: R.Tensor) -> R.Tensor: b0, b1, b2, b3 = memory.body.blocks[0].bindings assert b0.value.op.name == "relax.memory.alloc_storage" assert isinstance(b0.value.args[0], relax.ShapeExpr) - check_shape(b0.value.args[0], (1024,)) assert isinstance(b0.value.attrs, relax.op.MemAllocStorageAttrs) assert b1.value.op.name == "relax.memory.alloc_tensor" diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index bcc7c5d77338..d58a2063f2d2 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -327,7 +327,7 @@ def f(x: R.Tensor(("n", "n"))) -> R.Tensor: return r @R.function - def g(y: R.Tensor(("n", "n"))) -> R.Tensor: + def g(y: R.Tensor(("n", "n"))) -> R.Tensor(("n", "n"), "float32"): n = T.var("int64") r = relax.call_tir(my_matmul, (y, y), (n, n), dtype="float32") return r @@ -381,11 +381,6 @@ def test_shape_expr(): assert x.__str__() == "(10, 5)" -def test_runtime_dep_shape(): - x = relax.RuntimeDepShape() - assert x.__str__() == "None" - - def test_func_type(): # Since current all functions have "global_symbol" attribute, we can't # use the same name for different functions, even it's a local function. diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 61b95d762507..6bdcd0701c94 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -102,22 +102,17 @@ def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): assert len(After.get_global_vars()) == 2 main = After["main"] ewise_fma_fused = After["ewise_fma_fused"] - - # check sub function call type inference - assert_structural_equal(ewise_fma_fused.body.checked_type, relax.DynTensorType(2, "float32")) sub_func_call = main.body.blocks[0].bindings[1].value sub_func_call_var = main.body.blocks[0].bindings[1].var - assert_structural_equal(sub_func_call.checked_type, relax.DynTensorType(2, "float32")) - assert_structural_equal(sub_func_call_var.checked_type, relax.DynTensorType(2, "float32")) - # check sub function call shape inference - assert isinstance(ewise_fma_fused.body.shape, relax.ShapeExpr) - assert ewise_fma_fused.body.shape.values[0] == 3 - assert ewise_fma_fused.body.shape.values[1] == 4 - assert sub_func_call.shape.values[0] == 3 - assert sub_func_call.shape.values[1] == 4 - assert sub_func_call_var.shape.values[0] == 3 - assert sub_func_call_var.shape.values[1] == 4 + # check sub function struct_info inference + assert_structural_equal( + ewise_fma_fused.body.struct_info, relax.TensorStructInfo([3, 4], "float32") + ) + assert_structural_equal(sub_func_call.struct_info, relax.TensorStructInfo([3, 4], "float32")) + assert_structural_equal( + sub_func_call_var.struct_info, relax.TensorStructInfo([3, 4], "float32") + ) def test_fma_fuse_python(): @@ -137,22 +132,17 @@ def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): assert len(After.get_global_vars()) == 2 main = After["main"] ewise_fma_fused = After["ewise_fma_fused"] - - # check sub function call type inference - assert_structural_equal(ewise_fma_fused.body.checked_type, relax.DynTensorType(2, "float32")) sub_func_call = main.body.blocks[0].bindings[1].value sub_func_call_var = main.body.blocks[0].bindings[1].var - assert_structural_equal(sub_func_call.checked_type, relax.DynTensorType(2, "float32")) - assert_structural_equal(sub_func_call_var.checked_type, relax.DynTensorType(2, "float32")) - # check sub function call shape inference - assert isinstance(ewise_fma_fused.body.shape, relax.ShapeExpr) - assert ewise_fma_fused.body.shape.values[0] == 3 - assert ewise_fma_fused.body.shape.values[1] == 4 - assert sub_func_call.shape.values[0] == 3 - assert sub_func_call.shape.values[1] == 4 - assert sub_func_call_var.shape.values[0] == 3 - assert sub_func_call_var.shape.values[1] == 4 + # check sub function call type inference + assert_structural_equal( + ewise_fma_fused.body.struct_info, relax.TensorStructInfo([3, 4], "float32") + ) + assert_structural_equal(sub_func_call.struct_info, relax.TensorStructInfo([3, 4], "float32")) + assert_structural_equal( + sub_func_call_var.struct_info, relax.TensorStructInfo([3, 4], "float32") + ) def test_dataflowpass_fail(): @@ -216,33 +206,6 @@ def main(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")): relax.transform.FailTestRewrite()(TestRemoveSymbolicVar) -def test_visit_shape(): - @tvm.script.ir_module - class TestVisitShape: - @R.function - def foo(x: R.Tensor(("m", "n"), "float32")): - gv0 = R.add(x, x) - return gv0 - - mod = TestVisitShape - - shape_expr = [] - - def fvisit(e): - if isinstance(e, relax.ShapeExpr): - nonlocal shape_expr - shape_expr.append(e) - - relax.analysis.post_order_visit(mod["foo"], fvisit) - - # should have visited ShapeExpr 3 times - # the first time being visited is x.shape - # the last two times are the call node's shape and gv0's shape - assert len(shape_expr) == 3 - assert shape_expr[0] == mod["foo"].params[0].shape - assert shape_expr[1] == shape_expr[2] - - def test_to_non_dataflow(): @tvm.script.ir_module class TestToNonDataflow: diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py index a4bed0299f3c..bd871143946f 100644 --- a/tests/python/relax/test_tvmscript_ir_builder.py +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -84,7 +84,7 @@ def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): n = tir.Var("n", dtype="int64") bb = relax.BlockBuilder() with bb.function("foo", (x, y)): - bb.match_shape_binding(relax.MatchShape(x, (m,), var=None)) + bb.emit_normalized(relax.MatchShape(x, (m,), var=None)) y1 = bb.match_shape(y, (n,)) bb.emit_func_output(relax.ShapeExpr([m, n * 2])) mod = bb.get() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 859884856009..596608683544 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -15,16 +15,16 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional, Union, List +from typing import Optional, Union import pytest import tvm import tvm.testing from tvm import IRModule, relax, tir +from tvm.relax import DynTensorType from tvm.script.parser import ir as I from tvm.script.parser import relax as R from tvm.script.parser import tir as T -from tvm.relax import RuntimeDepShape, DynTensorType def _check( @@ -203,7 +203,7 @@ def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): n = tir.Var("n", dtype="int64") bb = relax.BlockBuilder() with bb.function("foo", (x, y)): - bb.match_shape_binding(relax.MatchShape(x, (m,), var=None)) + bb.emit_normalized(relax.MatchShape(x, (m,), var=None)) y1 = bb.match_shape(y, (n,)) bb.emit_func_output(relax.ShapeExpr([m, n * 2])) _check(foo, bb.get()["foo"]) @@ -475,30 +475,22 @@ def foo( o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, type_args=R.Object) return o - def _check_type_shape(binding, expected_type, expected_shape): - tvm.ir.assert_structural_equal(binding.var.checked_type, expected_type) - tvm.ir.assert_structural_equal(binding.var.shape_, expected_shape) + def _check_struct_info(binding, expected_sinfo): + tvm.ir.assert_structural_equal(binding.var.struct_info, expected_sinfo) # Cannot use block builder here because we need to check the annotated type, # which may be inconsistent with deduced type. assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo) - m = foo.params[0].shape[1] + m = relax.get_shape_of(foo.params[0])[1] bindings = foo.body.blocks[0].bindings - _check_type_shape( - bindings[0], - relax.DynTensorType(ndim=2, dtype="float32"), - relax.ShapeExpr([tvm.tir.IntImm("int64", 32), m]), - ) - _check_type_shape(bindings[1], relax.DynTensorType(dtype=""), RuntimeDepShape()) - _check_type_shape(bindings[2], relax.DynTensorType(ndim=2, dtype=""), RuntimeDepShape()) - _check_type_shape(bindings[3], relax.DynTensorType(dtype=""), RuntimeDepShape()) - _check_type_shape(bindings[4], relax.ShapeType(), None) - _check_type_shape( - bindings[5], - relax.DynTensorType(ndim=2, dtype="int8"), - relax.ShapeExpr([tvm.tir.IntImm("int64", 1), tvm.tir.IntImm("int64", 1)]), - ) - _check_type_shape(bindings[6], relax.ObjectType(), None) + + _check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32")) + _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=-1)) + _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=2)) + _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=-1)) + _check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1)) + _check_struct_info(bindings[5], relax.TensorStructInfo([1, 1], "int8")) + _check_struct_info(bindings[6], relax.ObjectStructInfo()) def test_annotate_override(): @@ -613,10 +605,6 @@ def foo(x: R.Tensor((10, 5), "float32")) -> R.Tensor((10, 5), "float32"): s = R.add(x, x) return s - # TODO(relax-team): enable it after fix block builder - # Current error: `gv2.shape` is different: (10, 5) vs RuntimeDepShape() - # tvm.ir.assert_structural_equal(Mod0, Mod1) - with pytest.raises(OSError): @I.ir_module