diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index f31e515c7913..2397caffc13e 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -346,6 +346,12 @@ TVM_DLL Pass PointerValueTypeRewrite(); */ TVM_DLL Pass HoistIfThenElse(); +/*! + * \brief Lower block init stmt into IfThenElse stmts + * \return The pass. + */ +TVM_DLL Pass LowerInitBlock(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 40dd170c5414..8bd63bdfef21 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -536,3 +536,14 @@ def HoistIfThenElse(variant=None): return _ffi_api.HoistIfThenElseBasic() elif variant is None: return _ffi_api.HoistIfThenElse() + + +def LowerInitBlock(): + """Lower block init stmt into IfThenElse stmts + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerInitBlock() diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc new file mode 100644 index 000000000000..c8aca5195085 --- /dev/null +++ b/src/tir/transforms/lower_init_block.cc @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Lower block init stmt into branch stmt + * \file lower_reduction.cc + */ +#include +#include +#include + +namespace tvm { +namespace tir { + +class InitBlockLower : public StmtMutator { + private: + Stmt VisitStmt_(const BlockNode* block) final { + if (!block->init.defined()) { + return StmtMutator::VisitStmt_(block); + } + Stmt init = DoLowering(block->init.value(), block->iter_vars); + Stmt body = VisitStmt(block->body); + auto n = CopyOnWrite(block); + n->init = NullOpt; + n->body = SeqStmt::Flatten(init, body); + return Block(n); + } + + static Stmt DoLowering(const Stmt& init, const Array& iter_vars) { + std::vector conditions; + for (const IterVar& var : iter_vars) { + if (var->iter_type == IterVarType::kCommReduce) { + conditions.push_back(equal(var->var, var->dom->min)); + } + } + // Handle the case where there is no condition + if (conditions.empty()) { + return init; + } + // Concat the conditions with logical and (&&) + PrimExpr cond = conditions[0]; + for (size_t i = 1; i < conditions.size(); ++i) { + cond = logical_and(cond, conditions[i]); + } + return IfThenElse(cond, init); + } +}; + +PrimFunc LowerInitBlock(PrimFunc func) { + auto fptr = func.CopyOnWrite(); + fptr->body = InitBlockLower()(std::move(fptr->body)); + return func; +} + +namespace transform { + +Pass LowerInitBlock() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + return LowerInitBlock(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerReduction", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerInitBlock").set_body_typed(LowerInitBlock); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py new file mode 100644 index 000000000000..3fb8331d39fc --- /dev/null +++ b/tests/python/unittest/test_tir_transform_lower_init_block.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir +from tvm.script import ty + + +@tvm.script.tir +class WithInit: + def main(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [64, 64, 64]) + B = tir.match_buffer(b, [64]) + + with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: + with tir.init(): + B[i] = tir.float32(0) + B[i] += A[i, j, k] + + +@tvm.script.tir +class WithBranch: + def main(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [64, 64, 64]) + B = tir.match_buffer(b, [64]) + + with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: + if (j == 0) and (k == 32): + B[i] = tir.float32(0) + B[i] += A[i, j, k] + + +def test_lower_reduction(): + origin_mod = WithInit() + mod = tvm.tir.transform.LowerInitBlock()(origin_mod) + tvm.ir.assert_structural_equal(mod, WithBranch(), True) + + +if __name__ == "__main__": + test_lower_reduction()