Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Metaschedule] MultiLevelTiling for wide vector architectures (apache…
Browse files Browse the repository at this point in the history
…#12845)

* [Metaschedule] Introduce MultiLevelTiling for wide vector architecture

* update test

* format

* cpplint
  • Loading branch information
masahi authored and xinetzone committed Nov 25, 2022
1 parent c14c837 commit f537d1b
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 12 deletions.
15 changes: 15 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,21 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline);

/*!
* \brief Extension of MultiLevelTiling for backends with wide vectors.
* The loop over the innermost spatial axis of the output buffer is always vectorized with the
* maximum vector length.
* \param structure The tiling structure. 'SSRSRS' is recommended.
* \param vector_length_in_bits The length of a vector register in bits.
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTilingWideVector(
String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Create a rule: add-rfactor to some blocks if needed
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
MultiLevelTilingWithIntrin,
ReuseType,
MultiLevelTilingTensorCore,
MultiLevelTilingWideVector,
)
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
from .random_compute_location import RandomComputeLocation
Expand Down
37 changes: 37 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,40 @@ def __init__(
reuse_write.as_dict() if reuse_write is not None else None,
use_software_pipeline,
)


@register_object("meta_schedule.MultiLevelTilingWideVector")
class MultiLevelTilingWideVector(ScheduleRule):
"""Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost
spatial axis of the output buffer is always vectorized with the maximum vector length.
Parameters
----------
structure : str
The tiling structure. 'SSRSRS' is recommended.
vector_length_in_bits: int
The length of a vector register in bits.
max_innermost_factor : Optional[int]
The maximum size of the innermost factor. None means no limit
reuse_read : Optional[ReuseType]
Data reuse configuration for reading. None means no reuse.
reuse_write : Optional[ReuseType]
Data reuse configuration for writing. None means no reuse.
"""

def __init__(
self,
structure: str,
vector_length_in_bits: int,
max_innermost_factor: Optional[int] = None,
reuse_read: Optional[ReuseType] = None,
reuse_write: Optional[ReuseType] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTilingWideVector, # type: ignore # pylint: disable=no-member
structure,
vector_length_in_bits,
max_innermost_factor,
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)
35 changes: 24 additions & 11 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
return results;
}

Array<tir::LoopRV> MultiLevelTilingNode::SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop,
int n_tiles) const {
Array<tir::ExprRV> factors = sch->SamplePerfectTile(
/*loop=*/loop,
/*n=*/n_tiles,
/*max_innermost_factor=*/max_innermost_factor);
Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop,
/*factors=*/{factors.begin(), factors.end()});
return splits;
}

std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
Schedule& sch = state->sch;
const BlockRV& block_rv = state->block_rv;
Expand All @@ -179,6 +190,7 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
for (int i = 0, n = loops.size(); i < n; ++i) {
LoopRV loop = loops[i];
const std::vector<int>* idx = nullptr;

if (iter_types[i] == IterVarType::kDataPar) {
idx = &s_indices_;
if (spatial_loop_product != -1) {
Expand All @@ -193,17 +205,18 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
} else {
continue;
}
// Do the split
int n_tiles = idx->size();
Array<tir::ExprRV> factors = sch->SamplePerfectTile(
/*loop=*/loop,
/*n=*/n_tiles,
/*max_innermost_factor=*/max_innermost_factor);
Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop,
/*factors=*/{factors.begin(), factors.end()});
// Put every tile to its slot
for (int j = 0; j < n_tiles; ++j) {
tiles[idx->at(j)].push_back(splits[j]);

const int n_tiles = idx->size();

if (n_tiles == 1) {
tiles[idx->at(0)].push_back(loop);
} else {
auto splits = SplitLoop(sch, block_rv, loop, n_tiles);

// Put every tile to its slot
for (int j = 0; j < n_tiles; ++j) {
tiles[idx->at(j)].push_back(splits[j]);
}
}
}
// Step 3. Reorder to organize the tiles
Expand Down
3 changes: 3 additions & 0 deletions src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
protected:
virtual std::vector<State> ApplySubRules(std::vector<State> states);

virtual Array<tir::LoopRV> SplitLoop(const tir::Schedule& sch, tir::BlockRV block,
tir::LoopRV loop, int n_tiles) const;

// Annotate a block to use cooperative fetching
void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const;

Expand Down
120 changes: 120 additions & 0 deletions src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "../../tir/schedule/analysis.h"
#include "../../tir/schedule/transform.h"
#include "../utils.h"
#include "multi_level_tiling.h"

namespace tvm {
namespace meta_schedule {

using tir::BlockRV;
using tir::LoopRV;
using tir::Schedule;

/*!
* \brief Extension of MultiLevelTiling for backends with wide vectors.
* The loop over the innermost spatial axis of the output buffer is always vectorized with the
* maximum vector length.
*/
class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode {
public:
size_t vector_length_in_bits;

static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWideVector";
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWideVectorNode, MultiLevelTilingNode);

protected:
Array<tir::LoopRV> SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const;
};

Array<tir::LoopRV> MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv,
LoopRV loop_rv, int n_tiles) const {
const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv));
const tir::StmtSRef block_sref = sch->GetSRef(block_rv);
const tir::BlockNode* block_node = block_sref->StmtAs<tir::BlockNode>();
const tir::BlockRealize block_realize = tir::GetBlockRealize(sch->state(), block_sref);
ICHECK(block_node && block_node->writes.size() == 1);

