diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index a1697d807db9..76826fdf7c5a 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -414,6 +414,13 @@ TVM_DLL Pass BF16StorageLegalize(); */ TVM_DLL Pass FP8StorageLegalize(); +/*! + * \brief Inline calls to private functions + * + * \return The pass. + */ +TVM_DLL Pass InlinePrivateFunctions(); + /*! * \brief Rewrite the pointer content type of arguments, * as well as Alloc internal to the function to use diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index a46b2d10373f..42c9aecd18e7 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -230,6 +230,17 @@ def StorageRewrite(): return _ffi_api.StorageRewrite() # type: ignore +def InlinePrivateFunctions(): + """Inline calls to private functions + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InlinePrivateFunctions() # type: ignore + + def PointerValueTypeRewrite(): """ Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc new file mode 100644 index 000000000000..a47c852067fa --- /dev/null +++ b/src/tir/transforms/inline_private_functions.cc @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file inline_private_functions.cc + * \brief Inline private functions to their callsite + */ +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { +namespace transform { + +namespace { + +template +using PSet = std::unordered_set; + +template +using PMap = std::unordered_map; + +PMap> CollectCallMap(const IRModule& mod) { + struct Visitor : StmtExprVisitor { + GlobalVar current; + PMap> caller_lookup; + + void VisitExpr_(const CallNode* op) { + if (auto gvar = op->op.as()) { + caller_lookup[gvar.value()].insert(current); + } + StmtExprVisitor::VisitExpr_(op); + } + } visitor; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto prim_func = base_func.as()) { + visitor.current = gvar; + visitor(prim_func->body); + } + } + + return visitor.caller_lookup; +} + +PSet CollectRecursiveFunctions(const IRModule& mod) { + // Collect all direct callers + auto call_map = CollectCallMap(mod); + + // Propagate to find all indirect callers + while (true) { + bool made_change = false; + for (const auto& [callee, callers] : call_map) { + for (const auto& caller : callers) { + if (auto it = call_map.find(caller); it != call_map.end()) { + PSet& indirect_callers = it->second; + + auto res = indirect_callers.insert(callee); + made_change = made_change || res.second; + } + } + } + if (!made_change) { + break; + } + } + + // Filter all GlobalVars that can be called by themselves, either + // directly or indirectly. + PSet recursive_funcs; + for (const auto& [caller, callees] : call_map) { + if (callees.count(caller)) { + recursive_funcs.insert(caller); + } + } + return recursive_funcs; +} + +Map CollectInlinablePrimFuncs(const IRModule& mod) { + auto recursive_functions = CollectRecursiveFunctions(mod); + + Map output; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + auto prim_func = opt.value(); + + // Only inline private functions. Externally-exposed functions + // must be preserved so to avoid breaking callsites outside of + // the IRModule. + bool is_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + + // We do not currently implement any analysis for termination of + // a function. If a recursive function requires runtime checks + // in order to terminate, we would keep inlining until the + // recursive visits segfault. + bool is_recursive = recursive_functions.count(gvar); + + // We do not currently support inlining of functions that accept + // buffer arguments. + bool has_buffer_arguments = prim_func->buffer_map.size(); + + // We do not currently support inlining of schedulable TIR + // functions. To support this use case, repeated names in + // `tir::Block` nodes resulting from multiple calls to the same + // inlined function will need to be de-duplicated. + bool has_block_node = prim_func->body.as(); + + if (!is_exposed && !is_recursive && !has_buffer_arguments && !has_block_node) { + output.Set(gvar, prim_func); + } + } + } + + return output; +} + +class PrimFuncInliner : StmtExprMutator { + public: + explicit PrimFuncInliner(Map inlinable_funcs) + : inlinable_funcs_(inlinable_funcs) { + for (const auto& [gvar, callee] : inlinable_funcs_) { + removable_funcs_.insert(gvar); + } + } + + PrimFunc VisitFunc(PrimFunc func) { + current_target_ = func->GetAttr(tvm::attr::kTarget); + auto new_body = VisitStmt(func->body); + current_target_ = NullOpt; + + if (!new_body.same_as(func->body)) { + func.CopyOnWrite()->body = new_body; + } + + return func; + } + + PSet GetRemovableFunctions() const { return removable_funcs_; } + + private: + Stmt VisitStmt_(const EvaluateNode* eval) override { + if (auto call = eval->value.as()) { + if (auto gvar = call->op.as()) { + if (auto opt_callee = inlinable_funcs_.Get(gvar.value())) { + auto callee = opt_callee.value(); + + bool is_same_target = [&]() -> bool { + auto callee_target = callee->GetAttr(tvm::attr::kTarget); + if (current_target_ && callee_target) { + return callee_target.value()->str() == current_target_.value()->str(); + } else { + return true; + } + }(); + + if (is_same_target) { + Stmt inlined = InlineArguments(gvar.value(), callee, call->args); + return VisitStmt(inlined); + } + } + } + } + + return StmtExprMutator::VisitStmt_(eval); + } + + PrimExpr VisitExpr_(const CallNode* call) override { + // Any callee that hasn't been inlined at this point must be kept + // in the output IRModule. + if (auto gvar = call->op.as()) { + removable_funcs_.erase(gvar.value()); + } + return StmtExprMutator::VisitExpr_(call); + } + + Stmt InlineArguments(const GlobalVar& gvar, PrimFunc callee, const Array& args) const { + CHECK_EQ(callee->params.size(), args.size()) + << "Callee " << gvar << " accepts " << callee->params.size() << " parameters (" + << callee->params << "), but is called with " << args.size() << " arguments (" << args + << ")"; + + ICHECK(callee->buffer_map.empty()) + << "Inlining of PrimFuncs with buffer arguments is not yet supported, " + << "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; + + Map param_map; + for (size_t i = 0; i < callee->params.size(); i++) { + param_map.Set(callee->params[i], args[i]); + } + + callee = Specialize(callee, param_map); + + return callee->body; + } + + // Map from GlobalVar to PrimFuncs which may be inlined. + Map inlinable_funcs_; + + /* \brief Set of callees that may be removed + * + * Some constructs may not be inlined (e.g. if the call site occurs + * outside of an Evaluate node). For these cases, the output + * IRModule must still contain the callee. + */ + PSet removable_funcs_; + + Optional current_target_ = NullOpt; +}; + +} // namespace + +Pass InlinePrivateFunctions() { + auto pass_func = [](IRModule mod, PassContext ctx) { + auto inlinable_prim_funcs = CollectInlinablePrimFuncs(mod); + + if (inlinable_prim_funcs.empty()) { + // Early bail-out if the module has no inlinable PrimFuncs. + return mod; + } + + PrimFuncInliner mutator(std::move(inlinable_prim_funcs)); + IRModule updates; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + auto updated = mutator.VisitFunc(opt.value()); + if (!updated.same_as(base_func)) { + updates->Add(gvar, updated); + } + } + } + + if (updates->functions.size()) { + auto write_ptr = mod.CopyOnWrite(); + write_ptr->Update(updates); + for (const auto& gvar : mutator.GetRemovableFunctions()) { + write_ptr->Remove(gvar); + } + mod = ConvertSSA()(mod); + } + + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.InlinePrivateFunctions", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InlinePrivateFunctions").set_body_typed(InlinePrivateFunctions); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/tir-transform/test_tir_inline_private_functions.py b/tests/python/tir-transform/test_tir_inline_private_functions.py new file mode 100644 index 000000000000..2edf74ebfb3d --- /dev/null +++ b/tests/python/tir-transform/test_tir_inline_private_functions.py @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm.testing +from tvm.script import ir as I, tir as T + + +class BaseTestCase: + def test_well_formed(self): + After = tvm.tir.transform.InlinePrivateFunctions()(self.Before) + tvm.tir.analysis.verify_well_formed(After) + + def test_produces_expected(self): + After = tvm.tir.transform.InlinePrivateFunctions()(self.Before) + tvm.ir.assert_structural_equal(self.Expected, After) + + +class TestSimple(BaseTestCase): + """Simple case directly acting on PrimFunc""" + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): + for i in range(64): + Before.subroutine(T.address_of(A[i, 0]), T.address_of(B[i, 0])) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): + A = T.decl_buffer([16, 16], "float32", data=A_data) + B = T.decl_buffer([16], "float32", data=B_data) + for i in range(16): + B[i] = 0.0 + for j in range(16): + B[i] = B[i] + A[i, j] + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): + for i in range(64): + A_view_data: T.handle("float32") = T.address_of(A[i, 0]) + Aview = T.decl_buffer([16, 16], "float32", data=A_view_data) + B_view_data: T.handle("float32") = T.address_of(B[i, 0]) + Bview = T.decl_buffer([16], "float32", data=B_view_data) + for j in range(16): + Bview[j] = 0.0 + for k in range(16): + Bview[j] = Bview[j] + Aview[j, k] + + +class TestRetainCrossFunctionSubroutines(BaseTestCase): + """Do not inline functions that cross device boundaries + + When lowering TIR, calls for which the callsite and callee have + different targets are used at some stages, before being further + lowered to explicit device kernel launches. Since inlining the + function would remove this cross-device information, + InlinePrivateSubroutines should not inline these cases. + """ + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): + T.func_attr({"target": T.target("llvm")}) + for i in range(64): + Before.subroutine(T.address_of(A[i, 0]), T.address_of(B[i, 0])) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): + T.func_attr({"target": T.target("cuda")}) + A = T.decl_buffer([16, 16], "float32", data=A_data) + B = T.decl_buffer([16], "float32", data=B_data) + for i in range(16): + B[i] = 0.0 + for j in range(16): + B[i] = B[i] + A[i, j] + + Expected = Before + + +class TestRetainRecursiveSubroutines(BaseTestCase): + """Do not inline recursive functions + + To avoid potentially infinite loops at compile-time, disable + inlining of recursive functions. If inlining of these functions + would be useful, this restriction may be relaxed with improved + analysis of the subroutine. + """ + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + Before.subroutine(T.address_of(A[0]), 16) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32"), A_size: T.int32): + A = T.decl_buffer(A_size, "float32", data=A_data) + A[1] = A[0] + A[1] + + if A_size > 1: + Before.subroutine(T.address_of(A[1]), A_size - 1) + + Expected = Before + + +class TestDeduplicateBlockName(BaseTestCase): + """Block names must be de-duplicated after inlining""" + + @pytest.mark.xfail(reason="Inlining of schedulable TIR not yet supported") + def test_produces_expected(self): + super().test_produces_expected(self) + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer([2, 16], "float32"), B: T.Buffer([2, 16], "float32")): + Before.subroutine(T.address_of(A[0, 0]), T.address_of(B[0, 0])) + Before.subroutine(T.address_of(A[1, 0]), T.address_of(B[1, 0])) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + with T.block("scalar_mul"): + B[i] = A[i] * 2.0 + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): + with T.LetStmt(T.address_of(A[0, 0]), var=T.handle("float32")) as A_data_1: + A_1 = T.decl_buffer(16, "float32", data=A_data_1) + B_data_1: T.handle("float32") = T.address_of(B[0, 0]) + B_1 = T.decl_buffer(16, "float32", data=B_data_1) + for i in range(16): + with T.block("scalar_mul_1"): + B_1[i] = A_1[i] * 2.0 + + with T.LetStmt(T.address_of(A[1, 0]), var=T.handle("float32")) as A_data_2: + A_2 = T.decl_buffer(16, "float32", data=A_data_2) + B_data_2: T.handle("float32") = T.address_of(B[1, 0]) + B_2 = T.decl_buffer(16, "float32", data=B_data_2) + for i in range(16): + with T.block("scalar_mul_2"): + B_2[i] = A_2[i] * 2.0 + + +class TestInlineCallOccurringInExpression(BaseTestCase): + """Inline a Call node that is used in a function + + The current implementation only replaces `tir.Call` instances that + occur in a `tir.Evaluate` context. This is the primary use case, + used in destination-passing style. + + This unit test is marked as xfail. If/when the implementation + supports inlining of function calls occurring as part of an + expression, the annotation can be removed. + """ + + @pytest.mark.xfail(reason="Inlining of PrimFuncs outside of tir.Evaluate is not yet supported") + def test_produces_expected(self): + super().test_produces_expected(self) + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + for i in range(16): + A[i] = Before.subroutine(i) + + @T.prim_func(private=True) + def subroutine(i: T.int32) -> T.float32: + cos = T.cos(T.cast(i, "float32")) + sin = T.sin(T.cast(i, "float32")) + retval = cos * cos + sin * sin + T.ret(retval) + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + for i in range(16): + cos = T.cos(T.cast(i, "float32")) + sin = T.sin(T.cast(i, "float32")) + retval = cos * cos + sin * sin + A[i] = retval + + +class TestInlineFunctionWithBufferArguments(BaseTestCase): + """Inline a function that accepts buffer arguments + + The current implementation does not support this usage. This unit + test is provided to display a possible user interaction, and is + marked with `@pytest.mark.xfail`. If/when the implementation + supports inlining of function calls with buffer arguments, the + annotation can be removed. + """ + + @pytest.mark.xfail(reason="Inlining of PrimFuncs with buffer arguments") + def test_produces_expected(self): + super().test_produces_expected(self) + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + Before.subroutine( + T.tvm_stack_make_array( + A.data, + T.tvm_stack_make_shape(*A.shape, dtype="handle"), + 0, + len(A.shape), + 0.0, + A.elem_offset, + dtype="handle", + ) + ) + + @T.prim_func(private=True) + def subroutine(A: T.Buffer(16, "float32")): + for i in range(16): + A[i] = A[i] * 2.0 + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + for i in range(16): + A[i] = A[i] * 2.0 + + +if __name__ == "__main__": + tvm.testing.main()