Skip to content

Commit

Permalink
[MetaSchedule] Mutator: Mutate compute location (apache#10028)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
7 people authored and yuanfz98 committed Jan 24, 2022
1 parent 8728fb1 commit 985f4cc
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/mutator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
design space.
"""
from .mutator import Mutator, PyMutator
from .mutate_compute_location import MutateComputeLocation
31 changes: 31 additions & 0 deletions python/tvm/meta_schedule/mutator/mutate_compute_location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A mutator that mutates the compute-at location decision of SampleComputeLocation"""
from tvm._ffi.registry import register_object

from .. import _ffi_api
from .mutator import Mutator


@register_object("meta_schedule.MutateComputeLocation")
class MutateComputeLocation(Mutator):
"""A mutator that mutates the compute-at location decision of SampleComputeLocation"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.MutatorMutateComputeLocation, # type: ignore # pylint: disable=no-member
)
131 changes: 131 additions & 0 deletions src/meta_schedule/mutator/mutate_compute_location.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

using tir::Instruction;
using tir::InstructionKind;
using tir::Trace;

/*! \brief A mutator that mutates the compute-at location decision of SampleComputeLocation */
class MutateComputeLocationNode : public MutatorNode {
public:
/*! \brief JSON representation of the workload */
std::string json_mod_;

void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "meta_schedule.MutateComputeLocation";
TVM_DECLARE_FINAL_OBJECT_INFO(MutateComputeLocationNode, MutatorNode);

public:
// Inherit from `MutatorNode`
void InitializeWithTuneContext(const TuneContext& context) final {
this->json_mod_ = SaveJSON(context->mod.value());
}
// Inherit from `MutatorNode`
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;

private:
struct Candidate {
/*! \brief The SampleComputeLocation instruction */
Instruction inst;
/*! \brief The candidate compute-at locations */
std::vector<int> locs;

explicit Candidate(Instruction inst, std::vector<int> locs)
: inst(std::move(inst)), locs(std::move(locs)) {}
};

std::vector<Candidate> FindCandidates(const Trace& trace, TRandState* rand_state);
};

/*!
* \brief Find all appearances of instruction `SampleComputeLocation` whose decision can be mutated
* to at lease one other value
* \param trace The trace from which to find the instructions
* \return All the candidate instructions together with the candidate compute-at locations
*/
std::vector<MutateComputeLocationNode::Candidate> MutateComputeLocationNode::FindCandidates(
const Trace& trace, TRandState* rand_state) {
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)), //
/*rand_state=*/ForkSeed(rand_state), //
/*debug_mode=*/0, //
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);

static InstructionKind inst_sample_compute_location =
InstructionKind::Get("SampleComputeLocation");
std::vector<MutateComputeLocationNode::Candidate> candidates;

auto f_decision_provider = [&](const tir::Instruction& inst, //
const Array<ObjectRef>& inputs, //
const Array<ObjectRef>& attrs, //
const ObjectRef& decision) -> ObjectRef {
if (inst->kind.same_as(inst_sample_compute_location)) {
// Step 1. Extract the instruction input and the old decision.
ICHECK_EQ(inputs.size(), 1);
tir::StmtSRef block_sref = sch->GetSRef(Downcast<tir::BlockRV>(inputs[0]));
int old_decision = Downcast<Integer>(decision)->value;

// Step 2. Collect all the compute_at locations.
Array<tir::StmtSRef> location_srefs;
std::vector<int> location_indices;
std::tie(location_srefs, location_indices) = CollectComputeLocation(sch->state(), block_sref);
// Step 3. Remove the old decision.
auto it = std::find(location_indices.begin(), location_indices.end(), old_decision);
if (it != location_indices.end()) {
location_srefs.erase(location_srefs.begin() + (it - location_indices.begin()));
location_indices.erase(it);
}
ICHECK_EQ(location_srefs.size(), location_indices.size());
// Step 4. Add a new candidate if there are at least one remaining compute-at position.
if (!location_srefs.empty()) {
candidates.emplace_back(inst, std::move(location_indices));
}
}
return decision;
};
trace->ApplyToSchedule(sch, //
/*remove_postproc=*/true, //
/*decision_provider=*/f_decision_provider);
return candidates;
}

Optional<Trace> MutateComputeLocationNode::Apply(const Trace& trace, TRandState* rand_state) {
std::vector<Candidate> candidates = FindCandidates(trace, rand_state);
if (candidates.empty()) {
return NullOpt;
}
const Candidate& candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())];
int loc = candidate.locs[tir::SampleInt(rand_state, 0, candidate.locs.size())];
return trace->WithDecision(candidate.inst, Integer(loc), /*remove_postproc=*/true);
}

Mutator Mutator::MutateComputeLocation() {
return Mutator(make_object<MutateComputeLocationNode>());
}

TVM_REGISTER_NODE_TYPE(MutateComputeLocationNode);
TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation")
.set_body_typed(Mutator::MutateComputeLocation);

} // namespace meta_schedule
} // namespace tvm
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.mutator import MutateComputeLocation, Mutator
from tvm.script import tir as T
from tvm.target import Target
from tvm.tir import Schedule

# pylint: disable=invalid-name, no-member


@T.prim_func
def add(a: T.handle, b: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, [2048, 2048, 2048], dtype="float32")
B = T.match_buffer(b, [2048, 2048, 2048], dtype="float32")
A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32")
# body
for i, j, k in T.grid(2048, 2048, 2048):
with T.block("move"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
T.reads([A[vi, vj, vk]])
T.writes([A_cached[vi, vj, vk]])
A_cached[vi, vj, vk] = A[vi, vj, vk]
for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32):
with T.block("add"):
vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2)
vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2)
vk = T.axis.spatial(2048, k0 * 32 + k1)
T.reads([A_cached[vi, vj, vk]])
T.writes([B[vi, vj, vk]])
B[vi, vj, vk] = A_cached[vi, vj, vk] + T.float32(1)


# pylint: enable=invalid-name, no-member


def _sch(decision: int) -> Schedule:
sch = Schedule(add, debug_mask="all")
# pylint: disable=invalid-name
b0 = sch.get_block(name="move", func_name="main")
l1 = sch.sample_compute_location(block=b0, decision=decision)
sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True)
# pylint: enable=invalid-name
return sch


def _make_mutator(target: Target) -> Mutator:
mutator = MutateComputeLocation()
mutator.initialize_with_tune_context(TuneContext(mod=add, target=target))
return mutator


def test_mutate_compute_location_add():
mutator = _make_mutator(
target=Target("llvm"),
)
sch = _sch(decision=4)
results = set()
for _ in range(100):
trace = mutator.apply(sch.trace)
decision = trace.decisions[trace.insts[-2]]
assert not decision == 4
results.add(decision)
assert len(results) == 9


if __name__ == "__main__":
test_mutate_compute_location_add()

0 comments on commit 985f4cc

Please sign in to comment.