Skip to content

Commit

Permalink
[MetaSchedule][M4a] Mutator: Mutate-Tile-Size (#10092)
Browse files Browse the repository at this point in the history
* [MetaSchedule][M4a] Mutator: Mutate-Tile-Size

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>

* Python 3.8 has no `math.prod`

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
7 people authored Jan 29, 2022
1 parent 6a274af commit 4d0dac3
Show file tree
Hide file tree
Showing 5 changed files with 399 additions and 1 deletion.
2 changes: 1 addition & 1 deletion include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class PyMutatorNode : public MutatorNode {
*/
class Mutator : public runtime::ObjectRef {
public:
/*! \brief Create a Mutator that mutates the tile size. */
/*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */
TVM_DLL static Mutator MutateTileSize();
/*!
* \brief Create a Mutator that mutates the parallel extent
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/mutator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
"""
from .mutator import Mutator, PyMutator
from .mutate_compute_location import MutateComputeLocation
from .mutate_tile_size import MutateTileSize
from .mutate_parallel import MutateParallel
from .mutate_unroll import MutateUnroll
31 changes: 31 additions & 0 deletions python/tvm/meta_schedule/mutator/mutate_tile_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Mutator that mutates the decision of instruction Sample-Perfect-Tile"""
from tvm._ffi.registry import register_object

from .. import _ffi_api
from .mutator import Mutator


@register_object("meta_schedule.MutateTileSize")
class MutateTileSize(Mutator):
"""Mutator that mutates the decision of instruction Sample-Perfect-Tile"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.MutatorMutateTileSize, # type: ignore # pylint: disable=no-member
)
273 changes: 273 additions & 0 deletions src/meta_schedule/mutator/mutate_tile_size.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <mutex>
#include <unordered_map>

#include "../utils.h"

namespace tvm {
namespace meta_schedule {

using tir::Instruction;
using tir::InstructionKind;
using tir::Trace;

/*!
* \brief Downcast the decision of Sample-Perfect-Tile to an array of integers
* \param decision The decision of Sample-Perfect-Tile
* \return The result of downcast
*/
std::vector<int64_t> DowncastTilingDecision(const ObjectRef& decision) {
const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode);
return support::AsVector<ObjectRef, int64_t>(GetRef<Array<ObjectRef>>(arr));
}

/*!
* \brief Calculate the product of elements in an array
* \param array The array
* \return The product of elements in the array
*/
int64_t Product(const std::vector<int64_t>& array) {
int64_t result = 1;
for (int64_t x : array) {
result *= x;
}
return result;
}

/*! \brief A mutator that mutates the decision of instruction Sample-Perfect-Tile */
class MutateTileSizeNode : public MutatorNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "meta_schedule.MutateTileSize";
TVM_DECLARE_FINAL_OBJECT_INFO(MutateTileSizeNode, MutatorNode);

public:
// Inherit from `MutatorNode`
void InitializeWithTuneContext(const TuneContext& context) final {}
// Inherit from `MutatorNode`
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
};

/*!
* \brief Find the Sample-Perfect-Tile instructions and their decisions in the trace
* \param trace The trace
* \param inst The instructions found
* \param decision The decisions of the instructions found
*/
void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst,
std::vector<std::vector<int64_t>>* decision) {
static const InstructionKind& inst_sample_perfect_tile =
InstructionKind::Get("SamplePerfectTile");
std::vector<Instruction>& instructions = *inst;
std::vector<std::vector<int64_t>>& decisions = *decision;
instructions.reserve(trace->decisions.size());
decisions.reserve(trace->decisions.size());
for (const auto& kv : trace->decisions) {
const Instruction& inst = kv.first;
const ObjectRef& decision = kv.second;
if (inst->kind.same_as(inst_sample_perfect_tile)) {
std::vector<int64_t> tiles = DowncastTilingDecision(decision);
if (tiles.size() >= 2 && Product(tiles) >= 2) {
instructions.push_back(inst);
decisions.push_back(tiles);
}
}
}
}

