From 020bea042faeb874bf989025632c036298f6dae6 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 8 Feb 2023 22:31:47 +0800 Subject: [PATCH] [Unity] Relax TVMScript Parser. (#13932) This PR adds the TVMScript parser/ir_builder support based on the blockbuilder. Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao Co-authored-by: Tianqi Chen Co-authored-by: Yuchen Jin Co-authored-by: Steven S. Lyubomirsky Co-authored-by: Yong Wu --- include/tvm/ir/expr.h | 1 + include/tvm/script/ir_builder/ir/frame.h | 11 +- include/tvm/script/ir_builder/ir/ir.h | 17 + include/tvm/script/ir_builder/relax/frame.h | 293 +++++ include/tvm/script/ir_builder/relax/ir.h | 137 +++ python/tvm/ir/expr.py | 50 +- python/tvm/script/ir_builder/base.py | 6 +- python/tvm/script/ir_builder/ir/__init__.py | 2 +- python/tvm/script/ir_builder/ir/ir.py | 45 + .../tvm/script/ir_builder/relax/__init__.py | 20 + .../tvm/script/ir_builder/relax/_ffi_api.py | 20 + python/tvm/script/ir_builder/relax/frame.py | 55 + python/tvm/script/ir_builder/relax/ir.py | 407 +++++++ python/tvm/script/parser/__init__.py | 3 +- python/tvm/script/parser/core/diagnostics.py | 2 +- python/tvm/script/parser/core/entry.py | 4 + python/tvm/script/parser/core/evaluator.py | 2 +- python/tvm/script/parser/core/parser.py | 50 +- python/tvm/script/parser/ir/parser.py | 4 + python/tvm/script/parser/relax/__init__.py | 17 +- python/tvm/script/parser/relax/entry.py | 327 +++++ python/tvm/script/parser/relax/parser.py | 276 +++++ python/tvm/script/parser/tir/entry.py | 4 +- python/tvm/script/parser/tir/parser.py | 26 + src/ir/module.cc | 30 +- src/script/ir_builder/ir/frame.cc | 12 +- src/script/ir_builder/ir/ir.cc | 41 +- src/script/ir_builder/ir/utils.h | 49 + src/script/ir_builder/relax/frame.cc | 273 +++++ src/script/ir_builder/relax/ir.cc | 236 ++++ src/script/ir_builder/relax/utils.h | 119 ++ src/script/ir_builder/tir/frame.cc | 15 +- src/script/ir_builder/tir/utils.h | 2 +- .../python/relax/test_tvmscript_ir_builder.py | 153 +++ tests/python/relax/test_tvmscript_parser.py | 1062 +++++++++++++++++ 35 files changed, 3726 insertions(+), 45 deletions(-) create mode 100644 include/tvm/script/ir_builder/relax/frame.h create mode 100644 include/tvm/script/ir_builder/relax/ir.h create mode 100644 python/tvm/script/ir_builder/relax/__init__.py create mode 100644 python/tvm/script/ir_builder/relax/_ffi_api.py create mode 100644 python/tvm/script/ir_builder/relax/frame.py create mode 100644 python/tvm/script/ir_builder/relax/ir.py create mode 100644 python/tvm/script/parser/relax/entry.py create mode 100644 python/tvm/script/parser/relax/parser.py create mode 100644 src/script/ir_builder/ir/utils.h create mode 100644 src/script/ir_builder/relax/frame.cc create mode 100644 src/script/ir_builder/relax/ir.cc create mode 100644 src/script/ir_builder/relax/utils.h create mode 100644 tests/python/relax/test_tvmscript_ir_builder.py create mode 100644 tests/python/relax/test_tvmscript_parser.py diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index d4ba628d36cf..c662067a0486 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -462,6 +462,7 @@ class GlobalVarNode : public RelayExprNode { v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); + v->Visit("struct_info_", &struct_info_); } bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const { diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index 887981ccffc8..dacfc361a6c7 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -38,12 +38,17 @@ namespace ir { */ class IRModuleFrameNode : public IRBuilderFrameNode { public: - Array global_vars; - Array functions; + /*! \brief A map from string names to global variables that ensures global uniqueness. */ + Map global_var_map; + /*! + * \brief A map from GlobalVar to all global functions. + * \note Only defined functions are in the map, while declared functions are not included. + */ + Map functions; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); - v->Visit("global_vars", &global_vars); + v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); } diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index f0e7cc6f5c2f..49bdcf60e6fb 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -37,6 +37,23 @@ namespace ir { */ TVM_DLL IRModuleFrame IRModule(); +/*! + * \brief Declare a Function without given the specific function implementation. + * \note It is usually used in cross-function call. And we can specify the function by `DefFunction` + * \param func_name The function unique name. + * \param func_signature A Function w/o body, which used to specify the function signature + * (i.e. func params and func return type/shape). + * \return The corresponding GlobalVar. + */ +TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature); + +/*! + * \brief Define the function which is declared before. + * \param func_name The function unique name. + * \param func The given function implementation + */ +TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func); + } // namespace ir } // namespace ir_builder } // namespace script diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h new file mode 100644 index 000000000000..0f544d3abcc2 --- /dev/null +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +/*! \brief The base ir_builder frame for the relax dialect. */ +class RelaxFrameNode : public IRBuilderFrameNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); } + + static constexpr const char* _type_key = "script.ir_builder.relax.RelaxFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(RelaxFrameNode, IRBuilderFrameNode); +}; + +class RelaxFrame : public IRBuilderFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame, IRBuilderFrame, RelaxFrameNode); + + protected: + RelaxFrame() = default; +}; + +/*! \brief The base ir_builder frame for frames with SeqExpr + i.e. Functions, If branches + */ +class SeqExprFrameNode : public RelaxFrameNode { + public: + /*! \brief The binding blocks inside the frame. */ + Array binding_blocks; + /*! \brief The frame output expr. `NullOpt` when undefined. */ + Optional output; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("binding_blocks", &binding_blocks); + v->Visit("output", &output); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.SeqExprFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(SeqExprFrameNode, RelaxFrameNode); + + public: + void EnterWithScope() override; + void ExitWithScope() override; +}; + +class SeqExprFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SeqExprFrame, RelaxFrame, SeqExprFrameNode); +}; + +/*! \brief The ir_builder frame for the relax function. */ +class FunctionFrameNode : public SeqExprFrameNode { + public: + /*! + * \brief The function name. + * \note The name will not be specified in constructor, so it is "Optional", + * However, we must specify the name by `R.func_name` before exit this frame. + */ + Optional name; + /*! \brief The function params. */ + Array params; + /*! + * \brief The function return struct info. + * \note Usually the function return type can be deduced by the function body. + * But we can use this field to specify a more "accurate" return type. + * i.e. If the `ret_struct_info` is None, try to use the deduced type from body + * If the `ret_struct_info` is not None, we can still take body.struct_info + * if we ret_struct_info is base of body.struct_info. If not, we will + * take the specified `ret_struct_info`. + */ + Optional ret_struct_info; + + /*! \brief The function attributes. */ + Map attrs; + /*! \brief The block builder to create Relax function. */ + tvm::relax::BlockBuilder block_builder; + + void VisitAttrs(tvm::AttrVisitor* v) { + SeqExprFrameNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("params", ¶ms); + v->Visit("ret_struct_info", &ret_struct_info); + v->Visit("attrs", &attrs); + v->Visit("binding_blocks", &binding_blocks); + v->Visit("output", &output); + // `block_builder` is not visited. + } + + static constexpr const char* _type_key = "script.ir_builder.relax.FunctionFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, SeqExprFrameNode); + + public: + void ExitWithScope() final; +}; + +class FunctionFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionFrame, SeqExprFrame, FunctionFrameNode); +}; + +/*! \brief The ir_builder frame for relax binding blocks. */ +class BlockFrameNode : public RelaxFrameNode { + public: + /*! \brief The flag that indicates whether the block is a dataflow block. */ + bool is_dataflow; + /*! \brief The variables emitted in this block. */ + Array emitted_vars; + /*! + * \brief A boolean indicating if the dataflow block is ended of construction. + * If it is true, any new binding trying to be emitted into this block will cause an error. + * \note Only used for a dataflow block. + */ + bool block_ended; + /*! + * \brief The output vars of the dataflow block. + * \note Only used for a dataflow block. + */ + Array output_vars; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("is_dataflow", &is_dataflow); + v->Visit("emitted_vars", &emitted_vars); + v->Visit("output_vars", &output_vars); + // `block_ended` is not visited. + } + + static constexpr const char* _type_key = "script.ir_builder.relax.BlockFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, RelaxFrameNode); + + public: + void EnterWithScope() final; + void ExitWithScope() final; +}; + +class BlockFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, RelaxFrame, BlockFrameNode); +}; + +/*! + * \brief A frame that represents if statement. + * + * \sa IfFrame + */ +class IfFrameNode : public RelaxFrameNode { + public: + /*! \brief The condition of the if statement. */ + tvm::relax::Expr condition; + /*! \brief The Bindings in the true branch. */ + Optional then_expr; + /*! \brief The Bindings in the false branch. */ + Optional else_expr; + /*! \brief The Binding var. */ + tvm::relax::Var var; + /*! \brief The binding var name. */ + String var_name; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + v->Visit("then_expr", &then_expr); + v->Visit("else_expr", &else_expr); + v->Visit("var", &var); + v->Visit("var_name", &var_name); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.IfFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, RelaxFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to IfFrameNode. + * + * \sa IfFrameNode + */ +class IfFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, RelaxFrame, IfFrameNode); +}; + +/*! + * \brief A frame that represents then. + * + * \sa ThenFrame + */ +class ThenFrameNode : public SeqExprFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.relax.ThenFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, SeqExprFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ThenFrameNode. + * + * \sa ThenFrameNode + */ +class ThenFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, SeqExprFrame, ThenFrameNode); +}; + +/*! + * \brief A frame that represents else. + * + * \sa ElseFrame + */ +class ElseFrameNode : public SeqExprFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.relax.ElseFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, SeqExprFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ElseFrameNode. + * + * \sa ElseFrameNode + */ +class ElseFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, SeqExprFrame, ElseFrameNode); +}; + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h new file mode 100644 index 000000000000..72aab6684ebf --- /dev/null +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +/////////////////////////////// Function //////////////////////////////// + +/*! + * \brief Start a function frame. + * \return The created ir_builder Function frame. + */ +TVM_DLL FunctionFrame Function(); + +/*! + * \brief Add a parameter to the last function frame. + * \param name The name of the parameter. + * \param struct_info The struct_info of the parameter. + * \return The created function parameter var. + */ +TVM_DLL tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_info); + +/*! + * \brief Specify the name of the last function frame. + * \param name The function name. + */ +TVM_DLL void FuncName(const String& name); + +/*! + * \brief Specify the attrs of the last function frame. + * \param attrs The function attrs. + */ +TVM_DLL void FuncAttrs(Map attrs); + +/*! + * \brief Specify the return struct info of the last function frame. + * \param ret_sinfo The return struct info. + */ +TVM_DLL void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo); + +/*! + * \brief Specify the return value of the last function frame. + * \param value The return value. + */ +TVM_DLL void FuncRetValue(const tvm::relax::Expr& value); + +///////////////////////////// BindingBlock ////////////////////////////// + +/*! + * \brief Start a binding block frame. + * \return The created ir_builder Block frame. + */ +TVM_DLL BlockFrame BindingBlock(); + +/*! + * \brief Start a dataflow binding block frame. + * \return The created ir_builder Block frame. + */ +TVM_DLL BlockFrame Dataflow(); + +/*! + * \brief Expose the dataflow block output variables as global ones + * \param vars The output variables of a dataflow block + */ +TVM_DLL void DataflowBlockOutput(const Array& vars); + +////////////////////////////// Bindings //////////////////////////////// + +/*! + * \brief Emit a binding to the last binding block frame. + * \param value The right side value of the bindings to be emitted. + * \param annotate_struct_info The optional struct info annotation for the emitted value. + * \return The left side var of the emitted binding. + */ +TVM_DLL tvm::relax::Var Emit( + const tvm::relax::Expr& value, + const Optional& annotate_struct_info = NullOpt); + +/*! + * \brief Emit a match_cast binding to the last binding block frame. + * \param value The value of the MatchCast to be emitted. + * \param struct_info The struct info of the MatchCast to be emitted. + * \return The left side var of the emitted binding. + */ +TVM_DLL tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, + const tvm::relax::StructInfo& struct_info); + +///////////////////////////// If Then Else ///////////////////////////// + +/*! + * \brief Create an if statement. + * \param condition The condition of if statement. + * \return The result IfFrame. + */ +IfFrame If(tvm::relax::Expr condition); +/*! + * \brief Create a then. + * \return The result ThenFrame. + */ +ThenFrame Then(); +/*! + * \brief Create an else. + * \return The result ElseFrame. + */ +ElseFrame Else(); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index f90468de66c6..721e12e7f8d9 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -93,10 +93,17 @@ def __call__(self, *args): A call taking the variable as a function. """ # pylint: disable=import-outside-toplevel + + # TODO(@relax-team): replace with Relax base class after it's introduced if all(isinstance(x, RelayExpr) for x in args): - from tvm import relay + if all(is_relax_expr(x) for x in args): + from tvm import relax + + return relax.Call(self, args) + else: + from tvm import relay - return relay.Call(self, args) + return relay.Call(self, args) arg_types = [type(x) for x in args] raise RuntimeError( "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types) @@ -185,3 +192,42 @@ def from_min_extent(min_value, extent, span=None): The constructed range. """ return _ffi_api.Range_from_min_extent(min_value, extent, span) + + +# TODO(@relax-team): remove when we have a RelaxExpr base class +def is_relax_expr(expr: RelayExpr) -> bool: + """check if a RelayExpr is a Relax expresssion. + + Parameters + ---------- + expr : RelayExpr + The expression to check. + + Returns + ------- + res : bool + If the expression is Relax expression, return True; otherwise return False. + """ + from tvm import relax # pylint: disable=import-outside-toplevel + + if isinstance( + expr, + ( + relax.Call, + relax.Constant, + relax.Tuple, + relax.TupleGetItem, + relax.If, + relax.Var, + relax.DataflowVar, + relax.ShapeExpr, + relax.SeqExpr, + relax.Function, + relax.ExternFunc, + relax.PrimValue, + relax.StringImm, + relax.DataTypeImm, + ), + ): + return True + return False diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index 7aa33ee49c72..b35bbd0a7df5 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -64,8 +64,10 @@ def __enter__(self) -> "IRBuilderFrame": _ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member return self - def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument - _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member + def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument + if exc_type is None and exc_value is None: + # Do not execute `FrameExit` if the with scope exits because of exceptions + _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member def add_callback(self, callback: Callable[[], None]) -> None: """Add a callback method invoked when exiting the with-scope. diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index ebb9728737ad..946be263a779 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,4 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import ir_module +from .ir import decl_function, def_function, ir_module diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 213180463cb2..796d6f3aad04 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,9 +16,54 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from tvm.ir import BaseFunc, GlobalVar + from . import _ffi_api from .frame import IRModuleFrame def ir_module() -> IRModuleFrame: + """Start a ir_module frame. + Returns + ------- + frame: IRModuleFrame + The constructed frame. + """ return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member + + +def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar: + """Declare a Function without given the specific function implementation. + Parameters + ---------- + func_name : str + The function unique name. + + func_signature: Optional[BaseFunc] + A Function w/o body, which used to specify the function signature + (i.e. func params and func return type/shape). + + Note + ---- + It is usually used in cross-function call. And we can specify the function by `DefFunction` + Returns + ------- + gv : GlobalVar + The corresponding GlobalVar. + """ + + return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member + func_name, func_signature + ) + + +def def_function(func_name: str, func: BaseFunc) -> None: + """Define the function which is declared before. + Parameters + ---------- + func_name : str + The function unique name. + func: BaseFunc + The given function implementation + """ + return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/ir_builder/relax/__init__.py b/python/tvm/script/ir_builder/relax/__init__.py new file mode 100644 index 000000000000..f0905acf34e3 --- /dev/null +++ b/python/tvm/script/ir_builder/relax/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Package tvm.script.ir_builder.relax""" +from . import frame +from .ir import * # pylint: disable=wildcard-import,redefined-builtin diff --git a/python/tvm/script/ir_builder/relax/_ffi_api.py b/python/tvm/script/ir_builder/relax/_ffi_api.py new file mode 100644 index 000000000000..6e2098cf88af --- /dev/null +++ b/python/tvm/script/ir_builder/relax/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.script.ir_builder.relax""" +import tvm._ffi + +tvm._ffi._init_api("script.ir_builder.relax", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/relax/frame.py b/python/tvm/script/ir_builder/relax/frame.py new file mode 100644 index 000000000000..97e181fbe4be --- /dev/null +++ b/python/tvm/script/ir_builder/relax/frame.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""IR Builder Frame for Relax dialect""" +from tvm._ffi import register_object as _register_object + +from ..base import IRBuilderFrame + + +@_register_object("script.ir_builder.relax.RelaxFrame") +class RelaxFrame(IRBuilderFrame): + """The base ir_builder frame for the relax dialect.""" + + +@_register_object("script.ir_builder.relax.SeqExprFrame") +class SeqExprFrame(RelaxFrame): + ... + + +@_register_object("script.ir_builder.relax.FunctionFrame") +class FunctionFrame(SeqExprFrame): + """The ir_builder frame for the relax function.""" + + +@_register_object("script.ir_builder.relax.BlockFrame") +class BlockFrame(RelaxFrame): + """The ir_builder frame for relax binding blocks.""" + + +@_register_object("script.ir_builder.relax.IfFrame") +class IfFrame(RelaxFrame): + ... + + +@_register_object("script.ir_builder.relax.ThenFrame") +class ThenFrame(SeqExprFrame): + ... + + +@_register_object("script.ir_builder.relax.ElseFrame") +class ElseFrame(SeqExprFrame): + ... diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py new file mode 100644 index 000000000000..647ef8f25af7 --- /dev/null +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -0,0 +1,407 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, wrong-import-order, no-member, invalid-name +"""IRBuilder for Relax dialect""" + +import builtins +import functools +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import tvm +from tvm import DataType, relax +from tvm.ir import PrimExpr +from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, const + +############################### Operators ############################### +from tvm.relax.op import ( + add, + assert_op, + call_builtin_with_ctx, + call_tir, + invoke_closure, + make_closure, + multiply, + null_value, + print, + shape_of, +) +from tvm.relax.struct_info import StructInfo +from tvm.relax.utils import args_converter +from tvm.runtime import Object as tvm_Object +from tvm.runtime import ObjectGeneric + +from . import _ffi_api, frame + +##################### Python Native Function Alias ###################### + +py_print = builtins.print +py_tuple = tuple +py_str = str + + +############################### Function ################################ + + +def function() -> frame.FunctionFrame: + """Start a function frame. + Returns + ------- + frame: FunctionFrame + The constructed function frame. + """ + return _ffi_api.Function() # type: ignore[attr-defined] # pylint: disable=no-member + + +def arg(name: py_str, struct_info: StructInfo) -> Var: + """Add a parameter to the last function frame. + Parameters + ---------- + name: str + The name of the parameter. + struct_info: StructInfo + The Struct Info of the parameter + + Returns + ------- + var: Var + The created function parameter var. + """ + + return _ffi_api.Arg(name, struct_info) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_name(name: py_str) -> None: + """Specify the name of the last function frame. + Parameters + ---------- + name: str + The function name. + """ + return _ffi_api.FuncName(name) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_attr(attrs: Dict[py_str, tvm_Object]) -> None: + """Specify the attrs of the last function frame. + Parameters + ---------- + attrs: Dict[str, Object] + The function attrs. + """ + return _ffi_api.FuncAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_ret_struct_info(ret_sinfo: StructInfo) -> None: + """Specify the return struct info of the last function frame. + Parameters + ---------- + ret_type: StructInfo + The function return struct info. + """ + return _ffi_api.FuncRetStructInfo(ret_sinfo) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_ret_value(value: Expr) -> None: + """Specify the return value of the last function frame. + Parameters + ---------- + value: Expr + The function return value. + """ + return _ffi_api.FuncRetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member + + +############################# BindingBlock ############################## + + +def dataflow() -> frame.BlockFrame: + """Start a dataflow binding block frame. + Returns + ------- + frame: frame.BlockFrame + The created ir_builder Block frame. + """ + return _ffi_api.Dataflow() # type: ignore[attr-defined] # pylint: disable=no-member + + +def output(*vars: Tuple[Var]) -> None: + """Expose the dataflow block output variables as global ones. + Parameters + ---------- + vars: Tuple[Var] + The output variables of a dataflow block. + """ + return _ffi_api.DataflowBlockOutput(vars) # type: ignore[attr-defined] # pylint: disable=no-member + + +################################## Ops ################################# + + +@args_converter.auto +def call_packed( + func: py_str, + *args: Expr, + sinfo_args: Union[StructInfo, List[StructInfo]], + **kwargs: Any, +) -> Call: + """Create a relax Call, which calls a packed function. + Parameters + ---------- + func: str + The name of extern function. + *args : Expr + The arguments. + sinfo_args: Union[StructInfo, List[StructInfo]] + The list of structure info arguments. + kwargs: Expr + The keyword arguments. + + Returns + ------- + call: Call + The created Relax Call + """ + op = ExternFunc(func) + if sinfo_args is None: + raise ValueError("R.call_packed is required to have type_args") + if isinstance(sinfo_args, py_tuple): # type: ignore + sinfo_args = list(sinfo_args) + elif not isinstance(sinfo_args, list): + sinfo_args = [sinfo_args] + for i, sinfo_arg in enumerate(sinfo_args): + if callable(sinfo_arg): + sinfo_arg = sinfo_arg() + # Convert possible StructInfoProxy to StructInfo + if isinstance(sinfo_arg, ObjectGeneric): + sinfo_arg = sinfo_arg.asobject() + sinfo_args[i] = sinfo_arg + + is_default = False + if "attrs_type_key" in kwargs: + attrs_type_key = kwargs["attrs_type_key"] + kwargs.pop("attrs_type_key") + else: + attrs_type_key = "DictAttrs" + is_default = True + attrs = None + if kwargs or not is_default: + attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs) + + return Call(op, args, attrs=attrs, sinfo_args=sinfo_args) + + +def _sinfo_arg_wrapper(func): + """A wrapper to convert StructInfoProxies to StructInfo for builtin operators with sinfo_args""" + + def _convert_tensor_type(args): + if isinstance(args, (list, py_tuple)): # type: ignore + new_args = [_convert_tensor_type(x) for x in args] + return type(args)(new_args) + if isinstance(args, dict): + return {_convert_tensor_type(k): _convert_tensor_type(v) for k, v in args.items()} + if inspect.isfunction(args): + args = args() + if isinstance(args, ObjectGeneric): + args = args.asobject() + return args + + @functools.wraps(func) + def wrapped(*args, **kwargs): + return func(*_convert_tensor_type(args), **_convert_tensor_type(kwargs)) + + return wrapped # type: ignore + + +invoke_closure = _sinfo_arg_wrapper(invoke_closure) # pylint: disable=invalid-name + +call_builtin_with_ctx = _sinfo_arg_wrapper(call_builtin_with_ctx) # pylint: disable=invalid-name + +############################### Bindings ############################### + + +def emit(value: Expr, annotate_struct_info: Optional[StructInfo] = None) -> Var: + """Emit a binding to the last binding block frame. + Parameters + ---------- + value: Expr + The right side value of the bindings to be emitted. + + annotate_struct_info: Optional[StructInfo] + The optional struct info annotation for the emitted value. + + Returns + ------- + var: Var + The left side var of the emitted binding. + """ + return _ffi_api.Emit(value, annotate_struct_info) # type: ignore[attr-defined] # pylint: disable=no-member + + +def emit_match_cast(value: Expr, struct_info: StructInfo) -> Var: + """Emit a match_cast binding to the last binding block frame. + Parameters + ---------- + value: Expr + The value of the MatchCast to be emitted. + struct_info: StructInfo + The struct_info of the MatchCast to be emitted. + + Returns + ------- + var: Var + The left side var of the emitted binding. + """ + return _ffi_api.EmitMatchCast(value, struct_info) # type: ignore + + +############################# If Then Else ############################# + + +def If(condition: Expr) -> frame.IfFrame: # pylint: disable=invalid-name + """Create an if frame. + Parameters + ---------- + condition : Expr + The condition of if statement, executes the true branch if the condition is true, + otherwise jump into the false branch. + Returns + ------- + res : frame.IfFrame + The result IfFrame. + """ + return _ffi_api.If(condition) # type: ignore[attr-defined] # pylint: disable=no-member + + +def Then() -> frame.ThenFrame: # pylint: disable=invalid-name + """Create a then frame. + Returns + ------- + res : frame.ThenFrame + The result ThenFrame. + """ + return _ffi_api.Then() # type: ignore[attr-defined] # pylint: disable=no-member + + +def Else() -> frame.ElseFrame: # pylint: disable=invalid-name + """Create an else frame. + Returns + ------- + res : frame.ElseFrame + The result ElseFrame. + """ + return _ffi_api.Else() # type: ignore[attr-defined] # pylint: disable=no-member + + +############################### R.tuple ################################ + + +def tuple(*fields: Expr) -> Expr: + """Create a tuple expression. + Parameters + ---------- + *fields : Expr + The fields of the tuple. + Returns + ------- + res : Expr + The result tuple. + """ + if len(fields) == 0: + fields = py_tuple() + + return relax.Tuple(fields) # type: ignore[attr-defined] # pylint: disable=no-member + + +############################### PrimValue ############################## + + +def prim_value(value: PrimExpr) -> Expr: + """Create a prim value expression. + Parameters + ---------- + value : PrimExpr + The value of the prim value. + Returns + ------- + res : Expr + The result prim value. + """ + return relax.PrimValue(value) # type: ignore[attr-defined] # pylint: disable=no-member + + +def str(value: py_str) -> Expr: + """Create a string imm expression. + Parameters + ---------- + value : str + The value of the str. + Returns + ------- + res : Expr + The result str. + """ + return relax.StringImm(value) # type: ignore[attr-defined] # pylint: disable=no-member + + +def dtype(value: Union[py_str, DataType]) -> Expr: + """Create a dtype imm expression. + Parameters + ---------- + value : dtype + The value of the dtype. + Returns + ------- + res : Expr + The result dtype. + """ + return relax.DataTypeImm(value) # type: ignore[attr-defined] # pylint: disable=no-member + + +############################### Importer ############################### + +__all__ = [ + "Else", + "If", + "Then", + "TupleGetItem", + "add", + "arg", + "assert_op", + "call_packed", + "call_tir", + "call_builtin_with_ctx", + "const", + "dataflow", + "dtype", + "emit", + "emit_match_cast", + "func_attr", + "func_name", + "func_ret_struct_info", + "func_ret_value", + "function", + "invoke_closure", + "make_closure", + "multiply", + "null_value", + "output", + "prim_value", + "print", + "shape_of", + "str", + "tuple", +] diff --git a/python/tvm/script/parser/__init__.py b/python/tvm/script/parser/__init__.py index 5161a2601c49..678297799e6d 100644 --- a/python/tvm/script/parser/__init__.py +++ b/python/tvm/script/parser/__init__.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the Licens. """The parser""" -from . import _core, ir, tir +from . import _core, ir, tir, relax from ._core import parse from .ir import ir_module from .tir import prim_func +from .relax import function diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py index ad7ae5034780..2767a97f6096 100644 --- a/python/tvm/script/parser/core/diagnostics.py +++ b/python/tvm/script/parser/core/diagnostics.py @@ -220,7 +220,7 @@ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) level : diagnostics.DiagnosticLevel The diagnostic level. """ - lineno = node.lineno or self.source.start_line + lineno = node.lineno or 1 col_offset = node.col_offset or self.source.start_column end_lineno = node.end_lineno or lineno end_col_offset = node.end_col_offset or col_offset diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 9e6c100c954d..3c01b54a9f1a 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -43,6 +43,7 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) if extra_vars is None: import tvm # pylint: disable=import-outside-toplevel from tvm.script.parser import ir # pylint: disable=import-outside-toplevel + from tvm.script.parser import relax # pylint: disable=import-outside-toplevel from tvm.script.parser import tir # pylint: disable=import-outside-toplevel extra_vars = { @@ -51,6 +52,9 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) "ir": ir, "T": tir, "tir": tir, + "relax": relax, + "R": relax, + "tvm": tvm, } source = Source(program) diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 3a72a3c33106..075aedd89146 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -203,7 +203,7 @@ def _visit(self, node: doc.AST) -> Any: else: value = self._eval_expr(node.__class__(**fields)) except Exception as e: # pylint: disable=broad-except,invalid-name - self.parser.report_error(node, str(e)) + self.parser.report_error(node, e) return self._add_intermediate_result(value) def _eval_lambda(self, node: doc.Lambda) -> Any: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 7c699c42aecb..105164ed5ffc 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -60,6 +60,10 @@ def context(): return context() +def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument + pass + + class VarTableFrame: """The variable table frame. A frame of variable table stores the variables created in one block or scope. @@ -259,6 +263,17 @@ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: node = self.diag.source.as_ast() self.visit(node) + def get_dispatch_token(self, node: doc.FunctionDef) -> str: + if not isinstance(node, doc.FunctionDef): + self.report_error(node, "Only can get dispatch token for function.") + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if not hasattr(decorator, "dispatch_token"): + self.report_error(node, "The parser does not understand the decorator") + return decorator.dispatch_token + def with_dispatch_token(self, token: str): """Add a new dispatching token as with statement. @@ -388,6 +403,8 @@ def report_error( # Only take the last line of the error message if isinstance(err, TVMError): msg = list(filter(None, str(err).split("\n")))[-1] + elif isinstance(err, KeyError): + msg = "KeyError: " + str(err) else: msg = str(err) self.diag.error(node, msg) @@ -457,30 +474,33 @@ def visit_tvm_annotation(self, node: doc.expr) -> Any: """ return _dispatch(self, "tvm_annotation")(self, node) - def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name - """The general function definition visiting method. + def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name + """The general function definition visit method. Parameters ---------- node : doc.FunctionDef - The doc AST function definition node. - - Returns - ------- - res : Any - The visiting result. + The doc FunctionDef node. """ - if not node.decorator_list: - self.report_error(node, "Function must be decorated") - # TODO: only the last decorator is parsed - decorator = self.eval_expr(node.decorator_list[-1]) - if not hasattr(decorator, "dispatch_token"): - self.report_error(node, "The parser does not understand the decorator") - token = decorator.dispatch_token + token = self.get_dispatch_token(node) + current_token = self.dispatch_tokens[-1] func = dispatch.get(token=token, type_name="FunctionDef", default=None) if func is None: self.report_error(node, "The parser does not understand the decorator") + pre_func = dispatch.get( + token=current_token, type_name="pre_token_switch", default=_do_nothing + ) + post_func = dispatch.get( + token=current_token, type_name="post_token_switch", default=_do_nothing + ) + pre_func(self, node) _dispatch_wrapper(func)(self, node) + post_func(self, node) + + def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None: + token = self.get_dispatch_token(node) + with self.with_dispatch_token(token): + _dispatch(self, "tvm_declare_function")(self, node) def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name """The general class definition visiting method. diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index e0268412d284..13b3e298590f 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -32,8 +32,12 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: node : doc.ClassDef The doc AST class definition node. """ + with self.var_table.with_frame(): with I.ir_module(): + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit_tvm_declare_function(stmt) with self.with_dispatch_token("ir"): self.visit_body(node.body) diff --git a/python/tvm/script/parser/relax/__init__.py b/python/tvm/script/parser/relax/__init__.py index feb8e683401c..04f3fea21c2b 100644 --- a/python/tvm/script/parser/relax/__init__.py +++ b/python/tvm/script/parser/relax/__init__.py @@ -15,7 +15,18 @@ # specific language governing permissions and limitations # under the License. """Initial impl of relax parser for sugars""" -from tvm.relax import TensorStructInfo, ShapeStructInfo +from ...ir_builder.relax import * # pylint: disable=redefined-builtin +from ...ir_builder.relax import ir as _relax +from . import parser as _parser +from .entry import Callable, Object, Prim, Shape, Tensor, Tuple, function, match_cast -Tensor = TensorStructInfo -Shape = ShapeStructInfo +__all__ = _relax.__all__ + [ + "Callable", + "Object", + "Prim", + "Shape", + "Tensor", + "Tuple", + "function", + "match_cast", +] diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py new file mode 100644 index 000000000000..d93f9a2826bc --- /dev/null +++ b/python/tvm/script/parser/relax/entry.py @@ -0,0 +1,327 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +import inspect +from typing import Any +from typing import Callable as _Callable +from typing import Dict, List, Optional, Set, TypeVar, Union + +from tvm.relax import ( + Expr, + FuncStructInfo, + Function, + ObjectStructInfo, + PrimStructInfo, + ShapeStructInfo, + StructInfo, + TensorStructInfo, + TupleStructInfo, +) +from tvm.runtime import ObjectGeneric +from tvm.tir import PrimExpr + +from .._core import parse, utils + +FType = TypeVar("FType", bound=_Callable) + +############################## R.function ############################## + + +def function(f: FType) -> Union[Function, FType]: + if not inspect.isfunction(f): + raise TypeError(f"Expect a function, but got: {f}") + if utils.is_defined_in_class(inspect.stack(), f): + return f + return parse(f, utils.inspect_function_capture(f)) + + +setattr(function, "dispatch_token", "relax") + + +############################# Struct Info ############################## + + +class StructInfoProxy(ObjectGeneric): + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> StructInfo: + raise NotImplementedError() + + def get_symbolic_vars(self) -> Set[str]: + return {} + + def asobject(self): + return self.as_struct_info(None) + + +############################### R.Tensor ############################### + + +def _eval_shape(expr: Union[str, PrimExpr], dict_globals: Optional[Dict[str, Any]]) -> PrimExpr: + if isinstance(expr, str): + code = compile(expr, "", "eval") + return eval(code, dict_globals or {}) # pylint: disable=eval-used + else: + return expr + + +class TensorProxy(StructInfoProxy): + shape: Optional[List[Union[str, PrimExpr]]] + dtype: str + ndim: int + + def __init__( + self, + shape: Optional[List[Union[PrimExpr, str]]] = None, + dtype: Optional[str] = None, + ndim: int = -1, + ) -> None: + self.shape = shape + self.dtype = dtype + self.ndim = ndim + super().__init__() + + def get_symbolic_vars(self) -> Set[str]: + if self.shape is None: + return {} + else: + return {s for s in self.shape if isinstance(s, str) and s.isidentifier()} + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo: + if self.shape is None: + return TensorStructInfo(None, self.dtype, self.ndim) + else: + if dict_globals is None and any([isinstance(s, str) for s in self.shape]): + raise ValueError( + "String-defined shape expr is only allowed when parsing function parameters " + "and return annotations for TVMScript." + ) + shape = [_eval_shape(s, dict_globals) for s in self.shape] + return TensorStructInfo(shape, self.dtype, self.ndim) + + +def Tensor( + shape: Optional[List[Union[PrimExpr, str]]] = None, + dtype: Optional[str] = None, + ndim: int = -1, +) -> TensorProxy: + # scalar tensor case + if shape is not None and len(shape) == 0: + shape = [] + if isinstance(shape, str) and dtype is None: + dtype = shape + shape = None + + if shape is not None and not isinstance(shape, (tuple, list)): + raise ValueError(f"shape must be a list or tuple, but got: {shape}") + return TensorProxy(shape, dtype, ndim) + + +############################## R.Callable ############################## + + +class CallableProxy(StructInfoProxy): + params: List[StructInfoProxy] + ret: StructInfoProxy + """Function type. + + A function type consists of a list of type parameters to enable + the definition of generic functions, + a set of type constraints which we omit for the time being, + a sequence of argument types, and a return type. + + Parameters + ---------- + params : List[StructInfoProxy] + The argument StructInfoProxy + + ret : StructInfoProxy + The return StructInfoProxy. + + """ + + def __init__( + self, + params: Union[StructInfoProxy, List[StructInfoProxy]], + ret: StructInfoProxy, + ) -> None: + if not isinstance(params, (list, tuple)): + params = [params] + # convert `R.Tensor` to `R.Tensor()` + self.params = [param() if callable(param) else param for param in params] + self.ret = ret() if callable(ret) else ret + + def get_symbolic_vars(self) -> Set[str]: + return set().union(*[p.get_symbolic_vars() for p in self.params]) + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncStructInfo: + params = [param.as_struct_info(dict_globals) for param in self.params] + ret = self.ret.as_struct_info(dict_globals) + return FuncStructInfo(params, ret) + + +def Callable( + params: Union[StructInfoProxy, List[StructInfoProxy]], + ret: StructInfoProxy, +) -> CallableProxy: + return CallableProxy(params, ret) + + +############################### R.Tuple ################################ + + +class TupleProxy(StructInfoProxy): + fields: List[StructInfoProxy] + """The type of tuple values. + + Parameters + ---------- + fields : List[StructInfoProxy] + The fields in the tuple + """ + + def __init__( + self, + *fields: List[StructInfoProxy], + ) -> None: + if len(fields) == 1 and isinstance(fields[0], (tuple, list)): + fields = fields[0] + # convert `R.Tensor` to `R.Tensor()` + self.fields = [field() if callable(field) else field for field in fields] + + def get_symbolic_vars(self) -> Set[str]: + return set().union(*[f.get_symbolic_vars() for f in self.fields]) + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TupleStructInfo: + fields = [field.as_struct_info(dict_globals) for field in self.fields] + return TupleStructInfo(fields) + + +def Tuple(*fields: List[StructInfoProxy]) -> TupleProxy: + return TupleProxy(*fields) + + +############################### R.Shape ################################ + + +class ShapeProxy(StructInfoProxy): + values: Optional[List[PrimExpr]] + ndim: int + """The type of shape values. + + Parameters + ---------- + values : Optional[List[PrimExpr]] + The symbolic shape values if known. + + ndim : Optional[int] + The size of the shape. + """ + + def __init__( + self, + values: Optional[List[PrimExpr]] = None, + ndim: int = -1, + ) -> None: + self.values = values + self.ndim = ndim + + def get_symbolic_vars(self) -> Set[str]: + if self.values is None: + return {} + else: + return {v for v in self.values if isinstance(v, str) and v.isidentifier()} + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: + values = [_eval_shape(v, dict_globals) for v in self.values] if self.values else None + return ShapeStructInfo(values, self.ndim) + + +def Shape(values: Optional[List[PrimExpr]] = None, ndim: int = -1) -> ShapeProxy: + return ShapeProxy(values, ndim) + + +############################### R.Object ################################ + + +class ObjectProxy(StructInfoProxy): + """The proxy fo ObjectStructInfo. + + Parameters + ---------- + values : Optional[List[PrimExpr]] + The symbolic shape values if known. + + ndim : Optional[int] + The size of the shape. + """ + + def __init__(self) -> None: + pass + + def get_symbolic_vars(self) -> Set[str]: + return set() + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: + return ObjectStructInfo() + + +def Object() -> ObjectProxy: + return ObjectProxy() + + +################################ R.Prim ################################ + + +class PrimProxy(StructInfoProxy): + dtype: str + """The type of shape values. + + Parameters + ---------- + dtype : str + The data type. + """ + + def __init__(self, dtype: str) -> None: + self.dtype = dtype + + def get_symbolic_vars(self) -> Set[str]: + return set() + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: + return PrimStructInfo(self.dtype) + + +def Prim(dtype: str) -> PrimProxy: + return PrimProxy(dtype) + + +############################ R.match_cast ############################# +class MatchCastPair: + value: Expr + struct_info: StructInfo + + def __init__(self, value: Expr, struct_info: StructInfo) -> None: + self.value = value + self.struct_info = struct_info + + +def match_cast(value: Expr, struct_info: StructInfo): + if value is None: + raise ValueError("value of match_cast cannot be None") + if struct_info is None: + raise ValueError("struct_info of match_cast cannot be None") + return MatchCastPair(value, struct_info) diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py new file mode 100644 index 000000000000..ef26ddd6e921 --- /dev/null +++ b/python/tvm/script/parser/relax/parser.py @@ -0,0 +1,276 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +import functools +import numbers +from typing import Any, Optional + +from tvm import relax, tir +from tvm.ir import structural_equal +from tvm.relax import StructInfo +from tvm.relax.utils import convert_to_expr +from tvm.script.ir_builder.relax.frame import BlockFrame + +from ...ir_builder import ir as I +from ...ir_builder import relax as R +from ...ir_builder.base import IRBuilder +from .._core import Parser, dispatch, doc +from .entry import MatchCastPair, StructInfoProxy, TupleProxy + + +def bind_assign_value( + self: Parser, + node: doc.expr, + var_name: str, + value: Any, + anno_sinfo: Optional[StructInfo] = None, +) -> Any: + var_table = self.var_table.get() + + if isinstance(value, tir.Var): + if value.name and var_name != value.name: + self.report_error( + node, + "Cannot define TIR variables with different names. The LHS of binding should " + "has the same name provided in RHS.", + ) + if var_name in var_table: + prev_value = var_table[var_name] + if not isinstance(prev_value, tir.Var): + self.report_error( + node, + "Cannot redefine a non-TIR-variable object to a TIR variable. Please " + "define the TIR variable with another name.", + ) + if prev_value.dtype != value.dtype: + self.report_error( + node, + "Expected the same dtype for TIR vars " + f"but got {value.dtype} vs {prev_value.dtype}", + ) + return prev_value + IRBuilder.name(var_name, value) + return value + + if isinstance(value, tuple): + value = convert_to_expr(value) + if isinstance(value, numbers.Number): + value = R.const(value) + + if isinstance(value, relax.Expr): + var = R.emit(value, anno_sinfo) + elif isinstance(value, MatchCastPair): + if anno_sinfo is not None and not structural_equal(anno_sinfo, value.struct_info): + self.report_error( + node, "Cannot specify inconsistent annotation for a match cast pair. " + ) + var = R.emit_match_cast(value.value, value.struct_info) + else: + raise TypeError(f"Unsupported type {type(value)} in assignment") + + IRBuilder.name(var_name, var) + return var + + +def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy: + try: + annotation = self.eval_expr(node) + if annotation is None: + return TupleProxy([]) + if callable(annotation): + annotation = annotation() + if isinstance(annotation, StructInfoProxy): + return annotation + else: + raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.") + except Exception as err: + self.report_error(node, str(err)) + raise err + + +def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> StructInfo: + var_table = self.var_table.get() if eval_str else None + try: + return eval_struct_info_proxy(self, node).as_struct_info(var_table) + except Exception as err: + self.report_error(node, str(err)) + raise err + + +def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> None: + # Collect symbolic vars from parameters + symbolic_vars = set() + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation) + symbolic_vars.update(param_sinfo_proxy.get_symbolic_vars()) + + # Define symbolic vars to the current var_table frame + for var_name in symbolic_vars: + self.var_table.add(var_name, tir.Var(var_name, "int64"), allow_shadowing=False) + + +@dispatch.register(token="relax", type_name="FunctionDef") +def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + with self.var_table.with_frame(): + with self.with_dispatch_token("relax"): + with R.function(): + R.func_name(node.name) + collect_symbolic_var_from_params(self, node) + + if node.returns is not None: + ann_sinfo = eval_struct_info(self, node.returns, eval_str=True) + R.func_ret_struct_info(ann_sinfo) + + self.visit(node.args) + self.visit_body(node.body) + + +@dispatch.register(token="relax", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: + with self.var_table.with_frame(): + collect_symbolic_var_from_params(self, node) + + if node.returns is None: + # Use ObjectStructInfo as unknown return type + # NOTE: Cannot use VoidStructInfo here because the return type can be refined later. + ret_sinfo = relax.ObjectStructInfo() + else: + ret_sinfo = eval_struct_info(self, node.returns, eval_str=True) + params = [] + params_sinfo = [] + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) + params_sinfo.append(param_sinfo) + params.append(relax.Var(arg.arg, param_sinfo)) + + func_signature = relax.Function.create_empty(params, ret_sinfo) + global_var = I.decl_function(node.name, func_signature) + self.var_table.add(node.name, global_var) + + +@dispatch.register(token="relax", type_name="pre_token_switch") +def pre_token_switch(self: Parser, node: doc.Expr) -> None: # pylint: disable=unused-argument + ir_builder = IRBuilder() + ir_builder.__enter__() + + +@dispatch.register(token="relax", type_name="post_token_switch") +def post_token_switch(self: Parser, node: doc.Expr) -> None: + ir_builder = IRBuilder.current() + result = ir_builder.get() + ir_builder.__exit__(None, None, None) + var = R.emit(result) + IRBuilder.name(node.name, var) + self.var_table.add(node.name, var, allow_shadowing=False) + + +@dispatch.register(token="relax", type_name="Expr") +def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: + value = self.eval_expr(node.value) + if value is not None: + self.report_error(node, f"Unsupported Expr stmt type {value}.") + + +@dispatch.register(token="relax", type_name="arguments") +def visit_arguments(self: Parser, node: doc.arguments) -> None: + arg: doc.arg + for arg in node.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) + param = R.arg(arg.arg, param_sinfo) + + self.var_table.add(arg.arg, param) + + +@dispatch.register(token="relax", type_name="tvm_annotation") +def visit_tvm_annotation(self: Parser, node: doc.expr) -> StructInfo: + return eval_struct_info(self, node, eval_str=False) + + +@dispatch.register(token="relax", type_name="With") +def visit_with(self: Parser, node: doc.With) -> None: + # Currently only `with R.dataflow()` is supported + if len(node.items) != 1: + self.report_error(node, "Only one item is allowed.") + item = node.items[0] + if item.optional_vars is not None: + self.report_error( + item.context_expr, + "Relax syntax doesn't allow binding expressions in `with` to variables", + ) + frame = self.eval_expr(item.context_expr) + with self.var_table.with_frame(): + with frame: + self.visit(node.body) + if isinstance(frame, BlockFrame) and frame.is_dataflow: + output_vars = frame.output_vars + for var in output_vars: + self.var_table.add(var.name_hint, var, allow_shadowing=True) + + +@dispatch.register(token="relax", type_name="Assign") +def visit_assign(self: Parser, node: doc.Assign) -> None: + if len(node.targets) != 1: + self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") + lhs = node.targets[0] + rhs = self.eval_expr(node.value) + self.eval_assign( + target=lhs, + source=rhs, + bind_value=bind_assign_value, + allow_shadowing=True, + ) + + +@dispatch.register(token="relax", type_name="AnnAssign") +def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: + lhs = node.target + rhs = self.eval_expr(node.value) + anno_sinfo = self.visit_tvm_annotation(node.annotation) + self.eval_assign( + target=lhs, + source=rhs, + bind_value=functools.partial(bind_assign_value, anno_sinfo=anno_sinfo), + allow_shadowing=True, + ) + + +@dispatch.register(token="relax", type_name="Return") +def visit_return(self: Parser, node: doc.Assign) -> None: + value = self.eval_expr(node.value) + value = convert_to_expr(value) + R.func_ret_value(value) + + +@dispatch.register(token="relax", type_name="If") +def visit_if(self: Parser, node: doc.If) -> None: + if node.orelse is None: + raise ValueError("Else statements are required for relax dialect.") + with R.If(self.eval_expr(node.test)) as if_frame: + with self.var_table.with_frame(): + with R.Then(): + self.visit_body(node.body) + with self.var_table.with_frame(): + with R.Else(): + self.visit_body(node.orelse) + self.var_table.add(if_frame.var_name, if_frame.var, allow_shadowing=True) diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 411a7f8f3c83..649f817411f0 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -83,7 +83,7 @@ def __getitem__(self, keys) -> Buffer: return self(keys) if len(keys) >= 2 and not isinstance(keys[1], str): return self(keys) - return self(*keys) # pylint: disable=no-member # type: ignore + return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member class PtrProxy: @@ -93,7 +93,7 @@ class PtrProxy: def __call__(self, dtype, storage_scope="global"): if callable(dtype): dtype = dtype().dtype - return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore + return ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member @deprecated("T.Ptr[...]", "T.handle(...)") def __getitem__(self, keys): diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 5796db40ec06..8ebcbb2133e5 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -24,6 +24,7 @@ from tvm.ir import PrimType from tvm.tir import Buffer, IterVar, PrimExpr, Var +from ...ir_builder import ir as I from ...ir_builder import tir as T from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame @@ -470,3 +471,28 @@ def visit_return(self: Parser, node: doc.Return) -> None: The doc AST return node. """ self.report_error(node, "Return is not allowed.") + + +@dispatch.register(token="tir", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: + """The function declaration step for tir + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Return + The doc AST return node. + """ + + ret_type = None + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + + # Only ret_type is needed for func_signature. + func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) + global_var = I.decl_function(node.name, func_signature) + self.var_table.add(node.name, global_var) diff --git a/src/ir/module.cc b/src/ir/module.cc index 42ced9612045..f30d98b50185 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -63,20 +63,30 @@ IRModule::IRModule(tvm::Map functions, } bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { - if (functions.size() != other->functions.size()) return false; if (!equal(this->attrs, other->attrs)) return false; if (equal.IsPathTracingEnabled()) { const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); + + // Update GlobalVar before equality check + if (functions.size() != other->functions.size()) return false; + for (const auto& gv : this->GetGlobalVars()) { + if (!other->ContainGlobalVar(gv->name_hint)) return false; + if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; + } + for (const auto& kv : this->functions) { - if (!other->ContainGlobalVar(kv.first->name_hint)) return false; ObjectPathPair func_paths = {obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first), obj_path_pair->rhs_path->Attr("functions") ->MapValue(other->GetGlobalVar(kv.first->name_hint))}; if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) return false; } if (type_definitions.size() != other->type_definitions.size()) return false; + for (const auto& gtv : this->GetGlobalTypeVars()) { + if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false; + if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false; + } + for (const auto& kv : this->type_definitions) { - if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false; ObjectPathPair type_def_paths = { obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first), obj_path_pair->rhs_path->Attr("type_definitions") @@ -86,13 +96,23 @@ bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) } return true; } + + if (functions.size() != other->functions.size()) return false; + // Update GlobalVar before equality check + for (const auto& gv : this->GetGlobalVars()) { + if (!other->ContainGlobalVar(gv->name_hint)) return false; + if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; + } for (const auto& kv : this->functions) { - if (!other->ContainGlobalVar(kv.first->name_hint)) return false; if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; } if (type_definitions.size() != other->type_definitions.size()) return false; + // Update GlobalTypeVar remap + for (const auto& gtv : this->GetGlobalTypeVars()) { + if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false; + if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false; + } for (const auto& kv : this->type_definitions) { - if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false; if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false; } return true; diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index a81c56922dff..addf12928435 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -26,11 +26,15 @@ namespace ir_builder { namespace ir { void IRModuleFrameNode::ExitWithScope() { - ICHECK_EQ(functions.size(), global_vars.size()); - int n = functions.size(); Map func_map; - for (int i = 0; i < n; ++i) { - func_map.Set(global_vars[i], functions[i]); + CHECK_EQ(functions.size(), global_var_map.size()) + << "All functions must be defined in the IRModule. Got " << global_var_map.size() + << "declared function(s), but only " << functions.size() << "defined function(s)."; + for (const auto& kv : functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined"; + func_map.Set(gv, func); } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index a8cc452e4f0c..da2330b5772b 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -17,9 +17,12 @@ * under the License. */ #include +#include #include #include +#include "./utils.h" + namespace tvm { namespace script { namespace ir_builder { @@ -27,12 +30,48 @@ namespace ir { IRModuleFrame IRModule() { ObjectPtr n = make_object(); - n->global_vars.clear(); + n->global_var_map.clear(); n->functions.clear(); return IRModuleFrame(n); } +GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) { + IRModuleFrame frame = FindModuleFrame("I.DeclFunction"); + CHECK(!frame->global_var_map.count(func_name)) + << "ValueError: function " << func_name << " already exists"; + GlobalVar gv = GlobalVar(func_name); + if (func_signature->struct_info_.defined()) { + gv->struct_info_ = tvm::relax::GetStructInfo(func_signature); + } else if (const auto* prim_func = func_signature.as()) { + gv->struct_info_ = + tvm::relax::FuncStructInfo::OpaqueFunc(tvm::relax::StructInfoFromType(prim_func->ret_type)); + } else { + LOG(FATAL) << "Unsupported function type: " << func_signature->GetTypeKey(); + } + CHECK(frame->functions.find(gv) == frame->functions.end()) + << "ValueError: function " << func_name << " has already been defined."; + frame->global_var_map.Set(func_name, gv); + if (func_signature.defined()) { + frame->functions.Set(gv, func_signature); + } + return gv; +} + +void DefFunction(const String& func_name, const BaseFunc& func) { + IRModuleFrame frame = FindModuleFrame("I.DefFunction"); + auto it = frame->global_var_map.find(func_name); + CHECK(it != frame->global_var_map.end()) + << "ValueError: function " << func_name << " does not exist, please declare it first."; + const GlobalVar& gv = (*it).second; + frame->functions.Set(gv, func); + if (func->checked_type_.defined()) { + gv->checked_type_ = func->checked_type_; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h new file mode 100644 index 000000000000..58d5e53f7032 --- /dev/null +++ b/src/script/ir_builder/ir/utils.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ + +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace ir { + +inline IRModuleFrame FindModuleFrame(const String& method) { + IRBuilder builder = IRBuilder::Current(); + if (Optional frame = builder->FindFrame()) { + const Optional& last_module_frame = builder->GetLastFrame(); + if (last_module_frame.defined() && last_module_frame.value() == frame) { + return frame.value(); + } + } else { + LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method + << "' is called under I.ir_module()"; + } + LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under I.ir_module()"; + throw; +} + +} // namespace ir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc new file mode 100644 index 000000000000..c78b9e73c534 --- /dev/null +++ b/src/script/ir_builder/relax/frame.cc @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +void SeqExprFrameNode::ExitWithScope() { + // At this moment, there should be at most one BlockFrame which hasn't ended. In this case, call + // its `ExitBlockFrame` and check if there is any more unended BlockFrame. + if (Optional block_frame = IRBuilder::Current()->GetLastFrame()) { + block_frame.value()->ExitWithScope(); + ICHECK(!IRBuilder::Current()->GetLastFrame().defined()) + << "ValueError: There is some remaining BlockFrame that is not properly popped out."; + } + RelaxFrameNode::ExitWithScope(); +} + +void SeqExprFrameNode::EnterWithScope() { + RelaxFrameNode::EnterWithScope(); + BindingBlock()->EnterWithScope(); +} + +void FunctionFrameNode::ExitWithScope() { + using ir::IRModuleFrame; + using tvm::relax::Expr; + IRBuilder builder = IRBuilder::Current(); + SeqExprFrameNode::ExitWithScope(); + // Step 1: Create the function. + CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use " + "`return` to return an Expr"; + this->block_builder->BeginScope(params); + Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); + auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); + this->block_builder->EndScope(); + tvm::relax::Function func(/*params=*/params, + /*body=*/body, + /*ret_struct_info=*/ret_struct_info, + /*attrs=*/dict_attrs); + // Step 2: Update IRModule. + if (builder->frames.empty()) { + // Case 0. No outer frame, return function directly + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = func; + } else if (Optional opt_frame = builder->FindFrame()) { + // Case 1. A global function of an IRModule + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // First time visiting the function. + ir::DeclFunction(func_name, func); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); + } else { + LOG(FATAL) << "ValueError: Cannot find where to insert Relax.Function"; + } +} + +void BlockFrameNode::EnterWithScope() { + // Step 1. If the last frame is a block frame. The start of a new block frame marks the end of the + // last block frame. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + if (block_frame.defined()) { + block_frame.value()->ExitWithScope(); + // Block frames cannot appear consecutively. + ICHECK(!IRBuilder::Current()->GetLastFrame()); + } + // Step 2. Deal with the new block frame. + RelaxFrameNode::EnterWithScope(); + Optional func_frame = IRBuilder::Current()->FindFrame(); + CHECK(func_frame.defined()) + << "ValueError: Cannot find FunctionFrame when creating BindingBlocks, Please ensure " + "creating the block under Relax function scope."; + const tvm::relax::BlockBuilder& block_builder = func_frame.value()->block_builder; + if (is_dataflow) { + block_builder->BeginDataflowBlock(); + } else { + block_builder->BeginBindingBlock(); + } +} + +class DataflowBlockRewriter : public tvm::relax::ExprMutator { + public: + static tvm::relax::DataflowBlock Rewrite(const tvm::relax::DataflowBlock& block, + const Array& output_vars) { + DataflowBlockRewriter rewriter(output_vars); + return Downcast(rewriter.VisitBindingBlock(block)); + } + + private: + explicit DataflowBlockRewriter(const Array& output_vars) { + for (const tvm::relax::Var& var : output_vars) { + output_var_set_.insert(var.get()); + } + } + + tvm::relax::Var VisitVarDef_(const tvm::relax::DataflowVarNode* op) final { + auto it = output_var_set_.find(op); + if (it != output_var_set_.end()) { + // Rewrite dataflow vars to global vars + auto n = make_object(*op); + tvm::relax::Var new_var(n); + this->var_remap_[op->vid] = new_var; + return new_var; + } else { + return GetRef(op); + } + } + + private: + std::unordered_set output_var_set_; +}; + +void BlockFrameNode::ExitWithScope() { + // Step 1. Pop the current frame out of the frame stack. + RelaxFrameNode::ExitWithScope(); + + // Step 2. Get the constructed binding block from the block builder. The block should have at + // lease one binding - otherwise, the block is not supposed to be created. + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + tvm::relax::BindingBlock block = block_builder->EndBlock(); + if (block->bindings.empty()) { + return; + } + + // Step 3. Rewrite the dataflow block. + if (is_dataflow) { + // Step 3.1. Rewrite block binding + block = DataflowBlockRewriter::Rewrite(Downcast(block), output_vars); + + // Step 3.2. Collect global vars' reference in bindings + Map new_global_vars; + for (const tvm::relax::Binding& binding : block->bindings) { + if (!binding->var->IsInstance()) { + new_global_vars.Set(binding->var->vid, binding->var); + } + } + + // Step 3.3. Rewrite output vars + Array new_output_vars; + for (const auto& var : output_vars) { + auto it = new_global_vars.find(var->vid); + ICHECK(it != new_global_vars.end()); + new_output_vars.push_back((*it).second); + } + output_vars = std::move(new_output_vars); + } + + // Step 3. Get the last frame from the IRBuilder frame stack. + Optional opt_last_frame = IRBuilder::Current()->GetLastFrame(); + ICHECK(opt_last_frame.defined()); + RelaxFrame last_frame = opt_last_frame.value(); + + // Step 4. Since we popped out any possible block frame when entering the "with" scope of the + // current frame, the last frame cannot be a block frame. + ICHECK(!last_frame->IsInstance()); + + // Step 5. Push the block frame into the corresponding field of the last frame. + if (const auto* seq_frame = last_frame.as()) { + ICHECK(!seq_frame->output.defined()) + << "The function is not expected to have output values when emitting blocks."; + auto frame = GetRef(seq_frame); + frame->binding_blocks.push_back(block); + } else { + LOG(FATAL) << "ValueError: Currently the last frame is supposed to be either a function frame " + "or a block frame. However, the last frame is \"" + << last_frame->GetTypeKey() << "\"."; + } + + // Step 6. Start another binding block when a dataflow block ended. + if (is_dataflow) { + BindingBlock()->EnterWithScope(); + } +} + +void IfFrameNode::EnterWithScope() { + const Array& frames = IRBuilder::Current()->frames; + for (const IRBuilderFrame& frame : frames) { + const auto* block_frame = frame.as(); + if (block_frame && block_frame->is_dataflow) { + LOG(FATAL) << "ValueError: Cannot create an IfFrame inside a dataflow block."; + } + } + RelaxFrameNode::EnterWithScope(); +} + +void IfFrameNode::ExitWithScope() { + RelaxFrameNode::ExitWithScope(); + CHECK(then_expr.defined()) + << "ValueError: The body of then part is expected to be defined before exiting."; + CHECK(then_expr.defined()) + << "ValueError: The body of else part is expected to be defined before exiting."; + auto body = tvm::relax::If(condition, then_expr.value(), else_expr.value()); + var = Emit(body); + IRBuilder::Name(var_name, var); +} + +void ThenFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("R.Then"); + CHECK(!frame->then_expr.defined()) + << "ValueError: Duplicate then branch declaration, previous one is " + << frame->then_expr.value(); + SeqExprFrameNode::EnterWithScope(); +} + +void ThenFrameNode::ExitWithScope() { + SeqExprFrameNode::ExitWithScope(); + String var_name; + output = GetSeqExprForBranch(GetRef(this), &var_name); + IfFrame frame = FindIfFrame("R.Then"); + frame->then_expr = output; + frame->var_name = var_name; +} + +void ElseFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("R.Else"); + CHECK(frame->then_expr.defined()) << "The else branch should follow then branch"; + CHECK(!frame->else_expr.defined()) + << "ValueError: Duplicate else branch declaration, previous one is " + << frame->else_expr.value(); + SeqExprFrameNode::EnterWithScope(); +} + +void ElseFrameNode::ExitWithScope() { + SeqExprFrameNode::ExitWithScope(); + String var_name; + output = GetSeqExprForBranch(GetRef(this), &var_name); + IfFrame frame = FindIfFrame("R.Else"); + frame->else_expr = output; + CHECK(frame->var_name == var_name) + << "This last binding of both branches must have the same variable."; +} + +TVM_REGISTER_NODE_TYPE(FunctionFrameNode); +TVM_REGISTER_NODE_TYPE(SeqExprFrameNode); +TVM_REGISTER_NODE_TYPE(BlockFrameNode); +TVM_REGISTER_NODE_TYPE(IfFrameNode); +TVM_REGISTER_NODE_TYPE(ThenFrameNode); +TVM_REGISTER_NODE_TYPE(ElseFrameNode); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc new file mode 100644 index 000000000000..ece645243c82 --- /dev/null +++ b/src/script/ir_builder/relax/ir.cc @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +///////////////////////////////// Vars ////////////////////////////////// + +using tvm::script::ir_builder::details::Namer; + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using tvm::relax::VarNode; + using tvm::relax::IdNode; + const VarNode* var = node.as(); + IdNode* vid = const_cast(var->vid.get()); + vid->name_hint = name; + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using tvm::relax::DataflowVarNode; + using tvm::relax::IdNode; + const DataflowVarNode* var = node.as(); + IdNode* vid = const_cast(var->vid.get()); + vid->name_hint = name; + }); + +/////////////////////////////// Function //////////////////////////////// + +FunctionFrame Function() { + ObjectPtr n = make_object(); + const IRBuilder& ir_builder = IRBuilder::Current(); + Optional mod = NullOpt; + if (const Optional mod_frame = ir_builder->GetLastFrame()) { + mod = tvm::IRModule(mod_frame.value()->functions); + } + n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/mod); + return FunctionFrame(n); +} + +tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_info) { + FunctionFrame frame = FindFunctionFrame("R.Arg"); + tvm::relax::Var var(name, struct_info); + frame->params.push_back(var); + return var; +} + +void FuncName(const String& name) { + FunctionFrame frame = FindFunctionFrame("R.func_name"); + if (frame->name.defined()) { + LOG(FATAL) << "ValueError: Duplicate function name, previous one is: \"" << frame->name.value() + << "\""; + } + frame->name = name; +} + +void FuncAttrs(Map attrs) { + FunctionFrame frame = FindFunctionFrame("R.func_attr"); + if (!frame->attrs.empty()) { + LOG(FATAL) << "ValueError: Duplicate function attrs, previous one is:\n" << frame->attrs; + } + frame->attrs = attrs; +} + +void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) { + FunctionFrame frame = FindFunctionFrame("R.func_ret_struct_info"); + if (frame->ret_struct_info.defined()) { + LOG(FATAL) << "ValueError: Duplicate function return struct info, previous one is:\n " + << frame->ret_struct_info.value(); + } + frame->ret_struct_info = ret_sinfo; +} + +void FuncRetValue(const tvm::relax::Expr& value) { + // Step 0. Normalize the value. + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + tvm::relax::Expr normalized_value = block_builder->Normalize(value); + + // Step 1. The current Relax TVMScript syntax only allows function return appearing at the end of + // a function body. Therefore if there is any unended block frame when dealing with function + // return, we should end the block frame. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + if (block_frame.defined()) { + block_frame.value()->ExitWithScope(); + ICHECK(!IRBuilder::Current()->FindFrame()) + << "All block frame are supposed to be popped out already"; + } + // Step 2. Add the output value to the function frame. + FunctionFrame frame = FindFunctionFrame("return"); + CHECK(!frame->output.defined()) + << "ValueError: Relax functions don't support multiple return statement. Please make sure " + "the return statement appears at the end of function."; + + frame->output = std::move(normalized_value); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetStructInfo").set_body_typed(FuncRetStructInfo); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue); + +///////////////////////////// BindingBlock ////////////////////////////// + +BlockFrame Dataflow() { + ObjectPtr n = make_object(); + n->is_dataflow = true; + n->block_ended = false; + return BlockFrame(n); +} + +BlockFrame BindingBlock() { + ObjectPtr n = make_object(); + n->is_dataflow = false; + n->block_ended = false; + return BlockFrame(n); +} + +void DataflowBlockOutput(const Array& vars) { + // Step 1. Check that we're in a Dataflow block that is not ended. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + CHECK(block_frame.defined() && block_frame.value()->is_dataflow) + << "ValueError: `R.output` should appear inside a dataflow block. However, the current " + "innermost block is not a dataflow block."; + CHECK(!block_frame.value()->block_ended) + << "ValueError: It is not allowed for a dataflow block to have multiple output operation."; + + // Step 2. Mark the block frame ended of construction, so that any followup binding after this + // mark in the dataflow block will lead to an error. + block_frame.value()->block_ended = true; + + // Step 3. All the output variables must be global variables and must be emitted by this dataflow + // block. + const Array& emitted_vars = block_frame.value()->emitted_vars; + for (const tvm::relax::Var& var : vars) { + CHECK(std::find(emitted_vars.begin(), emitted_vars.end(), var) != emitted_vars.end()) + << "ValueError: An output variable is not emitted by this dataflow block. Please make sure " + "all dataflow block output variables are emitted exactly by this block."; + block_frame.value()->output_vars.push_back(var); + } +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Dataflow").set_body_typed(Dataflow); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.BindingBlock").set_body_typed(BindingBlock); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.DataflowBlockOutput") + .set_body_typed(DataflowBlockOutput); + +/////////////////////////////// Bindings /////////////////////////////// + +tvm::relax::Var Emit(const tvm::relax::Expr& expr, + const Optional& annotate_struct_info) { + using tvm::relax::GetStructInfo; + BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + if (annotate_struct_info.defined()) { + const auto& sinfo = annotate_struct_info.value(); + if (!expr->struct_info_.defined()) { + UpdateStructInfo(expr, sinfo); + } else { + CHECK(StructInfoBaseCheck(sinfo, GetStructInfo(expr)) != tvm::relax::BaseCheckResult::kFailL0) + << "Invalid annotation. Got rhs value struct info: " << GetStructInfo(expr) + << ", given struct info: " << sinfo; + } + } + tvm::relax::Var var = block_builder->Emit(expr); + block_frame->emitted_vars.push_back(var); + return var; +} + +tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, + const tvm::relax::StructInfo& struct_info) { + BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + + tvm::relax::Var var = block_builder->EmitMatchCast(value, struct_info); + block_frame->emitted_vars.push_back(var); + return var; +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast); + +///////////////////////////// If Then Else ///////////////////////////// + +IfFrame If(tvm::relax::Expr condition) { + ObjectPtr n = make_object(); + n->condition = condition; + n->then_expr = NullOpt; + n->else_expr = NullOpt; + return IfFrame(n); +} + +ThenFrame Then() { + ObjectPtr n = make_object(); + return ThenFrame(n); +} + +ElseFrame Else() { + ObjectPtr n = make_object(); + return ElseFrame(n); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.If").set_body_typed(If); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Then").set_body_typed(Then); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Else").set_body_typed(Else); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h new file mode 100644 index 000000000000..ae91d05769bd --- /dev/null +++ b/src/script/ir_builder/relax/utils.h @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +inline FunctionFrame FindFunctionFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->FindFrame()) { + return frame.value(); + } + LOG(FATAL) << "ValueError: Function frame not find. Please ensure '" << method + << "' is called under R.function()"; + throw; +} + +inline IfFrame FindIfFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + return frame.value(); + } else { + LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method + << "' is called under R.if_()"; + } + throw; +} + +inline tvm::relax::BlockBuilder GetBlockBuilder() { + Optional frame = IRBuilder::Current()->FindFrame(); + CHECK(frame.defined()) << "ValueError: Relax Function frame not find. Please ensure " + "assignment is called under R.function()"; + return frame.value()->block_builder; +} + +inline BlockFrame CheckBlockFrameExistAndUnended() { + // We check if the current block is "ended" - if a block is ended, it is not allowed to emit new + // bindings into this block, and we should throw exceptions. + + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + CHECK(block_frame.defined()) << "ValueError: Block frame not find"; + CHECK(!block_frame.value()->block_ended) + << "ValueError: New binding is not allowed after dataflow block output."; + return block_frame.value(); +} + +inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String* var_name) { + // Step 0. Check frame type + std::string method; + if (frame->IsInstance()) { + method = "R.Then"; + } else if (frame->IsInstance()) { + method = "R.Else"; + } else { + ICHECK(false) << "TypeError: Unsupported frame type: " << frame->GetTypeKey(); + } + + // Step 1. Check non-empty block and last binding is non-dataflow + CHECK(!frame->binding_blocks.empty()) + << "Empty body is not allowed for '" << method << "' statements."; + const tvm::relax::BindingBlock& last_block = frame->binding_blocks.back(); + CHECK(!last_block->bindings.empty()) << "Blocks are expected to be non-empty."; + + // Step 2. Collect body from the last binding. + tvm::relax::Expr body; + const tvm::relax::Binding& last_binding = last_block->bindings.back(); + if (const auto* var_binding = last_binding.as()) { + CHECK(!var_binding->var->IsInstance()) + << "A non-dataflow var is expected in the last binding of '" << method << "'."; + body = var_binding->value; + *var_name = var_binding->var->name_hint(); + } else if (const auto* match_cast = last_binding.as()) { + CHECK(!match_cast->var->IsInstance()) + << "A non-dataflow var is expected in the last binding of '" << method << "'."; + body = var_binding->value; + *var_name = match_cast->var->name_hint(); + } else { + ICHECK(false) << "TypeError: Unsupported binding type: " << last_binding->GetTypeKey(); + } + + // Step 3. Re-collect binding blocks to remove the last binding. + Array new_blocks(frame->binding_blocks.begin(), + frame->binding_blocks.end() - 1); + Array last_block_bindings(last_block->bindings.begin(), + last_block->bindings.end() - 1); + new_blocks.push_back(tvm::relax::BindingBlock(last_block_bindings)); + + return tvm::relax::SeqExpr(new_blocks, body); +} + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 1e63201a40dd..dd8d3c2ed3f3 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -41,9 +42,17 @@ void PrimFuncFrameNode::ExitWithScope() { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; } else if (Optional opt_frame = builder->FindFrame()) { - ir::IRModuleFrame frame = opt_frame.value(); - frame->global_vars.push_back(GlobalVar(name.value_or(""))); - frame->functions.push_back(func); + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const ir::IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // Case. First time visiting the function. + ir::DeclFunction(func_name, func); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); } else { LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; } diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 485757063867..e8f125adc053 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -81,7 +81,7 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) { * \return The top frame of BlockFrame. */ inline BlockFrame FindBlockFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + if (Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); } LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py new file mode 100644 index 000000000000..12d8b114b862 --- /dev/null +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import relax, tir +from tvm.script.ir_builder import relax as R +from tvm.script.ir_builder.base import IRBuilder + + +def test_function_simple(): + """ + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + out = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + return out + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + R.func_attr({"Primitive": 1}) + x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) + R.func_ret_struct_info(relax.TensorStructInfo(dtype="float32", ndim=2)) + out = R.emit( + R.call_tir("extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32")) + ) + IRBuilder.name("out", out) + R.func_ret_value(out) + func = ir_builder.get() + # create with BlockBuilder + x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,), attrs={"Primitive": 1}): + out = bb.emit( + relax.call_tir("extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32")) + ) + bb.emit_func_output(out) + mod = bb.get() + + tvm.ir.assert_structural_equal(func, mod["foo"]) + # check names + assert func.params[0].name_hint == "x" + assert func.body.body.name_hint == "out" + + +def test_match_cast(): + """ + @R.function + def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): + m = T.var("int64") + n = T.var("int64") + _ = R.match_cast(x, R.Tensor((m,), "float32")) + y1 = R.match_cast(x, R.Tensor((n,), "float32")) + return (m, n * 2) + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + x = R.arg("x", relax.TensorStructInfo(ndim=-1, dtype="float32")) + y = R.arg("y", relax.TensorStructInfo(ndim=-1, dtype="float32")) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + _ = R.emit_match_cast(x, relax.TensorStructInfo((m,), "float32")) + y1 = R.emit_match_cast(y, relax.TensorStructInfo((n,), "float32")) + IRBuilder.name("y1", y1) + R.func_ret_value(relax.ShapeExpr([m, n * 2])) + func = ir_builder.get() + + # create with BlockBuilder + x = relax.Var("x", relax.TensorStructInfo(dtype="float32", ndim=-1)) + y = relax.Var("y", relax.TensorStructInfo(dtype="float32", ndim=-1)) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + _ = bb.match_cast(x, relax.TensorStructInfo((m,), "float32")) + y1 = bb.match_cast(y, relax.TensorStructInfo((n,), "float32")) + bb.emit_func_output(relax.ShapeExpr([m, n * 2])) + mod = bb.get() + + tvm.ir.assert_structural_equal(func, mod["foo"]) + + +def test_dataflow_block(): + """ + @R.function + def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): + # block 0 + with R.dataflow(): + lv0 = R.call_tir("extern_func", (x,), R.Tensor((128, 128), dtype="float32")) + gv: Tensor((128, 128), "float32") = lv0 + R.output(gv) + return gv + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) + with R.dataflow() as df: + lv0 = R.emit( + R.call_tir( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) + IRBuilder.name("lv0", lv0) + gv = R.emit(lv0) + IRBuilder.name("gv", gv) + R.output(gv) + (gv,) = df.output_vars + R.func_ret_value(gv) + func = ir_builder.get() + + # create with BlockBuilder + x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + with bb.dataflow(): + lv0 = bb.emit( + relax.call_tir( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) + gv = bb.emit_output(lv0) + bb.emit_func_output(gv) + + tvm.ir.assert_structural_equal(func, bb.get()["foo"]) + + +def test_regression_py_print(): + # Test that the py_print directs to python builtin print + from tvm.script.ir_builder.relax.ir import py_print # pylint: disable=import-outside-toplevel + + assert py_print == print + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py new file mode 100644 index 000000000000..34b02fdbb8c3 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser.py @@ -0,0 +1,1062 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import pytest +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax, tir, topi +from tvm.relax import DynTensorType +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]] = None, +): + # TODO(relax-team): enable roundtrip testing when printer is ready + # test = parsed.script(show_meta=True) + # roundtrip_mod = tvm.script.parse(test) + # tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_simple_func(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): + R.func_attr({"Primitive": 1}) + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + return gv0 + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,), attrs={"Primitive": 1}): + out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) + bb.emit_func_output(out) + + _check(foo, bb.get()["foo"]) + + +def test_error_report(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + # error: a = b = c is not allowed. + gv0 = gv1 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + return gv0 + + +def test_mismatch_cast_dims_and_ndim(): + with pytest.raises(Exception): + + @R.function + def f( + x: R.Tensor((2, 3), "float32", ndim=3) + ): # error: ndim and the shape dims are mismatch + return x + + +def test_unexpected_num_kw_args(): + with pytest.raises(Exception): + + @R.function + def f(x: R.Tensor(dtype="float32", ndim=1, foo=2)): # error: unexpected kw args foo + return x + + +def test_unexpected_ndim(): + with pytest.raises(Exception): + + @R.function + # error: dim is expected to be non-negative int or -1 for unknown + def f(x: R.Tensor(dtype="float32", ndim=-2)): + return x + + +def test_unexpected_ndim_type(): + with pytest.raises(Exception): + + @R.function + def f(x: R.Tensor(dtype="float32", ndim="1")): # error: dim is expected to be int + return x + + +def test_unexpected_tir_cast_args(): + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: R.Tensor(("m",), "float32")): + m = T.var("int64") + # tir.cast expects 2 arguments, but got 3 + return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),), dtype="float32")) + + +def test_unexpected_tir_max_args(): + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: R.Tensor(("m", "n"), "float32")): + m = T.var("int64") + # tir.max expects 2 arguments, but got 1 + return relax.call_tir("foo", (x,), R.Tensor((T.max(m),), dtype="float32")) + + +def test_func_type_annotation_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x, y): # error: the parameter type annotation is missing + z = R.add(x, y) + y = z + return y + + +def test_if_mismatch_var_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + z = R.add(w, w) # error: The binding var is expected to `y` + return z + + +def test_unassigned_call_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: R.Tensor): + R.add(x, x) + return x + + +def test_simple_module(): + @I.ir_module + class TestModule: + @T.prim_func + def tir_func( + x: T.Buffer((T.int64(128), T.int64(128)), "float32"), + y: T.Buffer((T.int64(128), T.int64(128)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i, j in T.grid(T.int64(128), T.int64(128)): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + y[vi, vj] = x[vi, vj] + 1.0 + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): + # TODO(Siyuan): Need to change to `TestModule.tir_func` + gv0 = R.call_tir(tir_func, x, R.Tensor((128, 128), dtype="float32")) + return gv0 + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") + bb.emit_func_output(out) + + _check(TestModule, bb.get()) + + +def test_relax_tensor_op(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"): + y = R.add(x, x) + z = R.multiply(x, y) + return z + + x = relax.Var("x", R.Tensor((4, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + y = bb.emit(relax.op.add(x, x)) + z = bb.emit(relax.op.multiply(x, y)) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_symbolic_shape(): + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64", "m") + n = T.var("int64", "n") + gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) + return gv0 + + @R.function + def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) + return gv0 + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): + m = T.var("int64") + n = T.var("int32") # The shape dtype should be int64 + gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) + return gv0 + + def _expected(name: str): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = relax.Var("x", R.Tensor([m, n], "float32")) + bb = relax.BlockBuilder() + with bb.function(name, (x,)): + out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32"))) + bb.emit_func_output(out) + return bb.get()[name] + + _check(foo, _expected("foo")) + _check(bar, _expected("bar")) + + +def test_shadowing(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + y = R.add(x, x) + z = R.multiply(x, y) + y = R.add(x, y) + y = z + y = R.multiply(y, x) + z = y + return z + + x = relax.Var("x", R.Tensor((4, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + y = bb.emit(relax.op.add(x, x)) + z = bb.emit(relax.op.multiply(x, y)) + y = bb.emit(relax.op.add(x, y)) + y = bb.emit(z) + y = bb.emit(relax.op.multiply(y, x)) + z = bb.emit(y) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_match_cast(): + @R.function + def foo(x: R.Tensor("float32"), y: R.Tensor("float32")): + m = T.var("int64") + n = T.var("int64") + x0 = R.match_cast(x, R.Tensor([m], "float32")) + with R.dataflow(): + y0 = R.match_cast(y, R.Tensor([n], "float32")) + gv = y0 + R.output(gv) + return (x0, (m, n * 2)) + + x = relax.Var("x", R.Tensor("float32")) + y = relax.Var("y", R.Tensor("float32")) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + y2 = relax.Var("y", R.Tensor([n], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + x0 = bb.match_cast(x, R.Tensor([m], "float32")) + with bb.dataflow(): + y0 = bb.match_cast(y, R.Tensor([n], "float32")) + bb.emit_output(y0) + bb.emit_func_output(relax.Tuple([x0, relax.ShapeExpr([m, n * 2])])) + + _check(foo, bb.get()["foo"]) + + +def test_tuple_return(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + gv0 = R.call_tir("extern_func_0", x, R.Tensor((4, 4), dtype="float32")) + gv1 = R.call_tir("extern_func_1", x, R.Tensor((4, 4), dtype="float32")) + return (gv0, gv1) + + x = relax.Var("x", R.Tensor((4, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + gv0 = bb.emit(relax.call_tir("extern_func_0", x, R.Tensor((4, 4), dtype="float32"))) + gv1 = bb.emit(relax.call_tir("extern_func_1", x, R.Tensor((4, 4), dtype="float32"))) + bb.emit_func_output(relax.Tuple((gv0, gv1))) + + _check(foo, bb.get()["foo"]) + + +def test_tuple_return_2(): + @R.function + def foo(x: R.Tensor("float32", ndim=2)): + n, m = T.var("int64"), T.var("int64") + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) + return (x0, (n + 1, m, 1)) + + x = relax.Var("x", R.Tensor("float32", ndim=2)) + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + x0 = bb.match_cast(x, R.Tensor((n, m), "float32")) + bb.emit_func_output(relax.Tuple([x0, relax.ShapeExpr([n + 1, m, 1])])) + + _check(foo, bb.get()["foo"]) + + +def test_tuple_binding(): + @R.function + def foo(x: R.Tensor("float32", ndim=2)): + n, m = T.var("int64"), T.var("int64") + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) + t0 = (x, x0) + t1 = (x, (n, m), t0) + return t1 + + x = relax.Var("x", R.Tensor("float32", ndim=2)) + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + x0 = bb.match_cast(x, R.Tensor((n, m), "float32")) + t0 = bb.emit(relax.Tuple([x, x0])) + t1 = bb.emit(relax.Tuple([x, relax.ShapeExpr([n, m]), t0])) + bb.emit_func_output(t1) + + _check(foo, bb.get()["foo"]) + + +def test_tuple_get_item(): + @R.function + def foo(x: R.Tensor, y: R.Tensor): + t1 = R.tuple(x, y) + t2 = (x, y) + a = t1[0] + b = R.TupleGetItem(t2, 1) + c = R.add(a, b) + return c + + x = relax.Var("x", R.Tensor()) + y = relax.Var("y", R.Tensor()) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + t1 = bb.emit(relax.Tuple([x, y])) + t2 = bb.emit(relax.Tuple([x, y])) + a = bb.emit(relax.TupleGetItem(t1, 0)) + b = bb.emit(relax.TupleGetItem(t2, 1)) + c = bb.emit(relax.op.add(a, b)) + bb.emit_func_output(c) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_block(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + lv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + lv1 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + gv = lv1 + R.output(gv) + return gv + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + with bb.dataflow(): + lv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) + lv1 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))) + gv = bb.emit_output(lv1) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_block_advanced(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) + with R.dataflow(): + m = T.var("int64") + n = T.var("int64") + lv0 = R.call_tir("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) + lv1 = R.match_cast(lv0, R.Tensor((m, n), "float32")) + gv2 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + gv2 = R.call_tir("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) + gv3 = R.match_cast(gv2, R.Tensor((m, n), "float32")) + gv3 = R.match_cast(lv0, R.Tensor((m, n), "float32")) + gv4 = gv3 + gv5 = gv2 + R.output(gv5, gv4) + gv6 = R.call_tir("extern_func", gv5, R.Tensor((128, 128), dtype="float32")) + gv7 = R.call_tir("extern_func", gv6, R.Tensor((128, 128), dtype="float32")) + return gv7 + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + with bb.function("foo", (x,)): + gv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) + gv1 = bb.emit(relax.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32"))) + with bb.dataflow(): + lv0 = bb.emit(relax.call_tir("extern_func", gv1, R.Tensor((128, 128), dtype="float32"))) + lv1 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) + gv2 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))) + gv21 = bb.emit( + relax.call_tir("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) + ) + gv3 = bb.match_cast(gv21, R.Tensor((m, n), "float32")) + gv31 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) + gv32 = bb.emit_output(gv31) + gv22 = bb.emit_output(gv21) + gv4 = bb.emit(relax.call_tir("extern_func", gv22, R.Tensor((128, 128), dtype="float32"))) + gv5 = bb.emit(relax.call_tir("extern_func", gv4, R.Tensor((128, 128), dtype="float32"))) + bb.emit_func_output(gv5) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_binding_after_output(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + R.output(gv) + lv = R.call_tir("extern_func", gv, R.Tensor((128, 128), dtype="float32")) + return gv + + +def test_dataflow_output_global_var(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + with R.dataflow(): + gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) + R.output(gv0, gv1) + return gv1 + + +def test_dataflow_multiple_output(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + R.output(gv) + R.output(gv) + return gv + + +def test_dataflow_output_outside_dataflow_block(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + R.output(gv) + return gv + + +def test_dataflow_scope_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: R.Tensor(ndim=2)): + with R.dataflow(): + y = R.add(x, x) + z = R.multiply(y, x) + w = R.add(z, x) + R.output(y, w) + t = R.multiply(y, z) # z is not in the outer scope + return t + + +def test_return_without_binding(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + return x + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + bb.emit_func_output(x) + + _check(foo, bb.get()["foo"]) + + +def test_multiple_return(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + return x + return x + + +def test_function_without_return(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + + +def test_tensor_type_without_args(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + v = R.call_tir("tir_relu", x, R.Tensor((32, 32), dtype="float32")) + return v + + x = relax.Var("x", R.Tensor((32, 32), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + v = bb.emit(relax.call_tir("tir_relu", x, R.Tensor((32, 32), dtype="float32"))) + bb.emit_func_output(v) + + _check(foo, bb.get()["foo"]) + + +def test_direct_return(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + return x + + x = relax.Var("x", R.Tensor((32, 32), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + bb.emit_func_output(x) + + _check(foo, bb.get()["foo"]) + + +def test_call_packed(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + z = R.call_packed("vm.builtin.copy", x, sinfo_args=R.Tensor((32, 32), "float32")) + return z + + x = relax.Var("x", R.Tensor((32, 32), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + z = bb.emit( + relax.Call( + relax.ExternFunc("vm.builtin.copy"), + (x,), + None, + sinfo_args=[R.Tensor((32, 32), "float32")], + ) + ) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_annotation(): + @R.function + def foo( + x: R.Tensor((32, "m"), "float32"), + y: R.Tensor(("m",), "float32"), + r: R.Tensor(dtype="int64"), + ) -> R.Object: + m = T.var("int64", "m") + z: R.Tensor((32, m), "float32") = R.multiply(x, y) + w: R.Tensor = R.multiply(z, z) + q: R.Tensor(ndim=2) = R.add(w, w) + t = R.add(w, z) + sh: R.Shape = R.call_packed("shape_of", x, sinfo_args=R.Shape) + o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, sinfo_args=R.Object) + return o + + def _check_struct_info(binding, expected_sinfo): + tvm.ir.assert_structural_equal(binding.var.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(binding.value.struct_info, expected_sinfo) + + # Cannot use block builder here because we need to check the annotated type, + # which may be inconsistent with deduced type. + assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo) + m = relax.get_shape_of(foo.params[0])[1] + bindings = foo.body.blocks[0].bindings + + _check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32")) + _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=-1)) + _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=2)) + _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=-1)) + _check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1)) + _check_struct_info(bindings[5], relax.ObjectStructInfo()) + + +def test_annotate_override(): + @R.function + def foo(x: R.Tensor): + y = x + # z will be treated as object type even though it's a tensor + z: R.Object = R.add(x, y) + return z + + assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo) + y_bind, z_bind = foo.body.blocks[0].bindings + assert isinstance(y_bind.var.struct_info, relax.TensorStructInfo) + assert isinstance(z_bind.var.struct_info, relax.ObjectStructInfo) + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def test(x: R.Tensor): + # Error: x is of Tensor StructInfo, which can not annotate to R.Shape. + z: R.Shape = x + return z + + @R.function + def bar(x: R.Tensor): + # x is of Tensor StructInfo, the annotation of `z` is ignored. + z: R.Object = x + return z + + assert isinstance(bar.ret_struct_info, relax.TensorStructInfo) + (z_bind,) = bar.body.blocks[0].bindings + assert isinstance(z_bind.var.struct_info, relax.TensorStructInfo) + + +def test_call_tir_empty_shape(): + @R.function + def foo(x: R.Tensor((), "float32")): + z = R.call_tir("scalar_add", x, R.Tensor((), dtype="float32")) + return z + + (z_bind,) = foo.body.blocks[0].bindings + shape_expr = z_bind.value.sinfo_args[0].shape + + assert isinstance(shape_expr, relax.ShapeExpr) + assert len(shape_expr.values) == 0 + + +def test_call_tir_empty_tuple_arg(): + bb = relax.BlockBuilder() + dummy_param = relax.Var("dummy_param", R.Tensor(())) + with bb.function("foo", [dummy_param]): + output = bb.emit_te(topi.full, shape=(16, 32), dtype="float32", fill_value=1.0) + bb.emit_func_output(output) + + _check(bb.get()) + + +def test_call_tir_with_tir_var(): + @I.ir_module + class Module: + @R.function + def main( + dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2", "float32")) + ) -> R.Tensor(("n * 2",), "float32"): + n = T.var("int64") + y = R.call_tir(copy, (x,), R.Tensor(((n * 2,)), dtype="float32"), tir_vars=(n,)) + return y + + @T.prim_func + def copy(var_x: T.handle, var_y: T.handle, n: T.int64): + X = T.match_buffer(var_x, (n * 2,), dtype="float32") + Y = T.match_buffer(var_y, (n * 2,), dtype="float32") + for i in T.grid(n * 2): + with T.block("block"): + vi = T.axis.remap("S", [i]) + Y[vi] = X[vi] + + _check(Module) + + +def test_local_function(): + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + @R.function + def outer_func( + c1: R.Tensor((2, 3), "float32") + ) -> R.Callable((R.Tensor(None, "float32", ndim=2),), R.Tensor(None, "float32", ndim=2)): + @R.function + def inner_func(x1: R.Tensor((2, 3), "float32")): + s: R.Tensor((2, 3), "float32") = R.add(x1, c1) + return s + + return inner_func + + in_call = outer_func(x) + res = in_call(y) + return res + + main_bindings = main.body.blocks[0].bindings + assert len(main_bindings) == 3 + outer_func = main_bindings[0].value + assert isinstance(outer_func, relax.Function) + + outer_func_bindings = outer_func.body.blocks[0].bindings + assert len(outer_func_bindings) == 1 + inner_func = outer_func_bindings[0].value + assert isinstance(inner_func, relax.Function) + + @I.ir_module + class TestModule: + @R.function + def f(x: R.Tensor((128, 128), "float32"), y: R.Tensor((128, 128), "float32")): + @T.prim_func + def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + + z = relax.call_tir(my_matmul, (x, y), R.Tensor((128, 128), dtype="float32")) + return z + + bindings = TestModule["f"].body.blocks[0].bindings + assert len(bindings) == 2 + tir_func = bindings[0].value + assert isinstance(tir_func, tir.PrimFunc) + + +def test_cross_function_call(): + @I.ir_module + class Mod0: + @R.function + def foo(x: R.Tensor((10, 5), "float32")): + s = R.add(x, x) + return s + + @R.function + def main(x: R.Tensor((10, 5), "float32")): + inner = foo + gv1 = inner(x) + gv2 = foo(x) + return (inner, gv1, gv2) + + @I.ir_module + class Mod1: + @R.function + def main(x: R.Tensor((10, 5), "float32")): + inner = foo + gv1 = inner(x) + gv2 = foo(x) + return (inner, gv1, gv2) + + @R.function + def foo(x: R.Tensor((10, 5), "float32")) -> R.Tensor((10, 5), "float32"): + s = R.add(x, x) + return s + + +def test_if_branch(): + @R.function + def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")) -> R.Tensor((1,), "float32"): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + return y + + cond, x = foo.params + y_bind = foo.body.blocks[0].bindings[0] + y, ite = y_bind.var, y_bind.value + + assert isinstance(y, relax.Var) + assert y.name_hint == "y" + + assert isinstance(ite, relax.If) + assert isinstance(ite.true_branch, relax.SeqExpr) + assert isinstance(ite.false_branch, relax.SeqExpr) + + def check_call(call, op, args): + assert isinstance(call, relax.Call) + if isinstance(op, str): + assert call.op.name == op + else: + assert call.op == op + tvm.ir.assert_structural_equal(call.args, args) + + w_bind = ite.true_branch.blocks[0].bindings[0] + # the seq exprts in the branches are normalized to bind any call + # in the seq expr "body" to a var + y_bind = ite.true_branch.blocks[-1].bindings[-1] + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "relax.add", [x, x]) + check_call(y_bind.value, "relax.multiply", [w_bind.var, w_bind.var]) + + w_bind = ite.false_branch.blocks[0].bindings[0] + y_bind = ite.false_branch.blocks[-1].bindings[-1] + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "relax.multiply", [x, x]) + check_call(y_bind.value, "relax.add", [w_bind.var, w_bind.var]) + + +def test_if_inside_dataflow(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): + with R.dataflow(): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + R.output(y) + return y + + +def test_var_if_scoping_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + return w # error: The w is not defined in the outer scope + + +def test_if_branch_var_scope(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + return w + + +def test_erase_to_well_defined(): + @R.function + def foo(x: R.Tensor): + q = x + m, n = T.var("int64"), T.var("int64") + z = R.match_cast(q, R.Tensor((m, n))) + w = z + return w + + tvm.ir.assert_structural_equal(foo.ret_struct_info, R.Tensor(ndim=2)) + _check(foo) + + +def test_empty_tuple(): + @R.function + def foo(x: R.Tuple()): + y: R.Tuple() = R.tuple() + return y + + x = relax.Var("x", relax.TupleStructInfo([])) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + y = bb.emit(relax.Tuple([])) + bb.emit_func_output(y) + + _check(foo, bb.get()["foo"]) + + +def test_symbolic_shape_computing(): + # Tensor Case 1 + @R.function + def foo(x: R.Tensor(("m + 1",), "float32"), y: R.Tensor(("m", 1), "float32")): + z = R.add(x, y) + return z + + m = tir.Var("m", "int64") + x = relax.Var("x", relax.TensorStructInfo([m + 1], "float32")) + y = relax.Var("y", relax.TensorStructInfo([m, 1], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + z = bb.emit(relax.op.add(x, y)) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + # Tensor Case 2 + @R.function + def bar( + x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32") + ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"): + m = T.var("int64") + z = R.call_tir("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32")) + return z + + m = tir.Var("m", "int64") + x = relax.Var("x", relax.TensorStructInfo([m], "float32")) + y = relax.Var("y", relax.TensorStructInfo([tir.max(m, 20)], "float32")) + bb = relax.BlockBuilder() + with bb.function("bar", (x, y)): + z = bb.emit( + relax.call_tir("test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 1,), dtype="float32")) + ) + bb.emit_func_output(z) + + _check(bar, bb.get()["bar"]) + + # Shape Case + @R.function + def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): + m = T.var("int64") + z = R.call_tir("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) + return z + + m = tir.Var("m", "int64") + x = relax.Var("x", relax.ShapeStructInfo([m])) + y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) + bb = relax.BlockBuilder() + with bb.function("baz", (x, y)): + z = bb.emit(relax.call_tir("test_intrin", (y), R.Tensor((m * 2,), dtype="float32"))) + bb.emit_func_output(z) + + _check(baz, bb.get()["baz"]) + + # Error Case + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor(("m + 1", "m * 2"), "float32")): # name 'm' is not defined + z = R.add(x, x) + return z + + +# TODO(relax-team): enable this when vm ops are ready +@pytest.mark.xfail +def test_vm_ops(): + @R.function + def foo(x: R.Tensor(("m", "n"), dtype="float32")): + m = T.var("int64") + n = T.var("int64") + storage = R.vm.alloc_storage((4 * m * n,), dtype="float32", runtime_device_index=0) + alloc = R.vm.alloc_tensor(storage, (m, n), offset=0, dtype="float32") + tensor = R.builtin.alloc_tensor((m, n), dtype="float32", runtime_device_index=0) + _ = R.vm.call_tir_dyn("te_func", (x, tensor, (m, n))) + gv = tensor + return alloc, gv + + +def test_prim_value(): + @R.function + def foo(): + gv = R.call_packed("test", 1, sinfo_args=R.Tensor((32, 32), "float32")) + return gv + + _check(foo) + + +def test_string_imm(): + @R.function + def foo(): + gv = R.call_packed("test", "hello", sinfo_args=R.Tensor((32, 32), "float32")) + return gv + + _check(foo) + + +def test_datatype_imm(): + @R.function + def foo(): + gv = R.call_packed("test", R.dtype("float32"), sinfo_args=R.Tensor((32, 32), "float32")) + return gv + + _check(foo) + + +def test_function_void_return_type(): + @tvm.script.ir_module + class Foo: + @R.function + def main(x: R.Tensor((3, 3), dtype="float32")): + res = mul(x) + return res + + @R.function + def mul(x: R.Tensor((3, 3), dtype="float32")): + res = R.multiply(x, x) + return res + + _check(Foo) + # Since the return type of function `mul` is not annotated, + # the function `main` regards it as a generic return type. + assert isinstance(Foo["main"].ret_struct_info, relax.ObjectStructInfo) + assert isinstance(Foo["mul"].ret_struct_info, relax.TensorStructInfo) + + @tvm.script.ir_module + class Bar: + @R.function + def main(x1: R.Tensor((3, 3), dtype="float32")): + res1 = mul(x1) + return res1 + + @R.function + def mul(x: R.Tensor((3, 3), dtype="float32")) -> None: + res = R.multiply(x, x) + return res + + # Since the return type of function `mul` is not annotated, + # the function `main` regards it as a generic return type. + _check(Bar) + tvm.ir.assert_structural_equal(Bar["main"].ret_struct_info, relax.TupleStructInfo([])) + tvm.ir.assert_structural_equal(Bar["mul"].ret_struct_info, relax.TupleStructInfo([])) + + +def test_class_normalize(): + @tvm.script.ir_module + class InputModule: + @R.function + def mul_add(x: R.Tensor) -> R.Tensor: + return R.multiply(R.add(x, x), R.add(x, x)) + + # The parser automatically normalizes the input AST to the following ANF form + @tvm.script.ir_module + class OutputModule: + @R.function + def mul_add(x: R.Tensor) -> R.Tensor: + gv = R.add(x, x) + gv1 = R.add(x, x) + return R.multiply(gv, gv1) + + _check(InputModule, OutputModule) + + +if __name__ == "__main__": + test_cross_function_call() + tvm.testing.main()