Skip to content

Commit

Permalink
Add threadIdx filtering in Multi-Level-Tiling and Verify-GPU-Code (#20)
Browse files Browse the repository at this point in the history
* Add threadIdx filtering in Multi-Level-Tiling and Verify-GPU-Code

* minor

* minor

* turn off debug flag
  • Loading branch information
junrushao authored Jan 23, 2022
1 parent f748df4 commit 2f08d20
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 20 deletions.
10 changes: 9 additions & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1373,9 +1373,17 @@ constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperati
*/
constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";

/*! \brief Mark that tensor core is enbaled in the PrimExpr */
/*! \brief Mark that tensor core is enabled in the PrimExpr */
constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled";

/*! \brief The allowed range of thread extent in thread bindings */
constexpr const char* meta_schedule_thread_extent_low_inclusive =
"meta_schedule.thread_extent_low_inclusive";

/*! \brief The allowed range of thread extent in thread bindings */
constexpr const char* meta_schedule_thread_extent_high_inclusive =
"meta_schedule.thread_extent_high_inclusive";

/*!
* \brief Mark a block as generated by cache_read or cache_write block.
* 0 means cache_read; 1 means cache_write.
Expand Down
53 changes: 53 additions & 0 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,56 @@

#include "../utils.h"

namespace tvm {
namespace tir {

class ThreadExtentChecker : private StmtVisitor {
public:
static bool Check(const Stmt& stmt) {
try {
ThreadExtentChecker().VisitStmt(stmt);
return true;
} catch (const dmlc::Error& e) {
return false;
}
}

private:
void VisitStmt_(const ForNode* loop) {
if (IsThreadIdx(GetThreadScope(loop))) {
if (const int64_t* p_ext = GetLoopIntExtent(loop)) {
thread_extent_product *= *p_ext;
StmtVisitor::VisitStmt_(loop);
thread_extent_product /= *p_ext;
return;
} else {
throw dmlc::Error("Dynamic thread extent");
}
}
StmtVisitor::VisitStmt_(loop);
}

void VisitStmt_(const BlockNode* block) {
if (Optional<Integer> low_inclusive =
GetAnn<Integer>(block, attr::meta_schedule_thread_extent_low_inclusive)) {
if (Optional<Integer> high_inclusive =
GetAnn<Integer>(block, attr::meta_schedule_thread_extent_high_inclusive)) {
int64_t low = low_inclusive.value()->value;
int64_t high = high_inclusive.value()->value;
if (!(low <= thread_extent_product && thread_extent_product <= high)) {
throw dmlc::Error("Thread extent");
}
}
}
StmtVisitor::VisitStmt_(block);
}

int64_t thread_extent_product = 1;
};

} // namespace tir
} // namespace tvm

namespace tvm {
namespace meta_schedule {

Expand Down Expand Up @@ -66,6 +116,9 @@ class VerifyGPUCodeNode : public PostprocNode {
const GlobalVar& g_var = kv.first;
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
if (!tir::ThreadExtentChecker::Check(prim_func->body)) {
return false;
}
IRModule lowered{nullptr};
try {
auto pass_list = Array<tvm::transform::Pass>();
Expand Down
41 changes: 39 additions & 2 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,17 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
}

// Do nothing; Inherited from ScheduleRuleNode
void InitializeWithTuneContext(const TuneContext& context) final {}
void InitializeWithTuneContext(const TuneContext& context) final {
if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("max_threads_per_block")) {
this->max_threads_per_block_ = v.value()->value;
if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("thread_warp_size")) {
this->thread_warp_size_ = v.value()->value;
} else {
LOG(INFO) << "'thread_warp_size' is not defined in the target";
}
}
}

// Entry of the mega rule; Inherited from ScheduleRuleNode
Array<Schedule> Apply(const Schedule& sch, const BlockRV& block_rv) final {
if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) {
Expand Down Expand Up @@ -331,6 +341,10 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
std::vector<int> s_indices_;
/*! \brief The indices of reduction tiles in `structure` */
std::vector<int> r_indices_;
/*! \brief The size of the thread warp */
int thread_warp_size_;
/*! \brief The maximum number of threads to be used size of a thread warp */
int max_threads_per_block_;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("structure", &structure);
Expand All @@ -342,6 +356,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
// `reuse_write_` is not visited
// `s_indices_` is not visited
// `r_indices_` is not visited
// `thread_warp_size_` is not visited
// `max_threads_per_block` is not visited
}