/*!
* \brief Find all Sample-Categorical instructions (and their decisions) whose outputs are used for
* cooperative fetch annotation
* \param trace The trace
* \param inst The instructions found
* \param decision The decisions of the instructions found
*/
void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst,
std::vector<int64_t>* decision) {
static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical");
static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
std::vector<Instruction>& instructions = *inst;
std::vector<int64_t>& decisions = *decision;
std::unordered_set<const Object*> annotated;
instructions.reserve(trace->decisions.size());
decisions.reserve(trace->decisions.size());
annotated.reserve(trace->decisions.size());
// Find annotation with `meta_schedule_cooperative_fetch`
for (const Instruction& inst : trace->insts) {
if (inst->kind.same_as(inst_annotate)) {
ICHECK_EQ(inst->attrs.size(), 1);
ICHECK_EQ(inst->inputs.size(), 2);
if (Downcast<String>(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) {
const auto* ann_val = inst->inputs[1].as<tir::ExprRVNode>();
ICHECK(ann_val);
annotated.insert(ann_val);
}
}
}
// Find sampling instruction that generates the annotation
for (const auto& kv : trace->decisions) {
const Instruction& inst = kv.first;
const ObjectRef& decision = kv.second;
if (inst->kind.same_as(inst_sample_categorical)) {
ICHECK_EQ(inst->outputs.size(), 1);
if (annotated.count(inst->outputs[0].get())) {
const auto* d = TVM_TYPE_AS(d, decision, IntImmNode);
instructions.push_back(inst);
decisions.push_back(d->value);
}
}
}
}

struct FactorMemo {
/*!
* \brief Find all factors of the input integer
* \param n The integer to be factorized
* \return The factors of the input integer
*/
static std::vector<int> Factorize(int n) {
if (const std::vector<int>* result = Global()->Query(n)) {
return *result;
}
std::vector<int> result;
for (int64_t i = 1; i * i < n; ++i) {
if (n % i == 0) {
result.push_back(i);
if (i * i != n) {
result.push_back(n / i);
}
}
}
std::sort(result.begin(), result.end());
Global()->Add(n, result);
return result;
}

private:
const std::vector<int>* Query(int n) {
std::unique_lock<std::mutex> lock(mutex);
auto it = memo.find(n);
if (it != memo.end()) {
return &it->second;
}
return nullptr;
}

void Add(int n, std::vector<int> result) {
std::unique_lock<std::mutex> lock(mutex);
memo.emplace(n, std::move(result));
}

static FactorMemo* Global() {
static FactorMemo singleton;
return &singleton;
}

std::unordered_map<int, std::vector<int>> memo;
std::mutex mutex;
};

Optional<Trace> MutateSampleTileSize(const Trace& trace, Instruction inst,
std::vector<int64_t> tiles, TRandState* rand_state) {
int n_splits = tiles.size();
// Step 1. Choose two loops, `x` and `y`
int x, y;
// select source
while (true) {
x = tir::SampleInt(rand_state, 0, n_splits);
if (tiles[x] <= 1) {
continue;
}
y = tir::SampleInt(rand_state, 0, n_splits - 1);
if (y >= x) {
++y;
}
std::vector<int> factors = FactorMemo::Factorize(tiles[x]);
// Step 2. Choose the divide factor
int64_t divide_factor;
if (y != n_splits - 1) {
divide_factor = factors[tir::SampleInt(rand_state, 1, factors.size())];
} else {
int64_t limit = Downcast<Integer>(inst->attrs[1])->value;
int max_factor_index = static_cast<int>(factors.size()) - 1;
for (; max_factor_index >= 1; max_factor_index--) {
if (factors[max_factor_index] * tiles[y] <= limit) {
break;
}
}
if (max_factor_index == 0) {
if (n_splits <= 2) {
return NullOpt;
}
// Failed on this dst_idx, try next one.
continue;
}
divide_factor = factors[tir::SampleInt(rand_state, 1, max_factor_index + 1)];
}
tiles[x] /= divide_factor;
tiles[y] *= divide_factor;
return trace->WithDecision(inst, support::AsArray<int64_t, ObjectRef>(tiles),
/*remove_postproc=*/true);
}
}

Optional<Trace> MutateSampleVectorize(const Trace& trace, Instruction inst,
int64_t original_decision, TRandState* rand_state) {
ICHECK_EQ(inst->attrs.size(), 2);
std::vector<double> probs =
support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(inst->attrs[1]));
probs.erase(probs.begin() + original_decision);
int result = tir::MakeMultinomialSampler(rand_state, probs)();
if (result >= original_decision) {
result += 1;
}
return trace->WithDecision(inst, Integer(result), /*remove_postproc=*/true);
}

Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) {
std::vector<Instruction> sample_perfect_tile_insts;
std::vector<Instruction> sample_vectorize_insts;
std::vector<std::vector<int64_t>> sample_perfect_tile_tiles;
std::vector<int64_t> sample_vectorize_decisions;
FindSamplePerfectTile(trace, &sample_perfect_tile_insts, &sample_perfect_tile_tiles);
FindSampleVectorize(trace, &sample_vectorize_insts, &sample_vectorize_decisions);
int size_a = sample_perfect_tile_insts.size();
int size_b = sample_vectorize_insts.size();
if (size_a == 0 && size_b == 0) {
return NullOpt;
}
int n = tir::SampleInt(rand_state, 0, size_a + size_b);
if (n < size_a) {
return MutateSampleTileSize(trace, sample_perfect_tile_insts[n], sample_perfect_tile_tiles[n],
rand_state);
} else {
n -= size_a;
return MutateSampleVectorize(trace, sample_vectorize_insts[n], sample_vectorize_decisions[n],
rand_state);
}
}

Mutator Mutator::MutateTileSize() { return Mutator(make_object<MutateTileSizeNode>()); }

TVM_REGISTER_NODE_TYPE(MutateTileSizeNode);
TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize").set_body_typed(Mutator::MutateTileSize);

} // namespace meta_schedule
} // namespace tvm
Loading

0 comments on commit 4d0dac3

Please sign in to comment.