Skip to content

Commit

Permalink
[REFACTOR] StructInfo M2: Cleanups on legacy shape related items (apa…
Browse files Browse the repository at this point in the history
…che#320)

* [REFACTOR] Remove shape function

* [WIP] Remove shape_, runtime_dep shape

* Remove shape_ pass Compile

* Remove RuntimeDepShape (#11)

* BlockBuilder: remove CanProveShapeEqual, consolidate binding emit to EmitNormalize

* Remove DimType, make get_shape_of API different from op.shape_of

Changes the init importing to direct import so the VSCode nagivator
can directly jump to the defintion point.

* Apply suggestions from code review

Co-authored-by: Ruihang Lai <ruihangl@cs.cmu.edu>

* Clarify cases where struct info can be determinstically derived

* Fix remaining testcases

* Remove InferShape/Type per comment.

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Ruihang Lai <ruihangl@cs.cmu.edu>
  • Loading branch information
3 people authored and junrushao committed Feb 5, 2023
1 parent 8e957dd commit baea09a
Show file tree
Hide file tree
Showing 66 changed files with 520 additions and 1,516 deletions.
16 changes: 0 additions & 16 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,15 +369,6 @@ class RelayExprNode : public BaseExprNode {
*/
mutable Type checked_type_ = Type(nullptr);

/*!
* \brief Stores the result of static shape analysis. It must be a RelayExpr
* and ObjectRef is used here to avoid cyclic typing.
*
* \note The value will be optional if a static shape can not be inferred.
* use .shape() instead to acesss an always defined shape expression.
*/
mutable Optional<ObjectRef> shape_ = Optional<ObjectRef>();

/*!
* \brief Stores the result of structure information of the
* expression that encapsulate both static shape and
Expand All @@ -390,13 +381,6 @@ class RelayExprNode : public BaseExprNode {
*/
inline const Type& checked_type() const;

/*!
* \return An expression which corresponds to the shape of the expression.
*
* Only valid when the expression's type is a Tensor.
*/
RelayExpr shape() const;

/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
Expand Down
2 changes: 0 additions & 2 deletions include/tvm/ir/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitType_(const relax::ShapeTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const relax::ObjectTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const relax::DynTensorTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const relax::DimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const relax::PackedFuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
Expand Down Expand Up @@ -123,7 +122,6 @@ class TypeFunctor<R(const Type& n, Args...)> {
TVM_TYPE_FUNCTOR_DISPATCH(relax::ShapeTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(relax::ObjectTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(relax::DynTensorTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(relax::DimTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(relax::PackedFuncTypeNode);
return vtable;
}
Expand Down
27 changes: 0 additions & 27 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,6 @@ TVM_DLL Type GetStaticType(const StructInfo& info);
*/
TVM_DLL StructInfo StructInfoFromType(const Type& type);

// TODO(relax-team): Remove legacy shape related functionalities after phasing out shape_
/*!
* \brief Get the corresponding struct info from static type.
* \param type The input type
* \param shape_hint The shape hint
* \return the corresponding struct info.
*/
TVM_DLL StructInfo StructInfoFromTypeLegacyShapeHint(const Type& type, Optional<Expr> shape_hint);

/*!
* \brief Get the corresponding legacy shape hint from struct info
* \param info The struct info.
* \return the corresponding legacy shape hint.
*/
TVM_DLL Optional<Expr> GetLegacyShapeHint(const StructInfo& info);

/*!
* \return Derive the call's ret value struct info from inputs.
* \param func_info The function struct info.
Expand Down Expand Up @@ -425,17 +409,6 @@ std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Function& fn);
*/
TVM_DLL Function RemoveAllUnused(const Function fn);

/*!
* \brief Given the argument vars and body, derives a return shape for a function with those args
* and that body. If the body's shape contains free shape vars (those not used in the args), the
* return shape is relaxed to RuntimeDepShape; otherwise, the body's shape is used.
*
* \param args The argument variables, ideally with the shape_ field filled in
* \param body The functino body, ideally with the shape_ field filled in
* \return An expression that can serve as the return shape for the function
*/
TVM_DLL Expr DeriveFuncRetShape(Array<Var> args, Expr body);

} // namespace relax
} // namespace tvm

Expand Down
35 changes: 0 additions & 35 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,6 @@ class BlockBuilderNode : public Object {
*/
virtual NameTable* name_table() = 0;

/*!
* \brief Check if two shape expressions can be proven equal at compile time.
* \param lhs The input lhs shape.
* \param rhs The input rhs shape.
* \return Whether we can prove lhs shape is the same as the rhs shape.
*/
virtual bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs) = 0;

/*!
* \brief Get the context IRModule in this builder.
*
Expand Down Expand Up @@ -167,15 +159,6 @@ class BlockBuilderNode : public Object {
*/
virtual Var Emit(Expr expr, String name_hint = "") = 0;

/*!
* \brief Emits a variable binding, and returns the bound Var.
* \param binding The variable binding.
* \return The bound variable.
*
* \note This function requires binding to be pre-normalized.
*/
virtual Var Emit(VarBinding binding) = 0;

/*!
* \brief Emit a MatchShape.
* \param value The value of the MatchShape to be emitted.
Expand All @@ -185,15 +168,6 @@ class BlockBuilderNode : public Object {
*/
virtual Var EmitMatchShape(Expr value, Array<PrimExpr> pattern, String name_hint = "") = 0;

/*!
* \brief Emit a MatchShape binding.
* \param binding The MatchShape binding to be emitted.
* \return The variable bound to the MatchShape.
*
* \note This function requires binding to be pre-normalized.
*/
virtual Var EmitMatchShape(MatchShape binding) = 0;

/*!
* \brief Generate an output for the current dataflow block.
* \param output The output variable of the block.
Expand All @@ -202,15 +176,6 @@ class BlockBuilderNode : public Object {
*/
virtual Var EmitOutput(Expr output, String name_hint = "") = 0;

/*!
* \brief Generate an output for the current dataflow block.
* \param binding The output binding to output.
* \return The variable bound to \p output.
*
* \note This function requires binding to be pre-normalized.
*/
virtual Var EmitOutput(VarBinding binding) = 0;

/*!
* \brief Emit a binding that is already normalized.
*
Expand Down
27 changes: 0 additions & 27 deletions include/tvm/relax/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class OrPattern;
class AndPattern;
class NotPattern;
class ShapePattern;
class RuntimeDepShapePattern;
class TypePattern;
class DataTypePattern;
class AttrPattern;
Expand Down Expand Up @@ -112,8 +111,6 @@ class DFPattern : public ObjectRef {
TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const;
/*! \brief Syntatic Sugar for creating a ShapePattern */
TVM_DLL ShapePattern HasShape(const Array<PrimExpr>& shape) const;
/*! \brief Syntatic Sugar for creating a RuntimeDepShapePattern */
TVM_DLL RuntimeDepShapePattern HasRuntimeDepShape() const;
/*! \brief Syntatic Sugar for duplicating the current pattern */
TVM_DLL DFPattern dup() const;

Expand Down Expand Up @@ -778,30 +775,6 @@ class ExternFuncPattern : public DFPattern {
TVM_DEFINE_OBJECT_REF_METHODS(ExternFuncPattern, DFPattern, ExternFuncPatternNode);
};

/*!
* \brief A pattern that asserting a root pattern has a runtime-dependent shape.
* \sa RuntimeDepShape
* \sa RuntimeDepShapePattern
*/
class RuntimeDepShapePatternNode : public DFPatternNode {
public:
DFPattern pattern; /*!< The root pattern to match */
void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "relax.dpl.RuntimeDepShapePattern";
TVM_DECLARE_FINAL_OBJECT_INFO(RuntimeDepShapePatternNode, DFPatternNode);
};

/*!
* \brief Managed reference to RuntimeDepShapePatternNode.
* \sa RuntimeDepShapePatternNode
*/
class RuntimeDepShapePattern : public DFPattern {
public:
TVM_DLL explicit RuntimeDepShapePattern(DFPattern pattern);
TVM_DEFINE_OBJECT_REF_METHODS(RuntimeDepShapePattern, DFPattern, RuntimeDepShapePatternNode);
};

/*! \brief Syntatic Sugar for creating a VarPattern with a name */
VarPattern IsVar(const String& name);
/*! \brief Syntatic Sugar for creating a ConstantPattern */
Expand Down
4 changes: 0 additions & 4 deletions include/tvm/relax/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;

virtual R VisitDFPattern_(const RuntimeDepShapePatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DataflowVarPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const GlobalVarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -135,7 +133,6 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);

RELAX_DFPATTERN_FUNCTOR_DISPATCH(RuntimeDepShapePatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(GlobalVarPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExternFuncPatternNode);
Expand Down Expand Up @@ -170,7 +167,6 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const WildcardPatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;

void VisitDFPattern_(const RuntimeDepShapePatternNode* op) override;
void VisitDFPattern_(const DataflowVarPatternNode* op) override;
void VisitDFPattern_(const GlobalVarPatternNode* op) override;
void VisitDFPattern_(const ExternFuncPatternNode* op) override;
Expand Down
Loading

0 comments on commit baea09a

Please sign in to comment.