const auto out_dtype = block_node->writes[0]->buffer->dtype;
const int vec_len = vector_length_in_bits / out_dtype.bits();

// Determine if this loop is over the innermost axis of the output buffer.
// In the example below, we look for a loop whose loop var is bound to the axis co.

// for (i0, 0, 1) {
// for (i1, 0, 56) {
// for (i2, 0, 56) {
// for (i3, 0, 64) {
// for (i4, 0, 3) {
// for (i5, 0, 3) {
// for (i6, 0, 64) {
// block conv2d_nhwc(...) {
// ...
// bind(co, i3)
// ...
// writes([conv2d_nhwc[n, h, w, co]])
// ...
// conv2d_nhwc[n, h, w, co] = ...
// }
const size_t innermost_axis = block_node->writes[0]->region.size() - 1;
const PrimExpr innermost_iter_value = block_realize->iter_values[innermost_axis];

if (!arith::Analyzer().CanProve(loop->loop_var == innermost_iter_value)) {
// If this is not the innermost spatial loop, split the loop in the normal way.
return MultiLevelTilingNode::SplitLoop(sch, block_rv, loop_rv, n_tiles);
} else {
// We split the innermost spatial loop in a way that always uses the maximum vector length.
const int64_t* extent_int = tir::GetLoopIntExtent(loop);
if (extent_int && *extent_int > vec_len) {
Array<tir::LoopRV> inner_splits = sch->Split(/*loop=*/loop_rv,
/*factors=*/{NullOpt, PrimExpr(vec_len)});
Array<tir::ExprRV> outer_factors = sch->SamplePerfectTile(
/*loop=*/inner_splits[0],
/*n=*/n_tiles - 1,
/*max_innermost_factor=*/max_innermost_factor);
Array<tir::LoopRV> outer_splits = sch->Split(
/*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()});
outer_splits.push_back(inner_splits[1]);
return outer_splits;
} else {
Array<tir::ExprRV> factors(n_tiles - 1, PrimExpr(1));
factors.push_back(loop->extent);
return sch->Split(/*loop=*/loop_rv,
/*factors=*/{factors.begin(), factors.end()});
}
}
}

ScheduleRule ScheduleRule::MultiLevelTilingWideVector(
String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
auto node = MultiLevelTilingInitCommon<MultiLevelTilingWideVectorNode>(
structure, NullOpt, max_innermost_factor, NullOpt, reuse_read, reuse_write);
node->vector_length_in_bits = vector_length_in_bits->value;
return ScheduleRule(node);
}

