Skip to content

Commit

Permalink
TIR debug info
Browse files Browse the repository at this point in the history
  • Loading branch information
driazati committed Oct 13, 2022
1 parent a61c1ad commit a2fdb8f
Show file tree
Hide file tree
Showing 19 changed files with 764 additions and 93 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,6 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py

# Used in CI to communicate between Python and Jenkins
.docker-image-names/

# Printed TIR code on disk
*.tir
2 changes: 2 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ class IntImm : public PrimExpr {
TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode);
};

/*!
Expand Down Expand Up @@ -572,6 +573,7 @@ class FloatImm : public PrimExpr {
TVM_DLL FloatImm(DataType dtype, double value, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode);
};

/*!
Expand Down
30 changes: 30 additions & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class StringImm : public PrimExpr {
public:
TVM_DLL StringImm(String value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode);
};

/*!
Expand Down Expand Up @@ -117,6 +118,7 @@ class Cast : public PrimExpr {
public:
TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode);
};

/*!
Expand Down Expand Up @@ -165,6 +167,7 @@ class Add : public PrimExpr {
public:
TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode);
};

/*! \brief a - b */
Expand All @@ -181,6 +184,7 @@ class Sub : public PrimExpr {
public:
TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode);
};

/*! \brief a * b */
Expand All @@ -197,6 +201,7 @@ class Mul : public PrimExpr {
public:
TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode);
};

/*!
Expand All @@ -216,6 +221,7 @@ class Div : public PrimExpr {
public:
TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode);
};

/*!
Expand All @@ -235,6 +241,7 @@ class Mod : public PrimExpr {
public:
TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode);
};

/*! \brief Floor division, floor(a/b) */
Expand All @@ -251,6 +258,7 @@ class FloorDiv : public PrimExpr {
public:
TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode);
};

/*! \brief The remainder of the floordiv */
Expand All @@ -267,6 +275,7 @@ class FloorMod : public PrimExpr {
public:
TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode);
};

/*! \brief min(a, b) */
Expand All @@ -283,6 +292,7 @@ class Min : public PrimExpr {
public:
TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode);
};

/*! \brief max(a, b) */
Expand All @@ -299,6 +309,7 @@ class Max : public PrimExpr {
public:
TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode);
};

/*!
Expand Down Expand Up @@ -347,6 +358,7 @@ class EQ : public PrimExpr {
public:
TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode);
};

/*! \brief a != b */
Expand All @@ -363,6 +375,7 @@ class NE : public PrimExpr {
public:
TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode);
};

/*! \brief a < b */
Expand All @@ -379,6 +392,7 @@ class LT : public PrimExpr {
public:
TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode);
};

/*! \brief a <= b */
Expand All @@ -395,6 +409,7 @@ class LE : public PrimExpr {
public:
TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode);
};

/*! \brief a > b */
Expand All @@ -411,6 +426,7 @@ class GT : public PrimExpr {
public:
TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode);
};

/*! \brief a >= b */
Expand All @@ -427,6 +443,7 @@ class GE : public PrimExpr {
public:
TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode);
};

/*! \brief a && b */
Expand Down Expand Up @@ -466,6 +483,7 @@ class And : public PrimExpr {
public:
TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode);
};

/*! \brief a || b */
Expand Down Expand Up @@ -505,6 +523,7 @@ class Or : public PrimExpr {
public:
TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode);
};

/*! \brief !a */
Expand Down Expand Up @@ -540,6 +559,7 @@ class Not : public PrimExpr {
public:
TVM_DLL Not(PrimExpr a, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode);
};

/*!
Expand Down Expand Up @@ -591,6 +611,7 @@ class Select : public PrimExpr {
TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode);
};

/*!
Expand Down Expand Up @@ -706,6 +727,7 @@ class ProducerLoad : public PrimExpr {
TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode);
};

/*!
Expand Down Expand Up @@ -765,6 +787,7 @@ class Load : public PrimExpr {
TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate,
Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LoadNode);
};

/*!
Expand Down Expand Up @@ -817,6 +840,7 @@ class Ramp : public PrimExpr {
public:
TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode);
};

/*! \brief Create a vector where all the elements are value. */
Expand Down Expand Up @@ -856,6 +880,7 @@ class Broadcast : public PrimExpr {
public:
TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode);
};

/*!
Expand Down Expand Up @@ -902,6 +927,7 @@ class Let : public PrimExpr {
public:
TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode);
};

/*!
Expand Down Expand Up @@ -948,6 +974,7 @@ class Call : public PrimExpr {
public:
TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
};

/*!
Expand Down Expand Up @@ -995,6 +1022,7 @@ class Shuffle : public PrimExpr {
TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode);
};

// Reduce operator
Expand Down Expand Up @@ -1124,6 +1152,7 @@ class Reduce : public PrimExpr {
int value_index, Array<PrimExpr> init, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode);
};

/*! \brief Any shape. */
Expand Down Expand Up @@ -1159,6 +1188,7 @@ class Any : public PrimExpr {
TVM_DLL Any(Span span = Span());

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode);
};

/*
Expand Down
13 changes: 13 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class LetStmt : public Stmt {
TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode);
};

/*!
Expand Down Expand Up @@ -158,6 +159,7 @@ class AttrStmt : public Stmt {
TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode);
};

/*!
Expand Down Expand Up @@ -206,6 +208,7 @@ class AssertStmt : public Stmt {
TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode);
};

/*!
Expand Down Expand Up @@ -271,6 +274,7 @@ class Store : public Stmt {
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StoreNode);
};

/*!
Expand Down Expand Up @@ -442,6 +446,7 @@ class ProducerStore : public Stmt {
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode);
};

/*!
Expand Down Expand Up @@ -505,6 +510,7 @@ class ProducerRealize : public Stmt {
String storage_scope = "", Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode);
};

/*!
Expand Down Expand Up @@ -679,6 +685,7 @@ class AllocateConst : public Stmt {
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode);
};

/*! \brief Declare a buffer that can be used in the body */
Expand Down Expand Up @@ -812,6 +819,7 @@ class SeqStmt : public Stmt {
};

TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode);
};

/*!
Expand Down Expand Up @@ -858,6 +866,7 @@ class IfThenElse : public Stmt {
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode);
};

/*!
Expand Down Expand Up @@ -897,6 +906,7 @@ class Evaluate : public Stmt {
explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}

TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode);
};

/*!
Expand Down Expand Up @@ -1054,6 +1064,7 @@ class While : public Stmt {
TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode);
};

/*!
Expand Down Expand Up @@ -1098,6 +1109,7 @@ class Prefetch : public Stmt {
TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span());

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode);
};

/*!
Expand Down Expand Up @@ -1202,6 +1214,7 @@ class MatchBufferRegion : public ObjectRef {
TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);

TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode);
};

/*!
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,13 @@ TVM_DLL Pass LowerAsyncDMA();
*/
TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);

/*!
* \brief Add TIR-printer output as debug information to all ops in the module
* \return The pass.
*/

TVM_DLL Pass InstallDebugSpans();

/*!
* \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and
* "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g.,
Expand Down
Loading

0 comments on commit a2fdb8f

Please sign in to comment.