Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][M4a] Schedule Rule: Auto-Inline #9943

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ class ScheduleRule : public runtime::ObjectRef {
* \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions
* \param into_producer If allows to inline a block into its producer
* \param into_consumer If allows to inline a block into its consumer
* \param into_cache_only If it only allows to inline into a block generated by cache_read/write
* \param inline_const_tensor Always inline constant tensors
* \param disallow_if_then_else Always disallow if-then-else-like constructs
* \param require_ordered Always require the read-to-write mapping to be ordered
Expand All @@ -125,7 +124,6 @@ class ScheduleRule : public runtime::ObjectRef {
*/
TVM_DLL static ScheduleRule AutoInline(bool into_producer, //
bool into_consumer, //
bool into_cache_only, //
bool inline_const_tensor, //
bool disallow_if_then_else, //
bool require_injective, //
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
Meta Schedule schedule rules are used for modification of
blocks in a schedule. See also PostOrderApply.
"""
from .auto_inline import AutoInline
from .schedule_rule import PyScheduleRule, ScheduleRule
67 changes: 67 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/auto_inline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Auto-Inline. Rule that inlines spatial blocks if it satisfies some conditions"""
from typing import List, Optional

from tvm._ffi import register_object

from .. import _ffi_api
from .schedule_rule import ScheduleRule


@register_object("meta_schedule.AutoInline")
class AutoInline(ScheduleRule):
"""Rule that inlines spatial blocks if it satisfies some conditions
Parameters
----------
into_producer : bool
If allows to inline a block into its producer
into_consumer : bool
If allows to inline a block into its consumer
inline_const_tensor : bool
Always inline constant tensors
disallow_if_then_else : bool
Always disallow if-then-else-like constructs
require_injective : bool
Always require the read-to-write mapping to be ordered
require_ordered : bool
Always require the read-to-write mapping to be injective
disallow_op : Optional[List[str]]
The operators that are disallowed in auto inline
"""

def __init__(
self,
into_producer: bool,
into_consumer: bool,
inline_const_tensor: bool,
disallow_if_then_else: bool,
require_injective: bool,
require_ordered: bool,
disallow_op: Optional[List[str]] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleAutoInline, # type: ignore # pylint: disable=no-member
into_producer,
into_consumer,
inline_const_tensor,
disallow_if_then_else,
require_injective,
require_ordered,
disallow_op,
)
47 changes: 47 additions & 0 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Default schedule rules"""
from tvm.meta_schedule.schedule_rule import (
AutoInline,
ScheduleRule,
)
from tvm.target import Target


def auto_inline(target: Target) -> ScheduleRule:
"""Default schedule rules for auto inline"""
if target.kind.name == "llvm":
return AutoInline(
into_producer=False,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=True,
require_injective=True,
require_ordered=True,
disallow_op=["tir.exp"],
)
if target.kind.name == "cuda":
return AutoInline(
into_producer=True,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=False,
require_injective=False,
require_ordered=False,
disallow_op=None,
)
raise NotImplementedError(f"{target.kind.name} is not supported")
174 changes: 174 additions & 0 deletions src/meta_schedule/schedule_rule/auto_inline.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

/*! \brief The type of inline to be performed on a specific block */
enum class InlineType : int32_t {
/*! \brief No inline opportunity */
kNoInline = 0,
/*! \brief Inline the block into its consumer */
kInlineIntoConsumer = 1,
/*! \brief Inline the block into its producer */
kInlineIntoProducer = 2,
};

/*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */
class AutoInlineNode : public ScheduleRuleNode {
public:
/*! \brief Checks if the specific block should be inlined */
inline InlineType CheckInline(const tir::Schedule& sch, const tir::BlockRV& block_rv);

// Inherited from ScheduleRuleNode
void InitializeWithTuneContext(const TuneContext& context) final {}

// Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
InlineType inline_type = CheckInline(sch, block_rv);
if (inline_type == InlineType::kInlineIntoConsumer) {
sch->ComputeInline(block_rv);
} else if (inline_type == InlineType::kInlineIntoProducer) {
sch->ReverseComputeInline(block_rv);
}
return {sch};
}

public:
/*! \brief If allows to inline a block into its producer */
bool into_producer;
/*! \brief If allows to inline a block into its consumer */
bool into_consumer;
/*! \brief Always inline constant tensors */
bool inline_const_tensor;
/*! \brief Always disallow if-then-else-like constructs */
bool disallow_if_then_else;
/*! \brief Always require the read-to-write mapping to be injective to do auto inline */
bool require_injective;
/*! \brief Always require the read-to-write mapping to be ordered to do auto inline */
bool require_ordered;
/*! \brief The operators that are disallowed in auto inline */
Array<Op> disallow_op;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("into_producer", &into_producer);
v->Visit("into_consumer", &into_consumer);
v->Visit("inline_const_tensor", &inline_const_tensor);
v->Visit("disallow_if_then_else", &disallow_if_then_else);
v->Visit("require_injective", &require_injective);
v->Visit("require_ordered", &require_ordered);
v->Visit("disallow_op", &disallow_op);
}

static constexpr const char* _type_key = "meta_schedule.AutoInline";
TVM_DECLARE_FINAL_OBJECT_INFO(AutoInlineNode, ScheduleRuleNode);
};

inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
const tir::BlockRV& block_rv) {
using namespace tvm::tir;
StmtSRef block_sref = sch->GetSRef(block_rv);
ScheduleState state = sch->state();
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
BlockRealize realize = GetBlockRealize(state, block_sref);
// Cond 1. The block has only one write buffer
if (block->writes.size() != 1) {
return InlineType::kNoInline;
}
// Cond 2. For a block that generates a constant tensor, ignore all other conditions
if (inline_const_tensor && block->reads.empty()) {
return InlineType::kInlineIntoConsumer;
}
// Cond 3. The block doesn't contain any disallowed operators
if (!disallow_op.empty() && HasOp(realize, disallow_op)) {
return InlineType::kNoInline;
}
// Cond 4. The block doesn't have any if-then-else-like constructs
if (disallow_if_then_else && HasIfThenElse(realize)) {
return InlineType::kNoInline;
}
// Cond 5. The mapping from read indices to write indices are injective and ordered
if (require_injective || require_ordered) {
const BufferRegion& write_region = block->writes[0];
for (const BufferRegion& read_region : block->reads) {
bool injective, ordered;
auto _ = std::ignore;
std::tie(/*exists=*/_, /*surjective=*/_, injective, ordered, /*no_const_read=*/_,
/*no_shift_read=*/_) = AnalyzeReadWritePattern(read_region, write_region);
if (require_injective && injective == false) {
return InlineType::kNoInline;
}
if (require_ordered && ordered == false) {
return InlineType::kNoInline;
}
}
}
// Last cond: Check inline into the consumers or the spatial producer
tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false);
if (into_consumer) {
Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) {
return InlineType::kInlineIntoConsumer;
}
}
if (into_producer) {
Array<tir::StmtSRef> producer_srefs = GetProducers(state, block_sref);
if (producer_srefs.size() == 1 &&
tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) &&
CanReverseComputeInline(state, block_sref)) {
return InlineType::kInlineIntoProducer;
}
}
return InlineType::kNoInline;
}

ScheduleRule ScheduleRule::AutoInline(bool into_producer, //
bool into_consumer, //
bool inline_const_tensor, //
bool disallow_if_then_else, //
bool require_injective, //
bool require_ordered, //
Optional<Array<String>> disallow_op) {
ObjectPtr<AutoInlineNode> n = make_object<AutoInlineNode>();
n->into_producer = into_producer;
n->into_consumer = into_consumer;
n->inline_const_tensor = inline_const_tensor;
n->disallow_if_then_else = disallow_if_then_else;
n->require_injective = require_injective;
n->require_ordered = require_ordered;
n->disallow_op.clear();
if (disallow_op.defined()) {
Array<String> op_names = disallow_op.value();
n->disallow_op.reserve(op_names.size());
for (const String& op_name : op_names) {
n->disallow_op.push_back(Op::Get(op_name));
}
}
return ScheduleRule(n);
}

TVM_REGISTER_NODE_TYPE(AutoInlineNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline")
.set_body_typed(ScheduleRule::AutoInline);

} // namespace meta_schedule
} // namespace tvm
45 changes: 45 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_TIR_SCHEDULE_ANALYSIS_H_

#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
#include <tvm/tir/schedule/state.h>

#include <tuple>
Expand Down Expand Up @@ -442,6 +443,50 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S
bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops);

/*!
* \brief Checks if the given AST contains the specific operators
* \param stmt The AST statement to be checked
* \param ops The list of operators to be checked
* \return A boolean indicating whether the AST contains the specific operators
*/
bool HasOp(const Stmt& stmt, const Array<Op>& ops);

/*!
* \brief Checks if the given AST statement contains if-then-else, including
* 1) IfThenElse statement
* 2) Select expression
* 3) The operator `tir.if_then_else`
* 4) non-constant-true Block predicates
* \param stmt The AST statement to be checked
* \return A boolean indicating whether the statement contains the if-then-else pattern
*/
bool HasIfThenElse(const Stmt& stmt);

/*!
* \brief Given the read/write region, extract the pattern of their index correspondence
* namely, the mapping from read index to the write index.
* \param read_region The read region
* \param write_region The write region
* \return A tuple of booleans, the extracted pattern
* 0) exists: if the pattern is found
* 1) surjective: if the pattern is surjective, i.e. each write index is mapped at least once
* e.g. A[i, j] = B[i, i, j]
* 2) injective: if the pattern is injective, i.e. each write index is mapped at most once.
* e.g. A[i, j] = B[i]
* 3) ordered: if the mapping is ordered
* 4) no_const_read: if there is no constant indexing in the read indices,
* e.g. A[i, j] = B[0, i, j]
* 5) no_shift_read: if there is no constant shift in the read indices,
* e.g. A[i, j] = B[i + 1, j]
*/
std::tuple</*exists=*/bool,
/*surjective=*/bool,
/*injective=*/bool,
/*ordered=*/bool,
/*no_const_read=*/bool,
/*no_shift_read=*/bool>
AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region);

} // namespace tir
} // namespace tvm

Expand Down
Loading