diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py index f88043b4b4fd..f232566785d9 100644 --- a/python/tvm/meta_schedule/mutator/__init__.py +++ b/python/tvm/meta_schedule/mutator/__init__.py @@ -20,3 +20,4 @@ design space. """ from .mutator import Mutator, PyMutator +from .mutate_compute_location import MutateComputeLocation diff --git a/python/tvm/meta_schedule/mutator/mutate_compute_location.py b/python/tvm/meta_schedule/mutator/mutate_compute_location.py new file mode 100644 index 000000000000..bb361247bf62 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_compute_location.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A mutator that mutates the compute-at location decision of SampleComputeLocation""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateComputeLocation") +class MutateComputeLocation(Mutator): + """A mutator that mutates the compute-at location decision of SampleComputeLocation""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.MutatorMutateComputeLocation, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc new file mode 100644 index 000000000000..3ed56df1b381 --- /dev/null +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::InstructionKind; +using tir::Trace; + +/*! \brief A mutator that mutates the compute-at location decision of SampleComputeLocation */ +class MutateComputeLocationNode : public MutatorNode { + public: + /*! \brief JSON representation of the workload */ + std::string json_mod_; + + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "meta_schedule.MutateComputeLocation"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateComputeLocationNode, MutatorNode); + + public: + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final { + this->json_mod_ = SaveJSON(context->mod.value()); + } + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; + + private: + struct Candidate { + /*! \brief The SampleComputeLocation instruction */ + Instruction inst; + /*! \brief The candidate compute-at locations */ + std::vector locs; + + explicit Candidate(Instruction inst, std::vector locs) + : inst(std::move(inst)), locs(std::move(locs)) {} + }; + + std::vector FindCandidates(const Trace& trace, TRandState* rand_state); +}; + +/*! + * \brief Find all appearances of instruction `SampleComputeLocation` whose decision can be mutated + * to at lease one other value + * \param trace The trace from which to find the instructions + * \return All the candidate instructions together with the candidate compute-at locations + */ +std::vector MutateComputeLocationNode::FindCandidates( + const Trace& trace, TRandState* rand_state) { + tir::Schedule sch = tir::Schedule::Traced( // + /*mod=*/Downcast(LoadJSON(this->json_mod_)), // + /*rand_state=*/ForkSeed(rand_state), // + /*debug_mode=*/0, // + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + + static InstructionKind inst_sample_compute_location = + InstructionKind::Get("SampleComputeLocation"); + std::vector candidates; + + auto f_decision_provider = [&](const tir::Instruction& inst, // + const Array& inputs, // + const Array& attrs, // + const ObjectRef& decision) -> ObjectRef { + if (inst->kind.same_as(inst_sample_compute_location)) { + // Step 1. Extract the instruction input and the old decision. + ICHECK_EQ(inputs.size(), 1); + tir::StmtSRef block_sref = sch->GetSRef(Downcast(inputs[0])); + int old_decision = Downcast(decision)->value; + + // Step 2. Collect all the compute_at locations. + Array location_srefs; + std::vector location_indices; + std::tie(location_srefs, location_indices) = CollectComputeLocation(sch->state(), block_sref); + // Step 3. Remove the old decision. + auto it = std::find(location_indices.begin(), location_indices.end(), old_decision); + if (it != location_indices.end()) { + location_srefs.erase(location_srefs.begin() + (it - location_indices.begin())); + location_indices.erase(it); + } + ICHECK_EQ(location_srefs.size(), location_indices.size()); + // Step 4. Add a new candidate if there are at least one remaining compute-at position. + if (!location_srefs.empty()) { + candidates.emplace_back(inst, std::move(location_indices)); + } + } + return decision; + }; + trace->ApplyToSchedule(sch, // + /*remove_postproc=*/true, // + /*decision_provider=*/f_decision_provider); + return candidates; +} + +Optional MutateComputeLocationNode::Apply(const Trace& trace, TRandState* rand_state) { + std::vector candidates = FindCandidates(trace, rand_state); + if (candidates.empty()) { + return NullOpt; + } + const Candidate& candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; + int loc = candidate.locs[tir::SampleInt(rand_state, 0, candidate.locs.size())]; + return trace->WithDecision(candidate.inst, Integer(loc), /*remove_postproc=*/true); +} + +Mutator Mutator::MutateComputeLocation() { + return Mutator(make_object()); +} + +TVM_REGISTER_NODE_TYPE(MutateComputeLocationNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation") + .set_body_typed(Mutator::MutateComputeLocation); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py new file mode 100644 index 000000000000..20a977189da5 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateComputeLocation, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def add(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [2048, 2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048, 2048], dtype="float32") + A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32") + # body + for i, j, k in T.grid(2048, 2048, 2048): + with T.block("move"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + T.reads([A[vi, vj, vk]]) + T.writes([A_cached[vi, vj, vk]]) + A_cached[vi, vj, vk] = A[vi, vj, vk] + for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): + with T.block("add"): + vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(2048, k0 * 32 + k1) + T.reads([A_cached[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A_cached[vi, vj, vk] + T.float32(1) + + +# pylint: enable=invalid-name, no-member + + +def _sch(decision: int) -> Schedule: + sch = Schedule(add, debug_mask="all") + # pylint: disable=invalid-name + b0 = sch.get_block(name="move", func_name="main") + l1 = sch.sample_compute_location(block=b0, decision=decision) + sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True) + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target) -> Mutator: + mutator = MutateComputeLocation() + mutator.initialize_with_tune_context(TuneContext(mod=add, target=target)) + return mutator + + +def test_mutate_compute_location_add(): + mutator = _make_mutator( + target=Target("llvm"), + ) + sch = _sch(decision=4) + results = set() + for _ in range(100): + trace = mutator.apply(sch.trace) + decision = trace.decisions[trace.insts[-2]] + assert not decision == 4 + results.add(decision) + assert len(results) == 9 + + +if __name__ == "__main__": + test_mutate_compute_location_add()