diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index e3fa847c3748..002fa51ee5e3 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -111,7 +111,7 @@ class PyMutatorNode : public MutatorNode { */ class Mutator : public runtime::ObjectRef { public: - /*! \brief Create a Mutator that mutates the tile size. */ + /*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */ TVM_DLL static Mutator MutateTileSize(); /*! * \brief Create a Mutator that mutates the parallel extent diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py index af3485b679f1..e534ba14346e 100644 --- a/python/tvm/meta_schedule/mutator/__init__.py +++ b/python/tvm/meta_schedule/mutator/__init__.py @@ -21,5 +21,6 @@ """ from .mutator import Mutator, PyMutator from .mutate_compute_location import MutateComputeLocation +from .mutate_tile_size import MutateTileSize from .mutate_parallel import MutateParallel from .mutate_unroll import MutateUnroll diff --git a/python/tvm/meta_schedule/mutator/mutate_tile_size.py b/python/tvm/meta_schedule/mutator/mutate_tile_size.py new file mode 100644 index 000000000000..ff432a6633b9 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_tile_size.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Mutator that mutates the decision of instruction Sample-Perfect-Tile""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateTileSize") +class MutateTileSize(Mutator): + """Mutator that mutates the decision of instruction Sample-Perfect-Tile""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.MutatorMutateTileSize, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc new file mode 100644 index 000000000000..6e034886bdb5 --- /dev/null +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::InstructionKind; +using tir::Trace; + +/*! + * \brief Downcast the decision of Sample-Perfect-Tile to an array of integers + * \param decision The decision of Sample-Perfect-Tile + * \return The result of downcast + */ +std::vector DowncastTilingDecision(const ObjectRef& decision) { + const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode); + return support::AsVector(GetRef>(arr)); +} + +/*! + * \brief Calculate the product of elements in an array + * \param array The array + * \return The product of elements in the array + */ +int64_t Product(const std::vector& array) { + int64_t result = 1; + for (int64_t x : array) { + result *= x; + } + return result; +} + +/*! \brief A mutator that mutates the decision of instruction Sample-Perfect-Tile */ +class MutateTileSizeNode : public MutatorNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "meta_schedule.MutateTileSize"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateTileSizeNode, MutatorNode); + + public: + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; +}; + +/*! + * \brief Find the Sample-Perfect-Tile instructions and their decisions in the trace + * \param trace The trace + * \param inst The instructions found + * \param decision The decisions of the instructions found + */ +void FindSamplePerfectTile(const Trace& trace, std::vector* inst, + std::vector>* decision) { + static const InstructionKind& inst_sample_perfect_tile = + InstructionKind::Get("SamplePerfectTile"); + std::vector& instructions = *inst; + std::vector>& decisions = *decision; + instructions.reserve(trace->decisions.size()); + decisions.reserve(trace->decisions.size()); + for (const auto& kv : trace->decisions) { + const Instruction& inst = kv.first; + const ObjectRef& decision = kv.second; + if (inst->kind.same_as(inst_sample_perfect_tile)) { + std::vector tiles = DowncastTilingDecision(decision); + if (tiles.size() >= 2 && Product(tiles) >= 2) { + instructions.push_back(inst); + decisions.push_back(tiles); + } + } + } +} + +/*! + * \brief Find all Sample-Categorical instructions (and their decisions) whose outputs are used for + * cooperative fetch annotation + * \param trace The trace + * \param inst The instructions found + * \param decision The decisions of the instructions found + */ +void FindSampleVectorize(const Trace& trace, std::vector* inst, + std::vector* decision) { + static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical"); + static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate"); + std::vector& instructions = *inst; + std::vector& decisions = *decision; + std::unordered_set annotated; + instructions.reserve(trace->decisions.size()); + decisions.reserve(trace->decisions.size()); + annotated.reserve(trace->decisions.size()); + // Find annotation with `meta_schedule_cooperative_fetch` + for (const Instruction& inst : trace->insts) { + if (inst->kind.same_as(inst_annotate)) { + ICHECK_EQ(inst->attrs.size(), 1); + ICHECK_EQ(inst->inputs.size(), 2); + if (Downcast(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) { + const auto* ann_val = inst->inputs[1].as(); + ICHECK(ann_val); + annotated.insert(ann_val); + } + } + } + // Find sampling instruction that generates the annotation + for (const auto& kv : trace->decisions) { + const Instruction& inst = kv.first; + const ObjectRef& decision = kv.second; + if (inst->kind.same_as(inst_sample_categorical)) { + ICHECK_EQ(inst->outputs.size(), 1); + if (annotated.count(inst->outputs[0].get())) { + const auto* d = TVM_TYPE_AS(d, decision, IntImmNode); + instructions.push_back(inst); + decisions.push_back(d->value); + } + } + } +} + +struct FactorMemo { + /*! + * \brief Find all factors of the input integer + * \param n The integer to be factorized + * \return The factors of the input integer + */ + static std::vector Factorize(int n) { + if (const std::vector* result = Global()->Query(n)) { + return *result; + } + std::vector result; + for (int64_t i = 1; i * i < n; ++i) { + if (n % i == 0) { + result.push_back(i); + if (i * i != n) { + result.push_back(n / i); + } + } + } + std::sort(result.begin(), result.end()); + Global()->Add(n, result); + return result; + } + + private: + const std::vector* Query(int n) { + std::unique_lock lock(mutex); + auto it = memo.find(n); + if (it != memo.end()) { + return &it->second; + } + return nullptr; + } + + void Add(int n, std::vector result) { + std::unique_lock lock(mutex); + memo.emplace(n, std::move(result)); + } + + static FactorMemo* Global() { + static FactorMemo singleton; + return &singleton; + } + + std::unordered_map> memo; + std::mutex mutex; +}; + +Optional MutateSampleTileSize(const Trace& trace, Instruction inst, + std::vector tiles, TRandState* rand_state) { + int n_splits = tiles.size(); + // Step 1. Choose two loops, `x` and `y` + int x, y; + // select source + while (true) { + x = tir::SampleInt(rand_state, 0, n_splits); + if (tiles[x] <= 1) { + continue; + } + y = tir::SampleInt(rand_state, 0, n_splits - 1); + if (y >= x) { + ++y; + } + std::vector factors = FactorMemo::Factorize(tiles[x]); + // Step 2. Choose the divide factor + int64_t divide_factor; + if (y != n_splits - 1) { + divide_factor = factors[tir::SampleInt(rand_state, 1, factors.size())]; + } else { + int64_t limit = Downcast(inst->attrs[1])->value; + int max_factor_index = static_cast(factors.size()) - 1; + for (; max_factor_index >= 1; max_factor_index--) { + if (factors[max_factor_index] * tiles[y] <= limit) { + break; + } + } + if (max_factor_index == 0) { + if (n_splits <= 2) { + return NullOpt; + } + // Failed on this dst_idx, try next one. + continue; + } + divide_factor = factors[tir::SampleInt(rand_state, 1, max_factor_index + 1)]; + } + tiles[x] /= divide_factor; + tiles[y] *= divide_factor; + return trace->WithDecision(inst, support::AsArray(tiles), + /*remove_postproc=*/true); + } +} + +Optional MutateSampleVectorize(const Trace& trace, Instruction inst, + int64_t original_decision, TRandState* rand_state) { + ICHECK_EQ(inst->attrs.size(), 2); + std::vector probs = + support::AsVector(Downcast>(inst->attrs[1])); + probs.erase(probs.begin() + original_decision); + int result = tir::MakeMultinomialSampler(rand_state, probs)(); + if (result >= original_decision) { + result += 1; + } + return trace->WithDecision(inst, Integer(result), /*remove_postproc=*/true); +} + +Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) { + std::vector sample_perfect_tile_insts; + std::vector sample_vectorize_insts; + std::vector> sample_perfect_tile_tiles; + std::vector sample_vectorize_decisions; + FindSamplePerfectTile(trace, &sample_perfect_tile_insts, &sample_perfect_tile_tiles); + FindSampleVectorize(trace, &sample_vectorize_insts, &sample_vectorize_decisions); + int size_a = sample_perfect_tile_insts.size(); + int size_b = sample_vectorize_insts.size(); + if (size_a == 0 && size_b == 0) { + return NullOpt; + } + int n = tir::SampleInt(rand_state, 0, size_a + size_b); + if (n < size_a) { + return MutateSampleTileSize(trace, sample_perfect_tile_insts[n], sample_perfect_tile_tiles[n], + rand_state); + } else { + n -= size_a; + return MutateSampleVectorize(trace, sample_vectorize_insts[n], sample_vectorize_decisions[n], + rand_state); + } +} + +Mutator Mutator::MutateTileSize() { return Mutator(make_object()); } + +TVM_REGISTER_NODE_TYPE(MutateTileSizeNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize").set_body_typed(Mutator::MutateTileSize); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py new file mode 100644 index 000000000000..9e75497b6cc2 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import operator +from functools import reduce +from typing import List + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateTileSize, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [512, 512]) + B = T.match_buffer(b, [512, 512]) + C = T.match_buffer(c, [512, 512]) + for i, j, k in T.grid(512, 512, 512): # type: ignore + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore + with T.init(): + C[vi, vj] = 0.0 # type: ignore + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=invalid-name, no-member + + +def _sch(decisions: List[List[int]]) -> Schedule: + sch = Schedule(matmul, debug_mask="all") + # pylint: disable=invalid-name + (d0,) = decisions + b0 = sch.get_block(name="C", func_name="main") + sch.get_consumers(block=b0) + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, + n=4, + max_innermost_factor=64, + decision=d0, + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) + l17, l18, l19, l20 = sch.split(loop=l3, factors=[8, 4, 8, 2]) + l23, l24 = sch.split(loop=l4, factors=[512, 1]) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True) + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target) -> Mutator: + mutator = MutateTileSize() + mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) + return mutator + + +def test_mutate_tile_size_matmul(): + mutator = _make_mutator( + target=Target("llvm --num-cores=16"), + ) + results = {} + sch = _sch(decisions=[[4, 32, 4, 1]]) + for _ in range(100): + trace = mutator.apply(sch.trace) + assert trace.insts[4].kind.name == "SamplePerfectTile" + decision = trace.decisions[trace.insts[4]] + decision = [int(x) for x in decision] + results[str(decision)] = decision + assert reduce(operator.mul, decision, 1) == 512 + assert len(results) > 15 + + +if __name__ == "__main__": + test_mutate_tile_size_matmul()