diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 68948196ff6b..037606253adc 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -141,6 +141,55 @@ void PreflattenedBuffer(Buffer postflattened_buffer, Array shape, */ BlockFrame Block(String name, bool no_realize = false); +namespace axis { + +/*! + * \brief The spatial block axis defining function. + * \param dom The domain of the iteration variable. + * \param binding The binding value of the iteration variable. + * \param dtype The data type of the iteration variable. + * \return The iteration variable. + */ +Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); + +/*! + * \brief The reduced block axis defining function. + * \param dom The domain of the iteration variable. + * \param binding The binding value of the iteration variable. + * \param dtype The data type of the iteration variable. + * \return The iteration variable. + */ +Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); + +/*! + * \brief The scanning block axis defining function. + * \param dom The domain of the iteration variable. + * \param binding The binding value of the iteration variable. + * \param dtype The data type of the iteration variable. + * \return The iteration variable. + */ +Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); + +/*! + * \brief The opaque block axis defining function. + * \param dom The domain of the iteration variable. + * \param binding The binding value of the iteration variable. + * \param dtype The data type of the iteration variable. + * \return The iteration variable. + */ +Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); + +/*! + * \brief The block axis remapping function. + * \param kinds The types of the iteration variables. + * \param bindings The binding values of the iteration variables. + * \param dtype The data types of the iteration variables. + * \return The iteration variables. + */ +Array Remap(String kinds, Array bindings, DataType dtype = DataType::Int(32)); + +} // namespace axis + /*! * \brief The serial For statement. * \param start The minimum value of iteration. diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index a5cdf8a3a105..40cd99c744d7 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -20,7 +20,7 @@ from numbers import Integral from typing import Any, Dict, List, Optional, Union, Tuple -from tvm.ir import Type +from tvm.ir import Range, Type from tvm.tir import ( Buffer, BufferLoad, @@ -344,6 +344,160 @@ def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame: return _ffi_api.Block(name, no_realize) # pylint: disable=no-member # type: ignore +def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range: + """The range constructor. + + Parameters + ---------- + dom : Union[Range, List[PrimExpr]] + The domain. + + Returns + ------- + res : Range + The Range. + """ + if isinstance(dom, Range): + return dom + if isinstance(dom, (list, tuple)): + return Range(dom[0], dom[1]) + return Range(0, dom) + + +class axis: # pylint: disable=invalid-name + @staticmethod + def spatial( + dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32" + ) -> Var: + """The spatial block axis defining function. + + Parameters + ---------- + dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + The domain of the iteration variable. + + binding : PrimExpr + The binding value of the iteration variable. + + dtype : str + The data type of the iteration variable. + + Returns + ------- + res : Var + The iteration variable. + """ + return _ffi_api.AxisSpatial( # pylint: disable=no-member # type: ignore + _as_range(dom), binding, dtype + ) + + @staticmethod + def reduce( + dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32" + ) -> Var: + """The reduced block axis defining function. + + Parameters + ---------- + dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + The domain of the iteration variable. + + binding : PrimExpr + The binding value of the iteration variable. + + dtype : str + The data type of the iteration variable. + + Returns + ------- + res : Var + The iteration variable. + """ + return _ffi_api.AxisReduce( # pylint: disable=no-member # type: ignore + _as_range(dom), binding, dtype + ) + + @staticmethod + def scan( + dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32" + ) -> Var: + """The scanning block axis defining function. + + Parameters + ---------- + dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + The domain of the iteration variable. + + binding : PrimExpr + The binding value of the iteration variable. + + dtype : str + The data type of the iteration variable. + + Returns + ------- + res : Var + The iteration variable. + """ + return _ffi_api.AxisScan( # pylint: disable=no-member # type: ignore + _as_range(dom), binding, dtype + ) + + @staticmethod + def opaque( + dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32" + ) -> Var: + """The opaque block axis defining function. + + Parameters + ---------- + dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + The domain of the iteration variable. + + binding : PrimExpr + The binding value of the iteration variable. + + dtype : str + The data type of the iteration variable. + + Returns + ------- + res : Var + The iteration variable. + """ + return _ffi_api.AxisOpaque( # pylint: disable=no-member # type: ignore + _as_range(dom), binding, dtype + ) + + @staticmethod + def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]: + """The block axis remapping function. + + Parameters + ---------- + kinds : str + The types of the iteration variables. + + bindings : List[PrimExpr] + The binding values of the iteration variables. + + dtype : str + The data types of the iteration variables. + + Returns + ------- + res : Var + The iteration variables. + """ + iter_vars = _ffi_api.AxisRemap( # pylint: disable=no-member # type: ignore + kinds, bindings, dtype + ) + return iter_vars[0] if len(iter_vars) == 1 else iter_vars + + S = spatial # pylint: disable=invalid-name + R = reduce # pylint: disable=invalid-name + + def serial( start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None ) -> frame.ForFrame: @@ -843,6 +997,7 @@ def var(dtype, name="") -> Var: "match_buffer", "preflattened_buffer", "block", + "axis", "serial", "parallel", "vectorized", diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 22c7face7084..5013e321728e 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -173,6 +173,86 @@ BlockFrame Block(String name, bool no_realize) { return BlockFrame(n); } +namespace axis { + +IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) { + if (Optional opt_frame = IRBuilder::Current()->GetLastFrame()) { + BlockFrame frame = opt_frame.value(); + frame->iter_vars.push_back(iter_var); + frame->iter_values.push_back(binding); + } else { + LOG(FATAL) << "TypeError: The last frame is not BlockFrame"; + } + return iter_var; +} + +#define TVM_TIR_IR_BUILDER_AXIS(Method, Kind, Name) \ + Var Method(Range dom, PrimExpr binding, DataType dtype) { \ + ICHECK(dom.defined()) << Name << " axis must have a domain"; \ + int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \ + return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("", dtype.with_bits(bits)), \ + /*iter_type=*/Kind, /*thread_tag=*/""), \ + binding) \ + ->var; \ + } +TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tir::IterVarType::kDataPar, "Spatial"); +TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tir::IterVarType::kCommReduce, "Reduction"); +TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan"); +TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque"); +#undef TVM_TIR_IR_BUILDER_AXIS + +Array Remap(String kinds, Array bindings, DataType dtype) { + using namespace tvm::tir; + Array results; + ICHECK_EQ(kinds.size(), bindings.size()); + int n = bindings.size(); + results.reserve(n); + for (int i = 0; i < n; ++i) { + char c = kinds.c_str()[i]; + PrimExpr e = bindings[i]; + const VarNode* v = e.as(); + ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap"; + Range dom{nullptr}; + for (const auto& frame : IRBuilder::Current()->frames) { + if (const auto* for_frame = frame.as()) { + ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size()); + int n = for_frame->doms.size(); + for (int i = 0; i < n; ++i) { + if (for_frame->vars[i].get() == v) { + dom = for_frame->doms[i]; + break; + } + } + if (dom.defined()) { + break; + } + } + } + ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef(v); + DataType dtype = v->dtype; + if (c == 'S') { + results.push_back(PushBlockVar(IterVar(/*dom=*/dom, + /*var=*/Var("", dtype), + /*iter_type=*/IterVarType::kDataPar, + /*thread_tag=*/""), + e) + ->var); + } else if (c == 'R') { + results.push_back(PushBlockVar(IterVar(/*dom=*/dom, + /*var=*/Var("", dtype), + /*iter_type=*/IterVarType::kCommReduce, + /*thread_tag=*/""), + e) + ->var); + } else { + LOG(FATAL) << "Unknown axis kind: " << c; + } + } + return results; +} + +} // namespace axis + #define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \ ForFrame Method(PrimExpr start, PrimExpr stop, Optional> annotations) { \ PrimExpr min = start; \ @@ -304,6 +384,12 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.PreflattenedBuffer").set_body_typed(P TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisScan").set_body_typed(axis::Scan); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisRemap").set_body_typed(axis::Remap); + TVM_REGISTER_GLOBAL("script.ir_builder.tir.Serial").set_body_typed(Serial); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Parallel").set_body_typed(Parallel); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized").set_body_typed(Vectorized); diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 9cbfd75e2280..d893ebc545c6 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -114,6 +114,49 @@ def test_ir_builder_tir_block(): assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True) +def test_ir_builder_tir_axis(): + with IRBuilder() as ib: + a = T.var("int32", "a") + b = T.var("int32", "b") + c = T.var("int32", "c") + d = T.var("int32", "d") + with T.block("block"): + T.axis.spatial(8, a) + T.axis.reduce(16, b) + T.axis.scan(32, c) + T.axis.opaque(64, d) + T.evaluate(0) + + # the block generated by IRBuilder + block_realize_actual = ib.get() + + # the expected block + var_a = tir.Var("a", "int32") + var_b = tir.Var("b", "int32") + var_c = tir.Var("c", "int32") + var_d = tir.Var("d", "int32") + block_expected = tir.Block( + iter_vars=[ + tir.IterVar((0, 8), tir.Var("", "int32"), iter_type=tir.IterVar.DataPar), + tir.IterVar((0, 16), tir.Var("", "int32"), iter_type=tir.IterVar.CommReduce), + tir.IterVar((0, 32), tir.Var("", "int32"), iter_type=tir.IterVar.Ordered), + tir.IterVar((0, 64), tir.Var("", "int32"), iter_type=tir.IterVar.DimInfo), + ], + reads=[], + writes=[], + name_hint="block", + body=tir.Evaluate(0), + annotations={"tir.script_parsing_detect_access": tir.IntImm("int64", 3)}, + ) + block_realize_expected = tir.BlockRealize( + iter_values=[var_a, var_b, var_c, var_d], + predicate=True, + block=block_expected, + ) + # Check if the generated ir is expected + assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True) + + def test_ir_builder_tir_for(): with IRBuilder() as ib: with T.serial(128) as a: