Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cooperative fetching #17

Merged
merged 1 commit into from
Jan 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class ScheduleRule : public runtime::ObjectRef {
* - [blockIdx.x, vthread.x, threadIdx.x] on GPU
* \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param vector_load_max_len The length of vector lane in vectorized cooperative fetching.
* \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
* NullOpt means disable vectorization
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
Expand All @@ -151,7 +151,7 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Array<String>> tile_binds, //
bool use_tensor_core, //
Optional<Integer> max_innermost_factor, //
Optional<Integer> vector_load_max_len, //
Optional<Array<Integer>> vector_load_lens, //
Optional<Map<String, ObjectRef>> reuse_read, //
Optional<Map<String, ObjectRef>> reuse_write);
/*!
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class MultiLevelTiling(ScheduleRule):
Whether to apply tensor core wmma intrinsic for the computation
max_innermost_factor : Optional[int]
The maximum size of the innermost factor. None means no limit
vector_load_max_len : Optional[int]
vector_load_lens : Optional[List[int]]
The length of vector lane in vectorized cooperative fetching.
None means disable vectorization
reuse_read : Optional[ReuseType]
Expand All @@ -72,7 +72,7 @@ def __init__(
tile_binds: Optional[List[str]] = None,
use_tensor_core: bool = False,
max_innermost_factor: Optional[int] = None,
vector_load_max_len: Optional[int] = None,
vector_load_lens: Optional[List[int]] = None,
reuse_read: Optional[ReuseType] = None,
reuse_write: Optional[ReuseType] = None,
) -> None:
Expand All @@ -82,7 +82,7 @@ def __init__(
tile_binds,
use_tensor_core,
max_innermost_factor,
vector_load_max_len,
vector_load_lens,
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
structure="SSRSRS",
tile_binds=None,
max_innermost_factor=64,
vector_load_max_len=None,
vector_load_lens=None,
reuse_read=None,
reuse_write=ReuseType(
req="may",
Expand All @@ -124,7 +124,7 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
max_innermost_factor=64,
vector_load_max_len=4,
vector_load_lens=[1, 2, 3, 4],
reuse_read=ReuseType(
req="must",
levels=[4],
Expand All @@ -147,7 +147,7 @@ def multi_level_tiling_tensor_core(target: Target) -> ScheduleRule:
tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"],
use_tensor_core=True,
max_innermost_factor=64,
vector_load_max_len=4,
vector_load_lens=[1, 2, 3, 4],
reuse_read=ReuseType(
req="must",
levels=[4],
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _sch_rules() -> List[ScheduleRule]:
structure="SSRSRS",
tile_binds=None,
max_innermost_factor=64,
vector_load_max_len=None,
vector_load_lens=None,
reuse_read=None,
reuse_write=M.ReuseType(
req="may",
Expand Down Expand Up @@ -154,7 +154,7 @@ def _sch_rules() -> List[ScheduleRule]:
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
max_innermost_factor=64,
vector_load_max_len=4,
vector_load_lens=[1, 2, 3, 4],
reuse_read=M.ReuseType(
req="must",
levels=[4],
Expand Down
108 changes: 84 additions & 24 deletions src/meta_schedule/mutator/mutate_tile_size.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ using tir::Trace;
* \param decision The decision of Sample-Perfect-Tile
* \return The result of downcast
*/
std::vector<int64_t> DowncastDecision(const ObjectRef& decision) {
std::vector<int64_t> DowncastTilingDecision(const ObjectRef& decision) {
const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode);
return support::AsVector<ObjectRef, int64_t>(GetRef<Array<ObjectRef>>(arr));
}
Expand Down Expand Up @@ -73,34 +73,62 @@ class MutateTileSizeNode : public MutatorNode {
* \param decision The decision selected
* \return Whether a decision is found
*/
bool FindSamplePerfectTile(const Trace& trace, TRandState* rand_state, Instruction* inst,
std::vector<int64_t>* decision) {
void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst,
std::vector<std::vector<int64_t>>* decision) {
static const InstructionKind& inst_sample_perfect_tile =
InstructionKind::Get("SamplePerfectTile");
std::vector<Instruction> instructions;
std::vector<std::vector<int64_t>> decisions;
std::vector<Instruction>& instructions = *inst;
std::vector<std::vector<int64_t>>& decisions = *decision;
instructions.reserve(trace->decisions.size());
decisions.reserve(trace->decisions.size());
for (const auto& kv : trace->decisions) {
const Instruction& inst = kv.first;
const ObjectRef& decision = kv.second;
if (!inst->kind.same_as(inst_sample_perfect_tile)) {
continue;
if (inst->kind.same_as(inst_sample_perfect_tile)) {
std::vector<int64_t> tiles = DowncastTilingDecision(decision);
if (tiles.size() >= 2 && Product(tiles) >= 2) {
instructions.push_back(inst);
decisions.push_back(tiles);
}
}
std::vector<int64_t> tiles = DowncastDecision(decision);
if (tiles.size() >= 2 && Product(tiles) >= 2) {
instructions.push_back(inst);
decisions.push_back(tiles);
}
}

void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst,
std::vector<int64_t>* decision) {
static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical");
static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
std::vector<Instruction>& instructions = *inst;
std::vector<int64_t>& decisions = *decision;
std::unordered_set<const Object*> annotated;
instructions.reserve(trace->decisions.size());
decisions.reserve(trace->decisions.size());
annotated.reserve(trace->decisions.size());
// Find annotation with `meta_schedule_cooperative_fetch`
for (const Instruction& inst : trace->insts) {
if (inst->kind.same_as(inst_annotate)) {
ICHECK_EQ(inst->attrs.size(), 1);
ICHECK_EQ(inst->inputs.size(), 2);
if (Downcast<String>(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) {
const auto* ann_val = inst->inputs[1].as<tir::ExprRVNode>();
ICHECK(ann_val);
annotated.insert(ann_val);
}
}
}
int n = instructions.size();
if (n > 0) {
int i = tir::SampleInt(rand_state, 0, n);
*inst = instructions[i];
*decision = decisions[i];
return true;
// Find sampling instruction that generates the annotation
for (const auto& kv : trace->decisions) {
const Instruction& inst = kv.first;
const ObjectRef& decision = kv.second;
if (inst->kind.same_as(inst_sample_categorical)) {
ICHECK_EQ(inst->outputs.size(), 1);
if (annotated.count(inst->outputs[0].get())) {
const auto* d = TVM_TYPE_AS(d, decision, IntImmNode);
instructions.push_back(inst);
decisions.push_back(d->value);
}
}
}
return false;
}

struct FactorMemo {
Expand Down Expand Up @@ -146,12 +174,8 @@ struct FactorMemo {
std::mutex mutex_;
};

Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) {
Instruction inst;
std::vector<int64_t> tiles;
if (!FindSamplePerfectTile(trace, rand_state, &inst, &tiles)) {
return NullOpt;
}
Optional<Trace> MutateSampleTileSize(const Trace& trace, Instruction inst,
std::vector<int64_t> tiles, TRandState* rand_state) {
int n_splits = tiles.size();
// Step 1. Choose two loops, `x` and `y`
int x, y;
Expand Down Expand Up @@ -194,6 +218,42 @@ Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_s
}
}

Optional<Trace> MutateSampleVectorize(const Trace& trace, Instruction inst,
int64_t original_decision, TRandState* rand_state) {
ICHECK_EQ(inst->attrs.size(), 2);
std::vector<double> probs =
support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(inst->attrs[1]));
probs.erase(probs.begin() + original_decision);
int result = tir::MakeMultinomialSampler(rand_state, probs)();
if (result >= original_decision) {
result += 1;
}
return trace->WithDecision(inst, Integer(result), /*remove_postproc=*/true);
}

Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) {
std::vector<Instruction> sample_perfect_tile_insts;
std::vector<Instruction> sample_vectorize_insts;
std::vector<std::vector<int64_t>> sample_perfect_tile_tiles;
std::vector<int64_t> sample_vectorize_decisions;
FindSamplePerfectTile(trace, &sample_perfect_tile_insts, &sample_perfect_tile_tiles);
FindSampleVectorize(trace, &sample_vectorize_insts, &sample_vectorize_decisions);
int size_a = sample_perfect_tile_insts.size();
int size_b = sample_vectorize_insts.size();
if (size_a == 0 && size_b == 0) {
return NullOpt;
}
int n = tir::SampleInt(rand_state, 0, size_a + size_b);
if (n < size_a) {
return MutateSampleTileSize(trace, sample_perfect_tile_insts[n], sample_perfect_tile_tiles[n],
rand_state);
} else {
n -= size_a;
return MutateSampleVectorize(trace, sample_vectorize_insts[n], sample_vectorize_decisions[n],
rand_state);
}
}

Mutator Mutator::MutateTileSize() { return Mutator(make_object<MutateTileSizeNode>()); }

TVM_REGISTER_NODE_TYPE(MutateTileSizeNode);
Expand Down
24 changes: 14 additions & 10 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
/*! \brief The maximum size of the innermost factor */
int max_innermost_factor;
/*! \brief The length of vector lane in vectorized cooperative fetching */
int vector_load_max_len;
std::vector<int> vector_load_lens;
/*! \brief Data reuse configuration for reading */
ReuseConfig reuse_read_;
/*! \brief Data reuse configuration for writing */
Expand All @@ -337,7 +337,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
v->Visit("tile_binds", &tile_binds);
v->Visit("use_tensor_core", &use_tensor_core);
v->Visit("max_innermost_factor", &max_innermost_factor);
v->Visit("vector_load_max_len", &vector_load_max_len);
// `vector_load_lens` is not visited
// `reuse_read_` is not visited
// `reuse_write_` is not visited
// `s_indices_` is not visited
Expand Down Expand Up @@ -491,12 +491,14 @@ inline std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const
LoopRV fused = sch->Fuse(Array<LoopRV>{buffer_loops.end() - buffer_ndim, //
buffer_loops.end()});
// Annotate cooperative fetching
if (vector_load_max_len > 0) {
// cooperative fetch + vectorized loading
// Split into inner and outer, vectorize the inner loop
Array<ExprRV> factors = sch->SamplePerfectTile(fused, 2, vector_load_max_len);
// Add cooperative fetching
sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch, factors[1]);
if (!vector_load_lens.empty()) {
int n = vector_load_lens.size();
double prob = 1.0 / n;
ExprRV vector_load_len =
sch->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch,
vector_load_len);
}
}
State new_state = state;
Expand Down Expand Up @@ -545,7 +547,7 @@ inline std::vector<State> MultiLevelTilingNode::FuseWriteReuse(State state) cons
ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<String>> tile_binds,
bool use_tensor_core,
Optional<Integer> max_innermost_factor,
Optional<Integer> vector_load_max_len,
Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write) {
ObjectPtr<MultiLevelTilingNode> n = make_object<MultiLevelTilingNode>();
Expand All @@ -561,7 +563,9 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<Str
tir::TensorIntrin::Get("wmma_fill");
}
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
n->vector_load_max_len = vector_load_max_len.value_or(Integer(-1))->value;
n->vector_load_lens = vector_load_lens.defined()
? support::AsVector<Integer, int>(vector_load_lens.value())
: std::vector<int>();
n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig();
n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig();
for (int i = 0, len = structure.size(); i < len; ++i) {
Expand Down
Loading