Skip to content

Commit

Permalink
[REFACTOR][ARCH] Introduce StructInfo M0 (apache#314)
Browse files Browse the repository at this point in the history
* [IR] Introduce StructInfo

* StructInfoFunctor and Analysis Support

* [TVMScript] Parse type/shape annotation with StructInfo

* remove runtime type assign

* Remove type/shape during parsing (#2)

* Normalizer prep: simple checks and legacy function renaming.

* Struct info deduction in BlockBuilder.

* Two TODOs

* StructInfo Normalizer Fixes (#3)

* StructInfo AST Fix

* Fix Extern Func Deduction and shape mutator.

* Update VoidStructInfo & globalvar (#4)

* Fix passes and proper sinfo propagation.

* Refactor EraseToWellDefined to Enable Remapping

* [WIP] First stab at symbolic param tracking

* Update EraseToWellDefined to support symbolic shape return (apache#5)

* fix R.shape with ndim (apache#6)

* Remove update shape/type

* Address review comment, AnnotateTypeShape=>AnnotateStructInfo

* Update include/tvm/script/ir_builder/relax/frame.h

Co-authored-by: Ruihang Lai <ruihangl@cs.cmu.edu>

* Address comments

* Update printer to use structinfo (apache#7)

* Update Error mechanism to prep for obj loc based reporting

* Symbolic shape aware function call return value derivation.

The main flow works as follows:
- Match and populate shape_var_map and var_map by visit each pair of
  param and call arguments.
- Call EraseToWellDefined to map the ret parameter to new result.

* [ANALYSIS] Refactor well-form to only look at struct info.

* Update comments according to reviews.

* Update include/tvm/relax/struct_info.h

Co-authored-by: Ruihang Lai <ruihangl@cs.cmu.edu>

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Tianqi Chen <tqchen>
Co-authored-by: Ruihang Lai <ruihangl@cs.cmu.edu>
  • Loading branch information
3 people authored Dec 21, 2022
1 parent 1517017 commit e617343
Show file tree
Hide file tree
Showing 68 changed files with 4,971 additions and 2,692 deletions.
27 changes: 27 additions & 0 deletions include/tvm/ir/diagnostic.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ class DiagnosticNode : public Object {
DiagnosticLevel level;
/*! \brief The span at which to report an error. */
Span span;
/*!
* \brief The object location at which to report an error.
*
* The object loc provides a location when span is not always
* available during transformation. The error reporter can
* still pick up loc->span if necessary.
*/
ObjectRef loc;
/*! \brief The diagnostic message. */
String message;

Expand Down Expand Up @@ -86,6 +94,18 @@ class Diagnostic : public ObjectRef {
static DiagnosticBuilder Warning(Span span);
static DiagnosticBuilder Note(Span span);
static DiagnosticBuilder Help(Span span);
// variants uses object location
static DiagnosticBuilder Bug(ObjectRef loc);
static DiagnosticBuilder Error(ObjectRef loc);
static DiagnosticBuilder Warning(ObjectRef loc);
static DiagnosticBuilder Note(ObjectRef loc);
static DiagnosticBuilder Help(ObjectRef loc);
// variants uses object ptr.
static DiagnosticBuilder Bug(const Object* loc);
static DiagnosticBuilder Error(const Object* loc);
static DiagnosticBuilder Warning(const Object* loc);
static DiagnosticBuilder Note(const Object* loc);
static DiagnosticBuilder Help(const Object* loc);

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Diagnostic, ObjectRef, DiagnosticNode);
};
Expand All @@ -104,6 +124,11 @@ class DiagnosticBuilder {
/*! \brief The span of the diagnostic. */
Span span;

/*!
* \brief The object location at which to report an error.
*/
ObjectRef loc;

template <typename T>
DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*)
stream_ << val;
Expand All @@ -117,6 +142,8 @@ class DiagnosticBuilder {

DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {}

DiagnosticBuilder(DiagnosticLevel level, ObjectRef loc) : level(level), loc(loc) {}

operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); }

private:
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,13 @@ class RelayExprNode : public BaseExprNode {
*/
mutable Optional<ObjectRef> shape_ = Optional<ObjectRef>();

/*!
* \brief Stores the result of structure information of the
* expression that encapsulate both static shape and
* runtime information such as shape.
*/
mutable Optional<ObjectRef> struct_info_ = Optional<ObjectRef>();

/*!
* \return The checked_type
*/
Expand Down Expand Up @@ -471,6 +478,7 @@ class GlobalVarNode : public RelayExprNode {
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
v->Visit("struct_info_", &struct_info_);
}

bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class PrimType : public Type {
* \brief Constructor
* \param dtype The corresponding dtype.
*/
TVM_DLL explicit PrimType(runtime::DataType dtype);
TVM_DLL explicit PrimType(runtime::DataType dtype, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode);
};
Expand Down
246 changes: 245 additions & 1 deletion include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,266 @@

/*!
* \file tvm/relax/analysis.h
* \brief The set of Relax specific analysis passes.
* \brief The set of Relax specific analysis on IR.
*/
#ifndef TVM_RELAX_ANALYSIS_H_
#define TVM_RELAX_ANALYSIS_H_

#include <tvm/arith/analyzer.h>
#include <tvm/ir/diagnostic.h>
#include <tvm/ir/module.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/tir/function.h>

#include <functional>
#include <utility>

namespace tvm {
namespace relax {
//-----------------------------------
// Shape expression analysis
//----------------------------------
/*!
* \brief Can prove the two symbolic shape arrays equals to each other.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \param ana The analyzer used for integer analysis.
* \return The prove result.
*
* \note This function does best effort prove, which means
* if result is false, there is still possibility that
* two shapes equals to each other during runtime.
*/
TVM_DLL bool CanProveShapeEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs,
arith::Analyzer* ana);

/*!
* \brief Can prove the two symbolic shape expressions equals to each other.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \param ana The analyzer used for integer analysis.
*
* \note This function does best effort prove, which means
* if result is false, there is still possibility that
* two shapes equals to each other during runtime.
*/
TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana);

//-----------------------------------
// Foundational StructInfo analysis
//-----------------------------------
/*!
* \brief Get the corresponding static type from a given struct info.
* \param info The struct info.
* \return the corresponding static type.
*/
TVM_DLL Type GetStaticType(const StructInfo& info);

/*!
* \brief Get the corresponding struct info from static type.
* \param type The input type
* \return the corresponding struct info.
*/
TVM_DLL StructInfo StructInfoFromType(const Type& type);

// TODO(relax-team): Remove legacy shape related functionalities after phasing out shape_
/*!
* \brief Get the corresponding struct info from static type.
* \param type The input type
* \param shape_hint The shape hint
* \return the corresponding struct info.
*/
TVM_DLL StructInfo StructInfoFromTypeLegacyShapeHint(const Type& type, Optional<Expr> shape_hint);

/*!
* \brief Get the corresponding legacy shape hint from struct info
* \param info The struct info.
* \return the corresponding legacy shape hint.
*/
TVM_DLL Optional<Expr> GetLegacyShapeHint(const StructInfo& info);

/*!
* \return Derive the call's ret value struct info from inputs.
* \param func_info The function struct info.
* \param call The call expression to be derived.
* \param ctx The builder context.
* \param ana Optional context analyzer to prove symbolic expression equality.
* \return The derived struct info of the call.
* \note call->op field is ignored during derivation and we only rely on information
* presented by func_sinfo.
*/
TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call,
const BlockBuilder& ctx, arith::Analyzer* ana = nullptr);

/*!
* \brief Erase the info to a corresponding more coarse grained
* struct info that is still well-defined(with all the vars in scope).
*
* When we are returning a StructInfo to another scope,
* it is important to remember that StructInfo may carry
* dependencies on var that is not defined the other scope.
*
* In such cases, it is important to call EraseToWellDefined to get
* another StructInfo that **only** contains the vars that are defined
* in the target scope.
*
* For example, consider the following function
*
* \code
*
* @R.function
* def f(x: R.Tensor[(n, m)]):
* k = tir.Var("k", "int64")
* v0 = opaque_fn(x)
* v1 = match_cast(v0, R.Tensor[(n, k)])
* v2 : R.Tensor[(n + 1, k + 2)] = pad(v1)
* return v2
*
* \endcode
*
* In the above code, the return value y have shape `(n + 1, k + 2)`,
* However, at the level of function signature, only n, m are defined,
* k is undefined here.
*
* When we call EraseToWellDefined(R.Tensor[(n + 1, k + 2)], fshape_var_map={n: n, m: m}),
* we will obtain R.Tensor(ndim=2), which is an erased info that does not depend
* on k(which is undefined from parameter signature).
*
* However, if we call EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: n, m: m}),
* Then the return value will be R.Tensor[(n + 1, m)], because both n and m are defined.
*
* We can also make these var map to return a different expression.
* For example, EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: 2, m: m})
* will give us R.Tensor[(3, m)], where n get replaced by 2.
*
* Use this function in the following scenarios:
* - Decide the struct_info of expr with sub-scopes, such as If, SeqExpr
* - Decide the deduced return struct_info of a function that can be fully decided by params.
*
* \param info The struct info.
* \param f_shape_var_map callback function to specify
* whether a symbolic shape var is defined and the value it maps to,
* return nullopt if var is undefined.
* \param f_var_defined callback function to specify
* whether a var is defined in the target scope and the value it maps to,
* return nullopt if var is undefined.
* \param ana Optional context analyzer to prove symbolic expression equality.
*
* \return the corresponding erased struct info.
*/
TVM_DLL StructInfo
EraseToWellDefined(const StructInfo& info,
std::function<Optional<PrimExpr>(const tir::Var& var)> f_shape_var_map = nullptr,
std::function<Optional<Expr>(const Var& var)> f_var_map = nullptr,
arith::Analyzer* ana = nullptr);

/*!
* \brief EraseToWellDefined variant with map.
* \param info The struct info.
* \param f_shape_var_map callback function to specify
* whether a symbolic shape var is defined and the value it maps to,
* return nullopt if var is undefined.
* \param f_var_defined callback function to specify
* whether a var is defined in the target scope and the value it maps to,
* return nullopt if var is undefined.
* \param ana Optional context analyzer to prove symbolic expression equality.
*
* \return the corresponding erased struct info.
*/
TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, Map<tir::Var, PrimExpr> shape_var_map,
Map<Var, Expr> var_map, arith::Analyzer* ana = nullptr);

/*!
* \brief Fine grained result of base check.
*
* This analysis comes with different levels of checking failures
* that can help to customize the compilation decisions.
*
* For a given pair of lhs_struct_info, rhs_struct_info. We adopt
* the following terminology:
* - LSet = {value | value mactches lhs_struct_info}
* - RSet = {value | value mactches rhs_struct_info}
*
* See the definition of each level below.
*/
enum class BaseCheckResult {
/*!
* \brief The two value sets have no intersection at all: Interset(LSet, RSet) = empty
*/
kFailL0 = 0,
/*!
* \brief LSet is not superset of RSet by only looking at static information.
*
* \note This level will trigger static type checking error when lhs is param and rhs is arg.
*/
kFailL1 = 1,
/*!
* \brief WLSet is not superset of RSet because of mismatch in value information.
*
* L1-level mismatches in params of FuncStructInfo is categorized as
* If lhs is FuncStructInfo, then L1-level mismatch in its params
* is categorized as L2-level mismatch for lhs.
*
* Design considerations for functions:
* - (a) We want to be able to erase type/value in function signature
* when we unify function struct info and preserve simpler representations.
* - (b) We automatically insert match_cast at function boundary, so
* we can erase (int)->int argument as (object)->int.
* The input shape/type mismatch will be detected by runtime checks at function boundary.
* This behavior is also consistent with the PackedFunc behavior.
*
* \note This level means there is no problem about static known information.
* It is OK for the checker to do best effort and return this value.
*/
kFailL2 = 2,
/*! \brief LSet is superset of RSet. */
kPass = 3
};

/*!
* \brief Run a base check to see if base subsumes derived.
*
* This function returns fine-grained base-check result on reasons of failure.
*
* \param base The base struct info.
* \param derived The derived struct info.
* \param ana Optional context analyzer to prove symbolic expression equality.
* \return Whether the relation holds.
*
* \sa BaseCheckResult
*/
TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived,
arith::Analyzer* ana = nullptr);

/*!
* \brief Check the relation of two struct info to see if one subsumes another one.
*
* \param base The base struct info.
* \param derived The derived struct info.
* \param ana Optional context analyzer to prove symbolic expression equality.
* \return Whether the relation holds.
*/
TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived,
arith::Analyzer* ana = nullptr);

/*!
* \brief Unify the two struct info their least common ancestor.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \param ana Optional context analyzer to prove symbolic expression equality.
* \return The unified information.
*/
TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
arith::Analyzer* ana = nullptr);

//-----------------------------------
// General IR analysis
//----------------------------------
/*!
* \brief Check if the IRModule is well formed.
*
Expand Down
Loading

0 comments on commit e617343

Please sign in to comment.