TVM_REGISTER_NODE_TYPE(MultiLevelTilingWideVectorNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWideVector")
.set_body_typed(ScheduleRule::MultiLevelTilingWideVector);

} // namespace meta_schedule
} // namespace tvm
108 changes: 107 additions & 1 deletion tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
from tvm import meta_schedule as ms
from tvm import te
from tvm import te, target
from tvm.meta_schedule.testing import te_workload
from tvm.meta_schedule.testing.schedule_rule import get_rules
from tvm.meta_schedule.testing.space_generation import check_sketches
Expand Down Expand Up @@ -521,9 +521,115 @@ def sum_with_trivial_block_iter(
assert not sch.trace.simplified(remove_postproc=True).insts


def test_multi_level_tiling_hexagon():
@T.prim_func
def cpu_conv2d_nhwc(
inputs: T.Buffer[(1, 56, 56, 64), "float16"],
weight: T.Buffer[(3, 3, 64, 64), "float16"],
conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float16"],
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float16")
for i0, i1, i2, i3 in T.grid(1, 58, 58, 64):
with T.block("PadInput"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1])
T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57,
inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1],
T.float16(0),
dtype="float16",
)
for (
i0_0,
i1_0,
i2_0,
i3_0,
i4_0,
i5_0,
i6_0,
i0_1_1,
i1_1_1,
i2_1_1,
i3_1_1,
i4_1,
i5_1,
i6_1,
i0_2,
i1_2,
i2_2,
i3_2,
) in T.grid(1, 1, 2, 1, 3, 3, 16, 1, 14, 2, 1, 1, 1, 4, 1, 4, 14, 64):
with T.block("conv2d_nhwc"):
n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_0)
h = T.axis.spatial(56, i1_0 * 56 + i1_1_1 * 4 + i1_2)
w = T.axis.spatial(56, i2_0 * 28 + i2_1_1 * 14 + i2_2)
co = T.axis.spatial(64, i3_0 * 64 + i3_1_1 * 64 + i3_2)
rh = T.axis.reduce(3, i4_1 + i4_0)
rw = T.axis.reduce(3, i5_0 + i5_1)
rc = T.axis.reduce(64, i6_0 * 4 + i6_1)
T.reads(PadInput[n, h + rh, w + rw, co // 64 * 64 + rc], weight[rh, rw, rc, co])
T.writes(conv2d_nhwc[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure": "SRSRS"})
with T.init():
conv2d_nhwc[n, h, w, co] = T.float16(0)
conv2d_nhwc[n, h, w, co] = (
conv2d_nhwc[n, h, w, co]
+ PadInput[n, h + rh, w + rw, co // 64 * 64 + rc] * weight[rh, rw, rc, co]
)

target_hexagon = target.hexagon("v69", num_cores=4)

I = 64
O = 64
H = 56
W = 56

mod = te.create_prim_func(
te_workload.conv2d_nhwc(1, H, W, I, O, 3, 1, 1, 1, in_dtype="float16", out_dtype="float16")
)

actual = ms.TuneContext(
mod=mod,
target=Target(target_hexagon, host=target_hexagon),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules=[
ms.schedule_rule.MultiLevelTilingWideVector(
structure="SRSRS",
vector_length_in_bits=1024,
max_innermost_factor=64,
reuse_read=None,
reuse_write=None,
)
],
task_name="test",
).generate_design_space()

decision_0 = [
("SamplePerfectTile", [1, 1, 1]),
("SamplePerfectTile", [1, 14, 4]),
("SamplePerfectTile", [2, 2, 14]),
("SamplePerfectTile", [3, 1]),
("SamplePerfectTile", [3, 1]),
("SamplePerfectTile", [16, 4]),
]

check_sketches(
mod,
sketches=actual,
expected_mods=[cpu_conv2d_nhwc],
expected_decisions=[decision_0],
)


if __name__ == "__main__":
test_cpu_matmul()
test_cpu_matmul_relu()
test_cuda_matmul()
test_cuda_matmul_relu()
test_cuda_sum_with_trivial_block_iter()
test_multi_level_tiling_hexagon()

0 comments on commit f537d1b

Please sign in to comment.