Skip to content

Commit

Permalink
[Dynamic] M2 for Pad-Einsum (S0) (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and tqchen committed Jun 17, 2023
1 parent e473781 commit 6ca9f26
Show file tree
Hide file tree
Showing 7 changed files with 500 additions and 374 deletions.
2 changes: 2 additions & 0 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_TIR_BUFFER_H_

#include <tvm/ir/expr.h>
#include <tvm/node/script_printer.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/string.h>
#include <tvm/tir/var.h>
Expand Down Expand Up @@ -162,6 +163,7 @@ class BufferNode : public Object {
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
};

/*!
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,13 @@ TVM_DLL PrimExpr bitwise_neg(PrimExpr a, Span span = Span());
*/
TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value,
Span span = Span());
/*!
* \brief Protected write. This is only used on BufferStore's immediate RHS to indicate that
* out-of-bound access will not be performed.
* \param expr The expression to be protected.
* \return The result expression.
*/
TVM_DLL PrimExpr protected_write(PrimExpr expr);
/*!
* \brief Mark condition as likely.
* \param cond The condition
Expand Down
7 changes: 7 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,13 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value,
{cond, true_value, false_value}, span);
}

TVM_TIR_REGISTER_OP("protected_write");

PrimExpr protected_write(PrimExpr expr) {
static const Op& op = Op::Get("tir.protected_write");
return tir::Call(expr.dtype(), op, {expr}, Span());
}

// likely
PrimExpr likely(PrimExpr cond, Span span) {
if (is_const_int(cond)) return cond;
Expand Down
18 changes: 0 additions & 18 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -707,24 +707,6 @@ Array<arith::IntSet> AnalyzeRegionLowerBound(const BufferRegion& region, const P
const StmtSRef& dom_high_exclusive,
arith::Analyzer* analyzer);

/*!
* \brief Check if buffer indices are all Vars and extr
* \param buffer_access The BufferLoad or BufferStore
* \return The indices if the indices are all Vars, otherwise NullOpt
*/
template <typename T>
Optional<Array<Var>> CheckTrivialBufferIndices(const T& buffer_access) {
Array<Var> indices;
for (const PrimExpr& index : buffer_access->indices) {
const VarNode* var = index.as<VarNode>();
if (var == nullptr) {
return NullOpt;
}
indices.push_back(GetRef<Var>(var));
}
return indices;
}

/*!
* \brief Simplify non-trivial expressions
* \param expr The expression to be simplified
Expand Down
3 changes: 1 addition & 2 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,7 @@ TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref
* \param block_sref The block sref that matches the Einsum pattern.
* \param padding The padding for each block iter.
*/
TVM_DLL void PadEinsum(ScheduleState self, const StmtSRef& block_sref,
const Array<Integer>& padding);
TVM_DLL void PadEinsum(ScheduleState self, StmtSRef block_sref, Array<Integer> padding);

/******** Schedule: Buffer transformation ********/
/*!
Expand Down
Loading

0 comments on commit 6ca9f26

Please sign in to comment.