Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][REFACTOR] Remove ir_pass in favor of analysis/transform. #5415

Merged
merged 1 commit into from
Apr 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion include/tvm/te/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ TVM_DLL void AutoInlineInjective(Schedule sch);
*/
Map<IterVar, Range> InferBound(const Schedule& sch);

/*!
* \brief Verify if there is any argument bound to compact buffer.
*
* \param stmt The stmt to be verified.
* \return true if there is any buffer_bind_scope attribute found,
* otherwise, false.
*/
bool VerifyCompactBuffer(const Stmt& stmt);

/*!
* \brief Schedule s' dependent operations.
*
Expand All @@ -72,7 +81,6 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);


/*!
* \brief Try to modify the AST generated by ScheduleOps to support TensorCore.
*
Expand Down
105 changes: 101 additions & 4 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
#define TVM_TIR_ANALYSIS_H_

#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>

#include <string>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -59,7 +60,47 @@ struct ExprDeepEqual {
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);

/*!
* \brief Whether the expression have side effect.
* \param expr The expression to be checked.
* \return whether expression have side effect
*/
TVM_DLL bool HasSideEffect(const PrimExpr& expr);

/*!
* \brief Whether e expression used any var in variable set..
* \param expr The expression to be checked.
* \param vset_contains The check function to see if var is in the vset.
* \return Whether e uses vset.
*/
TVM_DLL bool ExprUseVar(const PrimExpr& expr,
std::function<bool(const VarNode*)> vset_contains);

/*!
* \brief Whether e expression used var.
* \param expr The expression to be checked.
* \param var The variable.
* \return Whether e uses v.
*/
inline bool ExprUseVar(const PrimExpr& expr, const Var& var) {
return ExprUseVar(expr, [&](const VarNode* node) {
return var.get() == node;
});
}


/*!
* \brief Verifies whether the IR stmt or Expr is in SSA form.
* That is: each Var is defined and assigned once(in Let/For)
*
* \param func The function to be verified.
* \return Whether IR is in SSA form.
*
* \note All passes in TIR consume and produce SSA form.
*/
TVM_DLL bool VerifySSA(const PrimFunc& func);

/*!
* \brief Verify if memory accesses are legal for a specific target device type.
Expand All @@ -68,11 +109,67 @@ Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal. This pass performs verification for this case.
*
* \param mod The module to be verified.
* \param func The function to be verified.
* \return Success of memory verification.
*/
void VerifyMemory(const IRModule& mod);
TVM_DLL bool VerifyMemory(const PrimFunc& func);

/*!
* \brief Verify the correctness of a GPU code
* It will check the whether the amount of memory usage or the number of threads
* in a block exceeds the limit
* \param func The function to be checked
* \param constraints The dict to specify constraints to check.
* Possible keys are
*
* "max_local_memory_per_block": Total amount of local memory per block (in bytes).
* "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
* "max_threads_per_block": Maximum number of threads per block.
* "max_thread_x": Maximum length of threadIdx.x.
* "max_thread_y": Maximum length of threadIdx.y.
* "max_thread_z": Maximum length of threadIdx.z.
*
* If one key is missing in this argument, the pass won't check for that item.
* \return valid Whether it is a valid GPU code
*
*/
TVM_DLL bool VerifyGPUCode(const PrimFunc& func,
Map<std::string, PrimExpr> constraints);

// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {

using tvm::transform::Pass;
using tvm::transform::PassContext;

/*!
* \brief Pass variant of VerifySSA.
*
* \returns The pass.
* \sa tvm::tir::VerifySSA
*/
TVM_DLL Pass VerifySSA();

/*!
* \brief Pass variant of VerifyMemory.
*
* \returns The pass.
* \sa tvm::tir::VerifyMemory
*/
TVM_DLL Pass VerifyMemory();

/*!
* \brief Pass variant of VerifyGPUCode.
*
* \param constraints The dict to specify constraints to check.
*
* \returns The pass.
* \sa tvm::tir::VerifyGPUCode
*/
TVM_DLL Pass VerifyGPUCode(Map<std::string, PrimExpr> constraints);

} // namespace transform
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_
145 changes: 0 additions & 145 deletions include/tvm/tir/ir_pass.h

This file was deleted.

19 changes: 17 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ TVM_DLL Pass InstrumentBoundCheckers();
*/
TVM_DLL Pass MakePackedAPI(int num_unpacked_args);


/*!
* \brief Remap the thread axis
*
Expand All @@ -241,7 +240,6 @@ TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
*/
TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map);


/*!
* \brief Lower custom datatypes.
*
Expand All @@ -251,6 +249,13 @@ TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map);
*/
TVM_DLL Pass LowerCustomDatatypes();

/*!
* \brief Decorate all the function's body as device function.
*
* \return The pass.
*/
TVM_DLL Pass DecorateDeviceScope();

/*!
* \brief Split the function into a host function and device functions.
*
Expand Down Expand Up @@ -334,6 +339,16 @@ TVM_DLL Pass CombineContextCall();
*/
TVM_DLL Pass NarrowDataType(int target_bits);

/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
* to avoid pointer casting in backend when possible.
*
* \return The pass.
*/
TVM_DLL Pass PointerValueTypeRewrite();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

import tvm._ffi
from tvm import nd, rpc as _rpc, target as _target
from tvm.tir import ir_pass
from tvm.error import TVMError
from tvm.target import build_config
from tvm.driver import build
Expand Down Expand Up @@ -616,7 +615,7 @@ def gpu_verify_pass(**kwargs):
This pass will check memory usage and number of threads per block.
"""
def verify_pass(f, *_):
valid = ir_pass.VerifyGPUCode(f.body, kwargs)
valid = tvm.analysis.verify_gpu_code(f, kwargs)
if not valid:
raise InstantiationError("Skipped because of invalid gpu kernel")
return f
Expand Down
6 changes: 2 additions & 4 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from tvm.ir import container
from tvm.ir import CallingConv
from tvm.target import codegen, BuildConfig
from tvm.tir import ir_pass
from tvm.te import tensor
from tvm.te import schedule
from tvm import target as _target
Expand Down Expand Up @@ -111,7 +110,7 @@ def form_irmodule(sch, args, name, binds):
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)

compact = ir_pass.VerifyCompactBuffer(stmt)
compact = schedule.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)

stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
Expand Down Expand Up @@ -246,9 +245,8 @@ def _build_for_device(input_mod, target, target_host):

mod_mixed = input_mod
mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)
tvm.tir.analysis.verify_memory(mod_mixed)

opt_mixed = []
opt_mixed = [tvm.tir.transform.VerifyMemory()]
if len(mod_mixed.functions) == 1:
opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]
if BuildConfig.current().detect_global_barrier:
Expand Down
Loading