static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling";
Expand Down Expand Up @@ -419,19 +435,27 @@ inline std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const
std::vector<IterVarType> iter_types = GetBlockVarTypes(sch->GetSRef(state.block_rv));
ICHECK_EQ(loops.size(), iter_types.size());
// Step 2. For each loop axis, tile it
int64_t spatial_loop_product = 1;
std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size());
for (int i = 0, n = loops.size(); i < n; ++i) {
LoopRV loop = loops[i];
const std::vector<int>* idx = nullptr;
if (iter_types[i] == IterVarType::kDataPar) {
idx = &s_indices_;
if (spatial_loop_product != -1) {
if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) {
spatial_loop_product *= *extent;
} else {
spatial_loop_product = -1;
}
}
} else if (iter_types[i] == IterVarType::kCommReduce) {
idx = &r_indices_;
} else {
continue;
}
// Do the split
int n_tiles = idx->size();
LoopRV loop = loops[i];
Array<ExprRV> factors = sch->SamplePerfectTile(
/*loop=*/loop,
/*n=*/n_tiles,
Expand All @@ -453,6 +477,17 @@ inline std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const
tiles[i] = {fused};
}
state.tiles = Array<Array<LoopRV>>{tiles.begin(), tiles.end()};
if (this->thread_warp_size_ != -1) {
int64_t low_inclusive = 1;
int64_t high_inclusive = this->max_threads_per_block_;
if (spatial_loop_product > 2 * this->thread_warp_size_) {
low_inclusive = this->thread_warp_size_;
}
sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_low_inclusive,
Integer(low_inclusive));
sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_high_inclusive,
Integer(high_inclusive));
}
return {state};
}

Expand Down Expand Up @@ -578,6 +613,8 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<Str
LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure;
}
}
n->thread_warp_size_ = -1;
n->max_threads_per_block_ = -1;
return ScheduleRule(n);
}

Expand Down
8 changes: 4 additions & 4 deletions src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ class PostOrderApplyNode : public SpaceGeneratorNode {

Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod_) final {
using ScheduleAndUnvisitedBlocks = std::pair<tir::Schedule, Array<tir::BlockRV>>;
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/mod_, //
/*rand_state=*/ForkSeed(&this->rand_state_), //
/*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags, //
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/mod_, //
/*rand_state=*/ForkSeed(&this->rand_state_), //
/*debug_mode=*/0, //
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);

std::vector<ScheduleAndUnvisitedBlocks> stack;
Expand Down
10 changes: 4 additions & 6 deletions tests/python/meta_schedule/run_ansor_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ RPC_HOST="192.168.6.66"
RPC_PORT="4445"
RPC_KEY="jetson-agx-xavier"
TARGET="nvidia/jetson-agx-xavier"
NUM_TRIALS=800
LOG_DIR=$HOME/logs/ansor-cuda/
NUM_TRIALS=2000

mkdir -p $LOG_DIR

Expand All @@ -23,19 +23,17 @@ run () {
2>&1 | tee "$LOG_DIR/$name.log"
}

# Single op
run C1D
run C2D
run C3D
run CAP
run DEP
run DIL
run GMM
run GRP
run NRM
run SFM
run T2D
# Subgraph
run C2d-BN-RELU
run TBG

run C3D
run NRM
run SFM
10 changes: 4 additions & 6 deletions tests/python/meta_schedule/run_meta_schedule_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ RPC_HOST="192.168.6.66"
RPC_PORT="4445"
RPC_KEY="jetson-agx-xavier"
TARGET="nvidia/jetson-agx-xavier"
LOG_DIR=/tmp/logs/ms-cuda/
LOG_DIR=$HOME/logs/ms-cuda/
NUM_TRIALS=2000

mkdir -p $LOG_DIR
Expand All @@ -25,19 +25,17 @@ run () {
2>&1 | tee "$work_dir/$name.log"
}

# Single op
run C1D
run C2D
run C3D
run CAP
run DEP
run DIL
run GMM
run GRP
run NRM
run SFM
run T2D
# Subgraph
run C2d-BN-RELU
run TBG

run C3D
run NRM
run SFM
Loading

0 comments on commit 2f08d20

Please sign in to comment.