From 9efbd96288a0cedb67ff0db8a637382f6c2a1418 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Sun, 17 Oct 2021 14:58:28 -0700 Subject: [PATCH] Redesign IRBuilder to BlockBuilder (#22) * init * update * update * test case working * update and add multi block test case * check in * fixes * fix * update * add * update * add * update * address comments. Co-authored-by: Altan Haan --- include/tvm/relax/block_builder.h | 204 ++++++++ include/tvm/relax/expr_functor.h | 44 +- include/tvm/relax/ir_builder.h | 305 ------------ python/tvm/relax/__init__.py | 4 +- .../relax/{ir_builder.py => block_builder.py} | 118 +++-- src/printer/relax_script_printer.cc | 14 +- src/relax/ir/block_builder.cc | 294 +++++++++++ src/relax/ir/expr_functor.cc | 254 +++++----- src/relax/ir/ir_builder.cc | 460 ------------------ src/relax/transform/fma_rewrite.cc | 25 +- src/relax/transform/memory_rewrite.cc | 42 +- src/relax/transform/shape_lower.cc | 139 ++---- tests/python/relax/test_analysis.py | 75 +-- ...test_irbuilder.py => test_blockbuilder.py} | 93 ++-- tests/python/relax/test_transform.py | 35 +- 15 files changed, 935 insertions(+), 1171 deletions(-) create mode 100644 include/tvm/relax/block_builder.h delete mode 100644 include/tvm/relax/ir_builder.h rename python/tvm/relax/{ir_builder.py => block_builder.py} (65%) create mode 100644 src/relax/ir/block_builder.cc delete mode 100644 src/relax/ir/ir_builder.cc rename tests/python/relax/{test_irbuilder.py => test_blockbuilder.py} (80%) diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h new file mode 100644 index 000000000000..8cf2677b9b31 --- /dev/null +++ b/include/tvm/relax/block_builder.h @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/block_builder.h + * \brief The utility for constructing Relax binding blocks. + */ +#ifndef TVM_RELAX_BLOCK_BUILDER_H_ +#define TVM_RELAX_BLOCK_BUILDER_H_ + +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +class BlockBuilder; + +/*! + * \brief Utility data structure for generating unique names for IR construction. + */ +class NameTable { + public: + /*! + * \brief Generate a unique name with a specified prefix. + * \param prefix The name prefix. + * \return The generated name. + */ + inline std::string GetUniqueName(std::string prefix) { + std::replace(prefix.begin(), prefix.end(), '.', '_'); + std::string unique_prefix = prefix; + auto it = alloc_map_.find(prefix); + if (it != alloc_map_.end()) { + while (alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) { + } + } + alloc_map_[unique_prefix] = 0; + return unique_prefix; + } + + private: + std::unordered_map alloc_map_; +}; + +/*! + * \brief A builder that provides APIs to build Relax binding blocks. + */ +class BlockBuilderNode : public Object { + public: + BlockBuilderNode(std::shared_ptr name_table) : name_table_(name_table) {} + + ~BlockBuilderNode(); + + BlockBuilderNode() { name_table_ = std::make_shared(); } + + /*! \brief Begin to build a DataflowBlock. */ + void BeginDataflowBlock(); + /*! \brief Begin to build a BindingBlock. */ + void BeginBindingBlock(); + /*! + * \brief End building a BindingBlock. + * \return The BindingBlock being built. + */ + BindingBlock EndBlock(); + /*! + * \brief Check if the block being built is DataflowBlock or not. + * \return A boolean that indicates if the block being built is DataflowBlock or not. + */ + inline bool CurrentBlockIsDataFlow() { return CurrentFrame()->is_dataflow; } + /*! + * \brief Emits an Expr, and returns the variable it is bound to. + * \param expr The Expr to be emitted. + * \param name_hint Name hint for the bound variable. + * \return The new variable that \p expr is bound to. + */ + virtual Var Emit(const Expr& expr, std::string name_hint = ""); + /*! + * \brief Emits a variable binding, and returns the bound Var. + * \param binding The variable binding. + * \return The bound variable. + */ + virtual Var Emit(const VarBinding& binding); + /*! + * \brief Emit a MatchShape. + * \param value The value of the MatchShape to be emitted. + * \param pattern The pattern of the MatchShape to be emitted. + * \param name_hint Name hint for the bound variable. + * \return The variable bound to the MatchShape. + */ + Var EmitMatchShape(const Expr& value, const Array& pattern, std::string name_hint = ""); + /*! + * \brief Emit a MatchShape binding. + * \param binding The MatchShape binding to be emitted. + * \return The variable bound to the MatchShape. + */ + Var EmitMatchShape(const MatchShape& binding); + /*! + * \brief Generate an output for the current dataflow block. + * \param output The output variable of the block. + * \param name_hint Name hint for the bound variable. + * \return The variable bound to \p output. + */ + Var EmitOutput(const Expr& output, std::string name_hint = ""); + /*! + * \brief Generate an output for the current dataflow block. + * \param binding The output binding to output. + * \return The variable bound to \p output. + */ + Var EmitOutput(const VarBinding& binding); + /*! + * \brief Lookup a var in the binding table \p var_map_. + * \param var The input var. + * \return The Expr bound to the input \p var. + */ + Expr LookupVar(const Var& var); + /*! + * \brief Check if two shape expressions can be proven equal at compile time. + * \param lhs The input lhs shape. + * \param rhs The input rhs shape. + * \return Whether we can prove lhs shape is the same as the rhs shape. + */ + bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs); + /*! + * \brief Normalize an Expr to complete its shape and type. + * \param expr The input expr. + * \return The expr with normalized shape and type. + */ + Expr Normalize(const Expr& expr); + /*! + * \brief Create a BlockBuilder. + * \return The created BlockBuilder. + */ + TVM_DLL static BlockBuilder Create(); + + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.BlockBuilder"; + TVM_DECLARE_BASE_OBJECT_INFO(BlockBuilderNode, Object); + + private: + Var Emit(const Expr& expr, bool is_dataflow, std::string name_hint); + + protected: + /*! + * \brief A representation of a block frame. + * + * A block frame is a record containing the bindings needed + * to build a binding block, and a boolean to indicate if the + * block being built is a DataflowBlock or not. + */ + struct BlockFrame { + Array bindings; + bool is_dataflow; + }; + friend class BlockBuilder; + /*! + * \brief Get the current block frame. + * \return The current block frame. + */ + BlockFrame* CurrentFrame(); + /*! \brief A stack to store block frames. */ + std::stack block_stack_; + /*! \brief A diagnostic context for reporting errors. */ + DiagnosticContext diag_ctx_ = DiagnosticContext::Default(IRModule({}, {})); + /*! \brief A binding table that maps var to value. */ + // TODO(@yuchen, @altanh): make var_map_ scoped, and decide if it should be in the builder + std::unordered_map var_map_; + /*! \brief A name table to get unique names for IR construction. */ + std::shared_ptr name_table_; +}; + +class BlockBuilder : public ObjectRef { + public: + TVM_DLL explicit BlockBuilder(std::shared_ptr name_table); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BLOCK_BUILDER_H_ diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index fc7d6f0a5229..6956eb7f368c 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -27,8 +27,8 @@ #include #include +#include #include -#include #include #include #include @@ -167,6 +167,9 @@ class ExprVisitor : public ExprFunctor { virtual void VisitMatchShape(const MatchShape& binding); virtual void VisitBindingBlock(const BindingBlock& block); virtual void VisitDataflowBlock(const DataflowBlock& block); + + protected: + std::unordered_map visit_counter_; }; void PostOrderVisit(const Expr& node, std::function fvisit); @@ -180,11 +183,22 @@ void PostOrderVisit(const Expr& node, std::function fvisit); */ class ExprMutator : public ExprFunctor { public: + ExprMutator() { + name_table_ = std::make_shared(); + builder_ = BlockBuilder(name_table_); + } + /*! * \brief Mutate is alias for VisitExpr * \return expr. */ - Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); } + Expr Mutate(const Expr& expr) { + if (memo_.count(expr) == 0) { + memo_[expr] = this->VisitExpr(expr); + } + return Downcast(memo_[expr]); + } + Expr VisitExpr(const Expr& expr) override; Expr VisitExpr_(const ConstantNode* op) override; Expr VisitExpr_(const TupleNode* op) override; @@ -208,28 +222,32 @@ class ExprMutator : public ExprFunctor { * visitor for types which transform them appropriately. */ virtual Type VisitType(const Type& t); - virtual void VisitBinding(const Binding& binding, IRBuilder& builder); - virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& builder); - virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& builder); + + virtual void VisitBinding(const Binding& binding); + virtual Var VisitVarBinding(const VarBinding& binding); + virtual void VisitMatchShape(const MatchShape& binding); virtual BindingBlock VisitBindingBlock(const BindingBlock& block); virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block); protected: - IRBuilder builder_; + Expr MutateWithPrologue(const Expr& expr, bool is_dataflow); + /*! \brief Look up the value binded to a var. */ + Expr LookupVar(Var var); + // A remapping table: pre var -> post var + std::unordered_map var_remap_; + std::unordered_map memo_; + std::shared_ptr name_table_; + BlockBuilder builder_; }; +// TODO(@yuchen, @altan): Refactor to enforce dataflow mutator only rewrite stuff in dataflow blocks /*! \brief Dataflow Graph Rewriting for Custom Rewriting Passes */ class DataflowMutator : public ExprMutator { public: - virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block); - virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& builder); + void VisitBinding(const Binding& binding) final; - protected: - /*! \brief Look up the value binded to a var. */ - Expr LookupVar(Var var); - // A remapping table: pre var -> post var - std::unordered_map pre_post_var_map_; + virtual Var VisitDataflowVarBinding(const VarBinding& binding); }; } // namespace relax diff --git a/include/tvm/relax/ir_builder.h b/include/tvm/relax/ir_builder.h deleted file mode 100644 index 9d3ec6e68953..000000000000 --- a/include/tvm/relax/ir_builder.h +++ /dev/null @@ -1,305 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/relax/ir_builder.h - * \brief The utility for constructing Relax AST. - */ -#ifndef TVM_RELAX_IR_BUILDER_H_ -#define TVM_RELAX_IR_BUILDER_H_ - -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace relax { - -using relay::Call; - -class IRBuilder; -class LazyIRBuilder; - -/*! - * \brief The state of Relax function node being built. - */ -struct RelaxFunction { - /*! \brief The function name. */ - Optional func_name = NullOpt; - /*! \brief The function parameters. */ - Array params; - /*! \brief The bindings in the function. */ - std::vector bindings; - /*! \brief The binding blocks in the function. */ - std::vector binding_blocks; - /*! \brief The return of the function. */ - Expr ret = Tuple(); - /*! \brief The FunctionNode being built. */ - Function func; -}; - -/*! - * \brief A builder that provides APIs to build Relax AST. - */ -class IRBuilderNode : public Object { - public: - /*! - * \brief Fill the function name and parameters. - */ - void FillFuncNameParam(const Array& params, const std::string& func_name); - /*! - * \brief Build a function node. - */ - void BuildFunction(); - /*! - * \brief Build a binding block. - */ - virtual void BuildBlock(); - /*! - * \brief Emit a Call, and return a newly created Var binded to the Call. - * \param call The Call to be emitted. - * \return The variable being created and binded to \p call. - */ - virtual Var Emit(const Call& call); - /*! - * \brief Emit a var binding. - * \param binding The VarBinding to be emitted. - * \return The VarNode of the VarBinding \p binding. - */ - virtual Var Emit(const VarBinding& binding); - /*! - * \brief Emit a Call, and bind it to a Var. - * \param var The Var to be binded with. \p var is reused implicitly if the shape - * and type of \p call matches \p var. Otherwise a new Var is created. - * \param call The Call to be emitted. - * \return The Var to be binded with \p var. - */ - virtual Var Emit(const Var& var, const Call& call); - /*! - * \brief Emit a MatchShape. - * \param value The value of the MatchShape to be emitted. - * \param pattern The pattern of the MatchShape to be emitted. - * \return The variable being binded to the MatchShape. - */ - Var EmitMatchShape(const Expr& value, const Array& pattern); - /*! - * \brief Generate an output for the current dataflow block or function. - * \param output The output variable of the block/function. - * \return The variable being binded to \p output. - */ - Var EmitOutput(const Expr& output); - /*! - * \brief Lookup a var in the binding table \p var_map_. - */ - Expr LookupVar(const Var& var); - /*! - * \brief Get the function being built. - */ - Function Get(); - /*! - * \brief Get binding blocks being built. - */ - std::vector GetBlocks(); - /*! - * \brief Check if two shape expressions can be proven equal at compile time. - * \param lhs The input lhs shape. - * \param rhs The input rhs shape. - * \return Whether we can prove lhs shape == rhs shape. - */ - bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs); - /*! - * \brief Normalize an Expr to complete its shape and type. - * \param expr The input expr. - * \return The expr with normalized shape and type. - */ - Expr Normalize(const Expr& expr); - /*! - * \brief Create a IRBuilder. - * \return The created IRBuilder. - */ - TVM_DLL static IRBuilder Create(); - - /*! \brief A flag tracking if currently inside a dataflow block or not. */ - bool is_dataflow_ = false; - - void VisitAttrs(AttrVisitor* v) {} - - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; - static constexpr const char* _type_key = "relax.IRBuilder"; - TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderNode, Object); - - protected: - /*! \brief The state of the function currently being built. */ - RelaxFunction func_; - /*! \brief A global variable counter for naming global variables. */ - int global_var_counter_ = 0; - /*! \brief A dataflow variable counter for naming dataflow variables. */ - int dataflow_var_counter_ = 0; - /*! \brief A diagnostic context for reporting errors. */ - DiagnosticContext diag_ctx_ = DiagnosticContext::Default(IRModule({}, {})); - /*! \brief A binding table that maps var to value. */ - std::unordered_map var_map_; -}; - -class IRBuilder : public ObjectRef { - public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode); -}; - -/*! \brief Auxiliary scope for building Relax function node, - * similar to python's with syntax. - * - * \code - * { - * With scope(ir_builder); - * // build function node. - * } - */ -class FunctionScopeNode : public Object { - public: - IRBuilder ir_builder; - void VisitAttrs(AttrVisitor* v) { v->Visit("ir_builder", &ir_builder); } - - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; - static constexpr const char* _type_key = "relax.FunctionScope"; - TVM_DECLARE_FINAL_OBJECT_INFO(FunctionScopeNode, Object); -}; - -class FunctionScope : public ObjectRef { - public: - TVM_DLL FunctionScope(IRBuilder ib); - TVM_DEFINE_OBJECT_REF_METHODS(FunctionScope, ObjectRef, FunctionScopeNode); - class Internal; - - private: - // Classes to get the Python `with` like syntax. - friend class Internal; - friend class With; - // The entry of a function scope. - TVM_DLL void EnterWithScope(); - // The exit of a function scope. - TVM_DLL void ExitWithScope(); -}; - -/*! \brief Auxiliary scope for building Relax dataflow block, - * similar to python's with syntax. - * - * \code - * { - * With scope(ir_builder); - * // build dataflow block. - * } - */ -class DataflowScopeNode : public Object { - public: - IRBuilder ir_builder; - void VisitAttrs(AttrVisitor* v) { v->Visit("ir_builder", &ir_builder); } - - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; - static constexpr const char* _type_key = "relax.DataflowScope"; - TVM_DECLARE_FINAL_OBJECT_INFO(DataflowScopeNode, Object); -}; - -class DataflowScope : public ObjectRef { - public: - TVM_DLL DataflowScope(IRBuilder ib); - TVM_DEFINE_OBJECT_REF_METHODS(DataflowScope, ObjectRef, DataflowScopeNode); - class Internal; - - private: - // Classes to get the Python `with` like syntax. - friend class Internal; - friend class With; - // The entry of a dataflow scope. - TVM_DLL void EnterWithScope(); - // The exit of a dataflow scope. - TVM_DLL void ExitWithScope(); -}; - -/*! - * \brief A lazy builder to construct dataflow block in a copy-on-write fashion. - */ -class LazyIRBuilderNode : public IRBuilderNode { - public: - /*! - * \brief Emit a Call in a copy-on-write way. - * If no bindings in a dataflow block need to be rewritten, reuse the original variable instead of - * emiting one. If any binding in the block needs to be rewritten, reconstruct the whole block - * from scratch by emiting all previous bindings. - * \param call The Call to be emitted. - * \return The variable being created and binded to \p call. - */ - virtual Var Emit(const Call& call); - /*! - * \brief Emit a var binding in a copy-on-write way. - * \param binding The VarBinding to be emitted. - * \return The Var of the \p binding. - */ - virtual Var Emit(const VarBinding& binding); - /*! - * \brief Emit a Call, and bind it to a Var in a copy-on-write way. - * \param var The Var to be binded with. - * \param call The Call to be emitted. - * \return The Var to be binded with \p var. - */ - virtual Var Emit(const Var& var, const Call& call); - /*! - * \brief Emit an output for the current dataflow block or function in a copy-on-write way. - * \param binding The VarBinding to be emitted. - * \return The variable being binded to \p output. - */ - virtual Var EmitOutput(const VarBinding& binding); - /*! - * \brief Build a binding block. - */ - virtual void BuildBlock(); - /*! - * \brief Create a LazyIRBuilder. - * \return The created LazyIRBuilder. - */ - TVM_DLL static LazyIRBuilder Create(const DataflowBlock& block); - - void VisitAttrs(AttrVisitor* v) {} - - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; - static constexpr const char* _type_key = "relax.LazyIRBuilder"; - TVM_DECLARE_FINAL_OBJECT_INFO(LazyIRBuilderNode, IRBuilderNode); - - private: - /*! \brief Original DataflowBlock before rewriting. */ - DataflowBlock df_block_; - /*! \brief index in the \p bindings. */ - int64_t index_ = 0; - /*! \brief A flag tracking if current dataflow block needs to be rewritten or not. */ - bool is_rewrite_ = false; -}; - -class LazyIRBuilder : public IRBuilder { - public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LazyIRBuilder, IRBuilder, LazyIRBuilderNode); -}; - - -} // namespace relax -} // namespace tvm - -#endif // TVM_RELAX_IR_BUILDER_H_ diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 27e32fe71627..53b8e3c42ed6 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -19,7 +19,7 @@ from . import expr from . import ty from . import vm -from . import ir_builder +from . import block_builder from . import op from . import parser from . import analysis @@ -69,7 +69,7 @@ from .op.op_attrs import AllocStorageAttrs, AllocTensorAttrs # IRBuilder -IRBuilder = ir_builder.IRBuilder +BlockBuilder = block_builder.BlockBuilder # Parser from .parser import script diff --git a/python/tvm/relax/ir_builder.py b/python/tvm/relax/block_builder.py similarity index 65% rename from python/tvm/relax/ir_builder.py rename to python/tvm/relax/block_builder.py index 34fb139859ec..cd1fca8a6e88 100644 --- a/python/tvm/relax/ir_builder.py +++ b/python/tvm/relax/block_builder.py @@ -18,41 +18,48 @@ from typing import List, Optional, Union, Dict from tvm.relay.expr import Tuple from tvm.runtime import Object +from tvm import relax as rx from .expr import * from tvm._ffi.base import _LIB, check_call from . import _ffi_api -@tvm._ffi.register_object("relax.FunctionScope") -class FunctionScope(Object): +class FunctionScope(object): """Auxiliary scope for function""" def __init__(self, irbuilder): - self.__init_handle_by_constructor__(_ffi_api.CreateFunctionScope, irbuilder) + self._ib = irbuilder def __enter__(self): - return self + _ffi_api.BlockBuilderBeginBindingBlock(self._ib) def __exit__(self, ptype, value, trace): - _ffi_api.ExitFunctionScope(self) + block = _ffi_api.BlockBuilderEndBlock(self._ib) + if len(block.bindings) > 0: + self._ib._blocks.append(block) -@tvm._ffi.register_object("relax.DataflowScope") -class DataflowScope(Object): +class DataflowScope(object): """Auxiliary scope for Dataflow block""" def __init__(self, irbuilder): - self.__init_handle_by_constructor__(_ffi_api.CreateDataflowScope, irbuilder) + self._ib = irbuilder def __enter__(self): - _ffi_api.EnterDataflowScope(self) + block = _ffi_api.BlockBuilderEndBlock(self._ib) + if len(block.bindings) > 0: + self._ib._blocks.append(block) + _ffi_api.BlockBuilderBeginDataflowBlock(self._ib) def __exit__(self, ptype, value, trace): - _ffi_api.ExitDataflowScope(self) + block = _ffi_api.BlockBuilderEndBlock(self._ib) + if len(block.bindings) > 0: + self._ib._blocks.append(block) + _ffi_api.BlockBuilderBeginBindingBlock(self._ib) -@tvm._ffi.register_object("relax.IRBuilder") -class IRBuilder(Object): +@tvm._ffi.register_object("relax.BlockBuilder") +class BlockBuilder(Object): """A builder to build Relax IR for testing and dev. Examples @@ -71,12 +78,22 @@ class IRBuilder(Object): lv0 = ib.emit(rx.add(x, y)) lv1 = ib.emit(rx.multiply(lv0, y)) gv0 = ib.emit_output(lv1) - ib.emit_output(gv0) + ib.emit_func_output(gv0) func = ib.get() """ def __init__(self): - self.__init_handle_by_constructor__(_ffi_api.IRBuilderCreate) + self._blocks = [] + self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate) + + def _begin_dataflow_block(self) -> None: + _ffi_api.BlockBuilderBeginDataflowBlock(self) + + def _begin_binding_block(self) -> None: + _ffi_api.BlockBuilderBeginBindingBlock(self) + + def _end_block(self) -> BindingBlock: + return _ffi_api.BlockBuilderEndBlock(self) def function(self, params: Optional[Union[Var, Tuple, List[Var]]] = None, @@ -87,10 +104,10 @@ def function(self, ---------- params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional The parameters of the function. - + name : str, optional The name of the function. If provided, the function is global, otherwise local. - + Returns ------- ret: FunctionScope @@ -101,12 +118,13 @@ def function(self, if not isinstance(params, (list, tuple)): params = [params] - _ffi_api.IRBuilderFillFuncNameParam(self, params, name) + self._func_params = params + self._func_name = name return FunctionScope(self) def dataflow(self) -> DataflowScope: """Annotate a Relax dataflow block. - + Returns ------- ret: DataflowScope @@ -114,8 +132,7 @@ def dataflow(self) -> DataflowScope: """ return DataflowScope(self) - def emit(self, - call: relay.Call) -> Var: + def emit(self, call: relay.Call) -> Var: """Emit a call node. This infers the shape and type of the CallNode, create a variable, and bind the CallNode to the variable. @@ -130,11 +147,9 @@ def emit(self, ret : tvm.relax.Var A newly created variable that gets binded to the call code. """ - return _ffi_api.IRBuilderEmit(self, call) - - def match_shape(self, - value: Expr, - pattern: List[PrimExpr]): + return _ffi_api.BlockBuilderEmit(self, call) + + def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var: """Emit a MatchShape. Parameters @@ -144,23 +159,22 @@ def match_shape(self, pattern : List[PrimExpr] The pattern of the MatchShape to be emitted. - + Returns ------- ret : tvm.relax.Var A newly created variable that gets binded to the call code. """ - return _ffi_api.IRBuilderEmitMatchShape(self, value, pattern) + return _ffi_api.BlockBuilderEmitMatchShape(self, value, pattern) - def emit_output(self, - output: Union[Expr, Tuple, List[Expr]]) -> None: + def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None: """Emit output for the current dataflow block or function. Parameters ---------- output : Expr | Tuple | List[Expr] The output of the current block/function. - + Returns ------- ret : tvm.relax.Var @@ -168,40 +182,50 @@ def emit_output(self, """ if isinstance(output, (list, tuple)): output = Tuple(output) - return _ffi_api.IRBuilderEmitOutput(self, output) + return _ffi_api.BlockBuilderEmitOutput(self, output) + + def emit_func_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None: + """Emit output for the function. + + Parameters + ---------- + output : Expr | Tuple | List[Expr] + The output of the current block/function. - def normalize(self, - expr: Expr) -> Expr: + Returns + ------- + ret : tvm.relax.Var + The return variable which gets binded to the output. + """ + if isinstance(output, (list, tuple)): + output = Tuple(output) + self._func_ret = output + + def normalize(self, expr: Expr) -> Expr: """Normalize an Expr to complete its shape and type. Parameters ---------- expr : Expr The input expr. - + Returns ------- ret : Expr The expr with normalized shape and type. """ - return _ffi_api.IRBuilderNormalize(self, expr) + return _ffi_api.BlockBuilderNormalize(self, expr) def get(self) -> Function: """Return the function being built. - + Returns ------- ret : tvm.relax.Function A Relax function node being built. """ - return _ffi_api.IRBuilderGet(self) - - def get_blocks(self) -> List[BindingBlock]: - """Return the binding blocks being built. - - Returns - ------- - ret : List[tvm.relax.BindingBlock] - A list of binding blocks being built. - """ - return _ffi_api.IRBuilderGetBlocks(self) + seqe = rx.SeqExpr(self._blocks, self._func_ret) + func = rx.Function( + self._func_params, seqe, rx.DynTensorType(-1, "float32"), rx.GlobalVar(self._func_name) + ) + return func diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index dcf27f0d18b0..b2bd4a945afa 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -42,7 +43,7 @@ class RelaxScriptPrinter : public relax::IRFunctor, TVM_DLL Doc Print(const ObjectRef& node); private: - std::unordered_map name_alloc_map_; + NameTable name_table_; std::unordered_map var_id_map_; std::unordered_map dim_var_map_; @@ -533,16 +534,7 @@ Doc RelaxScriptPrinter::GetUniqueName(std::string prefix, std::string fallback = if (prefix.empty()) { prefix = fallback; } - // TODO(@altanh): more robust name legalization - std::replace(prefix.begin(), prefix.end(), '.', '_'); - std::string unique_prefix = prefix; - auto it = name_alloc_map_.find(prefix); - if (it != name_alloc_map_.end()) { - while (name_alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) { - } - } - name_alloc_map_[unique_prefix] = 0; - return Doc::Text(unique_prefix); + return Doc::Text(name_table_.GetUniqueName(prefix)); } String AsRelaxScript(const ObjectRef& mod) { diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc new file mode 100644 index 000000000000..6f32123fb2d9 --- /dev/null +++ b/src/relax/ir/block_builder.cc @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/block_builder.cc + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(BlockBuilderNode); + +BlockBuilderNode::~BlockBuilderNode() { + if (!block_stack_.empty()) { + LOG(WARNING) << "BlockBuilder destroyed with remaining blocks!"; + } +} + +BlockBuilder BlockBuilderNode::Create() { + BlockBuilder ret(make_object()); + return ret; +} + +void BlockBuilderNode::BeginDataflowBlock() { this->block_stack_.push({{}, true}); } + +void BlockBuilderNode::BeginBindingBlock() { this->block_stack_.push({{}, false}); } + +BindingBlock BlockBuilderNode::EndBlock() { + BlockFrame* cur_frame = CurrentFrame(); + BindingBlock ret = cur_frame->is_dataflow ? DataflowBlock(cur_frame->bindings) + : BindingBlock(cur_frame->bindings); + block_stack_.pop(); + return ret; +} + +Optional InferShape(const Call& call, DiagnosticContext diag_ctx) { + auto op_map = Op::GetAttrMap("FInferShape"); + if (call->op.as()) { + Op op = Downcast(call->op); + if (op_map.count(op)) { + return op_map[op](call, diag_ctx); + } + } + return NullOpt; +} + +Type InferType(const Call& call, DiagnosticContext diag_ctx) { + auto op_map = Op::GetAttrMap("FInferType"); + if (call->op.as()) { + Op op = Downcast(call->op); + if (op_map.count(op)) { + return op_map[op](call, diag_ctx); + } + } + return Type(); +} + +Var BlockBuilderNode::Emit(const Expr& expr, std::string name_hint) { + return Emit(expr, CurrentFrame()->is_dataflow, name_hint); +} + +Var BlockBuilderNode::Emit(const Expr& expr, bool is_dataflow, std::string name_hint) { + BlockFrame* cur_frame = CurrentFrame(); + + if (name_hint.empty()) { + name_hint = is_dataflow ? "lv" : "gv"; + } + Id vid = Id(name_table_->GetUniqueName(name_hint)); + Var var = is_dataflow ? DataflowVar(vid, NullOpt, NullOpt) : Var(vid, NullOpt, NullOpt); + + // do eager inference for Calls + if (const CallNode* call_node = expr.as()) { + // TypeInference::InferCall(...) + const Call& call = GetRef(call_node); + + Optional inferred_shape = InferShape(call, this->diag_ctx_); + Type inferred_type = InferType(call, this->diag_ctx_); + + var->shape_ = inferred_shape; + var->checked_type_ = inferred_type; + + Call new_call = Call(call->op, call->args, call->attrs, call->type_args, call->span); + new_call->checked_type_ = inferred_type; + new_call->shape_ = inferred_shape; + + cur_frame->bindings.push_back(VarBinding(var, new_call)); + this->var_map_[var->vid] = new_call; + } else if (const VarNode* var_node = expr.as()) { + const Var& lhs_var = GetRef(var_node); + if (lhs_var->shape_.defined()) { + var->shape_ = lhs_var->shape_; + } + if (lhs_var->checked_type_.defined()) { + var->checked_type_ = lhs_var->checked_type_; + } + cur_frame->bindings.push_back(VarBinding(var, lhs_var)); + this->var_map_[var->vid] = lhs_var; + } + + else { + cur_frame->bindings.push_back(VarBinding(var, expr)); + this->var_map_[var->vid] = expr; + } + + return var; +} + +Var BlockBuilderNode::Emit(const VarBinding& binding) { + BlockFrame* cur_frame = CurrentFrame(); + if (cur_frame->is_dataflow) { + ICHECK(binding->var.as()); + } + cur_frame->bindings.push_back(binding); + this->var_map_[binding->var->vid] = binding->value; + return binding->var; +} + +Var BlockBuilderNode::EmitMatchShape(const Expr& value, const Array& pattern, + std::string name_hint) { + BlockFrame* cur_frame = CurrentFrame(); + + if (name_hint.empty()) { + name_hint = cur_frame->is_dataflow ? "lv" : "gv"; + } + Id vid = Id(name_table_->GetUniqueName(name_hint)); + Var var = + cur_frame->is_dataflow ? DataflowVar(vid, NullOpt, NullOpt) : Var(vid, NullOpt, NullOpt); + + if (value->checked_type().as()) { + var->checked_type_ = ShapeType(Span()); + } else if (const DynTensorTypeNode* tty = value->checked_type().as()) { + ShapeExpr shape = ShapeExpr(pattern); + var->shape_ = shape; + DataType dtype = tty->dtype; + var->checked_type_ = DynTensorType(pattern.size(), dtype); + } else { + this->diag_ctx_.EmitFatal( + Diagnostic::Error(value->span) + << "The value passed to EmitMatchShape must be of DynTensorType or ShapeType."); + } + + MatchShape match_shape = MatchShape(value, pattern, var); + cur_frame->bindings.push_back(match_shape); + return var; +} + +Var BlockBuilderNode::EmitMatchShape(const MatchShape& binding) { + BlockFrame* cur_frame = CurrentFrame(); + if (cur_frame->is_dataflow) { + ICHECK(!binding->var.as()) + << "cannot bind DataflowVar outside dataflow block."; + } + cur_frame->bindings.push_back(binding); + return binding->var; +} + +Var BlockBuilderNode::EmitOutput(const Expr& output, std::string name_hint) { + BlockFrame* cur_frame = CurrentFrame(); + + ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; + + return Emit(output, false, name_hint); +} + +Var BlockBuilderNode::EmitOutput(const VarBinding& binding) { + BlockFrame* cur_frame = CurrentFrame(); + + ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; + ICHECK(!binding->var.as()) << "EmitOutput can only emit Var bindings."; + + cur_frame->bindings.push_back(binding); + this->var_map_[binding->var->vid] = binding->value; + return binding->var; +} + +Expr BlockBuilderNode::LookupVar(const Var& var) { + auto it = this->var_map_.find(var->vid); + if (it == this->var_map_.end()) { + this->diag_ctx_.EmitFatal(Diagnostic::Error(var->span) + << "The var to be looked up is not in the binding table."); + } + return it->second; +} + +bool BlockBuilderNode::CanProveShapeEqual(const Expr& lhs, const Expr& rhs) { + if (lhs == rhs) { + return true; + } + const auto* lhs_shape = lhs.as(); + const auto* rhs_shape = rhs.as(); + if (lhs_shape && rhs_shape) { + size_t lhs_ndim = lhs_shape->values.size(); + size_t rhs_ndim = rhs_shape->values.size(); + if (lhs_ndim != rhs_ndim) { + return false; + } + arith::Analyzer analyzer; + for (size_t i = 0; i < lhs_ndim; ++i) { + PrimExpr lhs_dim = lhs_shape->values[i]; + PrimExpr rhs_dim = rhs_shape->values[i]; + if (!analyzer.CanProveEqual(lhs_dim, rhs_dim)) { + return false; + } + } + return true; + } + return false; +} + +// TODO(@altanh, @yuchen): emit expr in ssa form +Expr BlockBuilderNode::Normalize(const Expr& expr) { + if (expr.as()) { + Call call = Downcast(expr); + // Shape inference + auto inferred_shape = InferShape(call, this->diag_ctx_); + if (inferred_shape.defined()) { + if (auto* shape_expr = inferred_shape.value().as()) { + call->shape_ = GetRef(shape_expr); + } + } + // Type inference + auto inferred_type = InferType(call, this->diag_ctx_); + call->checked_type_ = inferred_type; + return call; + } + return expr; +} + +BlockBuilderNode::BlockFrame* BlockBuilderNode::CurrentFrame() { + ICHECK(!block_stack_.empty()) << "no block is being built"; + return &block_stack_.top(); +} + +BlockBuilder::BlockBuilder(std::shared_ptr name_table) { + ObjectPtr n = make_object(); + n->name_table_ = name_table; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed(BlockBuilderNode::Create); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginDataflowBlock") + .set_body_typed([](BlockBuilder builder) { builder->BeginDataflowBlock(); }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginBindingBlock").set_body_typed([](BlockBuilder builder) { + builder->BeginBindingBlock(); +}); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEndBlock").set_body_typed([](BlockBuilder builder) { + return builder->EndBlock(); +}); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit") + .set_body_typed([](BlockBuilder builder, const Call& call) { return builder->Emit(call); }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchShape") + .set_body_typed([](BlockBuilder builder, const Expr& value, const Array& pattern) { + return builder->EmitMatchShape(value, pattern); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput") + .set_body_typed([](BlockBuilder builder, const Expr& output) { + return builder->EmitOutput(output); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize") + .set_body_typed([](BlockBuilder builder, const Expr& expr) { + return builder->Normalize(expr); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index a630f7945c32..ad1b1238e65c 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -26,24 +26,20 @@ */ #include #include +#include #include #include -#include namespace tvm { namespace relax { -void ExprVisitor::VisitExpr_(const ConstantNode* op) { - this->VisitSpan(op->span); -} +void ExprVisitor::VisitExpr_(const ConstantNode* op) { this->VisitSpan(op->span); } -void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { - this->VisitSpan(op->span); -} +void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); } void ExprVisitor::VisitExpr_(const TupleNode* op) { this->VisitSpan(op->span); - for (auto field : op->fields) { + for (Expr field : op->fields) { this->VisitExpr(field); } } @@ -64,7 +60,7 @@ void ExprVisitor::VisitExpr_(const DataflowVarNode* op) { void ExprVisitor::VisitExpr_(const FunctionNode* op) { this->VisitSpan(op->span); - for (auto param : op->params) { + for (Var param : op->params) { this->VisitExpr(param); } @@ -75,11 +71,11 @@ void ExprVisitor::VisitExpr_(const CallNode* op) { this->VisitSpan(op->span); this->VisitExpr(op->op); - for (auto ty_arg : op->type_args) { + for (Type ty_arg : op->type_args) { this->VisitType(ty_arg); } - for (auto arg : op->args) { + for (Expr arg : op->args) { this->VisitExpr(arg); } } @@ -91,38 +87,28 @@ void ExprVisitor::VisitExpr_(const IfNode* op) { this->VisitExpr(op->false_branch); } -void ExprVisitor::VisitExpr_(const OpNode* op) { - return; -} +void ExprVisitor::VisitExpr_(const OpNode* op) {} void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitSpan(op->span); this->VisitExpr(op->tuple); } -void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { - this->VisitSpan(op->span); -} +void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { this->VisitSpan(op->span); } -void ExprVisitor::VisitExpr_(const ExternFuncNode* op) { - this->VisitSpan(op->span); -} +void ExprVisitor::VisitExpr_(const ExternFuncNode* op) { this->VisitSpan(op->span); } void ExprVisitor::VisitExpr_(const SeqExprNode* op) { this->VisitSpan(op->span); - for (auto block : op->blocks) { + for (BindingBlock block : op->blocks) { this->VisitBindingBlock(block); } this->VisitExpr(op->body); } -void ExprVisitor::VisitType(const Type& t) { - return; -} +void ExprVisitor::VisitType(const Type& t) {} -void ExprVisitor::VisitSpan(const Span& span) { - return; -} +void ExprVisitor::VisitSpan(const Span& span) {} void ExprVisitor::VisitBinding(const Binding& binding) { if (binding.as()) { @@ -134,9 +120,7 @@ void ExprVisitor::VisitBinding(const Binding& binding) { } } -void ExprVisitor::VisitVarBinding(const VarBinding& binding) { - this->VisitExpr(binding->value); -} +void ExprVisitor::VisitVarBinding(const VarBinding& binding) { this->VisitExpr(binding->value); } void ExprVisitor::VisitMatchShape(const MatchShape& binding) { this->VisitExpr(binding->value); @@ -149,14 +133,14 @@ void ExprVisitor::VisitBindingBlock(const BindingBlock& block) { if (block.as()) { this->VisitDataflowBlock(Downcast(block)); } else { - for (auto binding : block->bindings) { + for (Binding binding : block->bindings) { this->VisitBinding(binding); } } } void ExprVisitor::VisitDataflowBlock(const DataflowBlock& block) { - for (auto binding : block->bindings) { + for (Binding binding : block->bindings) { this->VisitBinding(binding); } } @@ -183,24 +167,26 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit") -.set_body_typed([](Expr expr, PackedFunc f) { +TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit").set_body_typed([](Expr expr, PackedFunc f) { PostOrderVisit(expr, [f](const Expr& n) { f(n); }); }); - // ================== // ExprMutator -Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef(op); } +Expr ExprMutator::VisitExpr_(const ConstantNode* op) { + return GetRef(op); +} -Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef(op); } +Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { + return GetRef(op); +} Expr ExprMutator::VisitExpr_(const TupleNode* op) { tvm::Array fields; bool all_fields_unchanged = true; - for (auto field : op->fields) { - auto new_field = this->Mutate(field); + for (Expr field : op->fields) { + Expr new_field = this->Mutate(field); fields.push_back(new_field); all_fields_unchanged &= new_field.same_as(field); } @@ -214,7 +200,7 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) { Expr ExprMutator::VisitExpr_(const VarNode* op) { if (op->type_annotation.defined()) { - auto type = this->VisitType(op->type_annotation.value()); + Type type = this->VisitType(op->type_annotation.value()); if (!op->type_annotation.same_as(type)) { return Var(op->vid, Downcast(op->shape()), type, op->span); } @@ -225,7 +211,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) { Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) { if (op->type_annotation.defined()) { - auto type = this->VisitType(op->type_annotation.value()); + Type type = this->VisitType(op->type_annotation.value()); if (!op->type_annotation.same_as(type)) { return DataflowVar(op->vid, Downcast(op->shape()), type, op->span); } @@ -237,14 +223,14 @@ Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) { Expr ExprMutator::VisitExpr_(const FunctionNode* op) { tvm::Array params; bool all_params_unchanged = true; - for (auto param : op->params) { + for (Var param : op->params) { Var new_param = Downcast(this->Mutate(param)); params.push_back(new_param); all_params_unchanged &= param.same_as(new_param); } - auto ret_type = this->VisitType(op->ret_type); - auto body = this->Mutate(op->body); + Type ret_type = this->VisitType(op->ret_type); + Expr body = this->MutateWithPrologue(op->body, false); if (all_params_unchanged && ret_type.same_as(op->ret_type) && body.same_as(op->body)) { return GetRef(op); @@ -254,19 +240,19 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { } Expr ExprMutator::VisitExpr_(const CallNode* call_node) { - auto new_op = this->Mutate(call_node->op); + Expr new_op = this->Mutate(call_node->op); bool unchanged = call_node->op.same_as(new_op); tvm::Array ty_args; - for (auto ty_arg : call_node->type_args) { - auto new_ty_arg = this->VisitType(ty_arg); + for (Type ty_arg : call_node->type_args) { + Type new_ty_arg = this->VisitType(ty_arg); ty_args.push_back(new_ty_arg); unchanged &= new_ty_arg.same_as(ty_arg); } tvm::Array call_args; - for (auto arg : call_node->args) { - auto new_arg = this->Mutate(arg); + for (Expr arg : call_node->args) { + Expr new_arg = this->Mutate(arg); call_args.push_back(new_arg); unchanged &= new_arg.same_as(arg); } @@ -279,9 +265,9 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) { } Expr ExprMutator::VisitExpr_(const IfNode* op) { - auto guard = this->Mutate(op->cond); - auto true_b = this->Mutate(op->true_branch); - auto false_b = this->Mutate(op->false_branch); + Expr guard = this->Mutate(op->cond); + Expr true_b = this->MutateWithPrologue(op->true_branch, false); + Expr false_b = this->MutateWithPrologue(op->false_branch, false); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b)) { return GetRef(op); @@ -301,20 +287,33 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* get_item) { } } -Expr ExprMutator::VisitExpr_(const ShapeExprNode* op) { return GetRef(op); } +Expr ExprMutator::VisitExpr_(const ShapeExprNode* op) { + return GetRef(op); +} -Expr ExprMutator::VisitExpr_(const ExternFuncNode* op) { return GetRef(op); } +Expr ExprMutator::VisitExpr_(const ExternFuncNode* op) { + return GetRef(op); +} Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { bool all_blocks_unchanged = true; Array blocks; for (auto block : op->blocks) { BindingBlock new_block = this->VisitBindingBlock(block); - blocks.push_back(new_block); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } all_blocks_unchanged &= block.same_as(new_block); } + builder_->BeginBindingBlock(); Expr body = this->Mutate(op->body); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + blocks.push_back(prologue); + all_blocks_unchanged = false; + } + if (all_blocks_unchanged && body.same_as(op->body)) { return GetRef(op); } else { @@ -324,98 +323,129 @@ Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { Type ExprMutator::VisitType(const Type& t) { return t; } -void ExprMutator::VisitBinding(const Binding& binding, IRBuilder& builder) { - Binding new_binding; +void ExprMutator::VisitBinding(const Binding& binding) { if (binding.as()) { - this->VisitVarBinding(Downcast(binding), builder); + this->VisitVarBinding(Downcast(binding)); } else if (binding.as()) { - this->VisitMatchShape(Downcast(binding), builder); + this->VisitMatchShape(Downcast(binding)); } else { LOG(FATAL) << "Wrong type."; } } -Var ExprMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& builder) { - Expr new_value = this->Mutate(binding->value); - if (!binding->var.as()) { - return builder->EmitOutput(new_value); +Var ExprMutator::VisitVarBinding(const VarBinding& binding) { + Expr new_value = builder_->Normalize(this->Mutate(binding->value)); + Var new_var = Downcast(this->Mutate(binding->var)); + // TODO(@altanh): this probably shouldn't live here, all passes would have to make sure to do it + // in this method... + // if (new_value->shape_.defined()) { + // if (new_var->shape_.defined()) { + // new_var = Var(new_var->vid, NullOpt, new_var->type_annotation, new_var->span); + // } + // new_var->shape_ = new_value->shape_; + // } + // if (new_value->checked_type_.defined()) { + // if (new_var->checked_type_.defined()) { + + // } + // new_var = Var(new_var->vid, new_var->shape_, NullOpt, new_var->span); + // new_var->checked_type_ = new_value->checked_type_; + // } + + if (!builder_->CanProveShapeEqual(new_var->shape(), new_value->shape()) || + !StructuralEqual()(new_var->checked_type(), new_value->checked_type())) { + new_var = Var(new_var->vid, NullOpt, NullOpt, new_var->span); + if (new_value->shape_.defined()) { + new_var->shape_ = new_value->shape_; + } + // TODO(@yuchen, @altanh): checked_type_.defined() needs to change depends on how to represent unknown type + if (new_value->checked_type_.defined()){ + new_var->checked_type_ = new_value->checked_type_; + } + } + + this->var_remap_[binding->var] = new_var; + + if (builder_->CurrentBlockIsDataFlow() && !binding->var.as()) { + return builder_->EmitOutput(VarBinding(new_var, new_value)); } else { - return builder->Emit(Downcast(new_value)); + return builder_->Emit(VarBinding(new_var, new_value)); } } -void ExprMutator::VisitMatchShape(const MatchShape& binding, IRBuilder& builder) { - this->Mutate(binding->value); - this->Mutate(ShapeExpr(binding->pattern)); +void ExprMutator::VisitMatchShape(const MatchShape& binding) { + Expr new_value = this->Mutate(binding->value); + Expr new_pattern = this->Mutate(ShapeExpr(binding->pattern)); + Var new_var = Downcast(this->Mutate(binding->var)); + builder_->EmitMatchShape( + MatchShape(new_value, Downcast(new_pattern)->values, new_var)); } BindingBlock ExprMutator::VisitBindingBlock(const BindingBlock& block) { if (block.as()) { return this->VisitDataflowBlock(Downcast(block)); - } else{ - this->builder_ = IRBuilderNode::Create(); - for (auto binding : block->bindings) { - this->VisitBinding(binding, this->builder_); + } else { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); } - auto blocks = this->builder_->GetBlocks(); - return blocks.back(); + return builder_->EndBlock(); } } BindingBlock ExprMutator::VisitDataflowBlock(const DataflowBlock& block) { - this->builder_ = LazyIRBuilderNode::Create(block); - { - With scope(this->builder_); - for (auto binding : block->bindings) { - this->VisitBinding(binding, this->builder_); - } + builder_->BeginDataflowBlock(); + for (auto binding : block->bindings) { + this->VisitBinding(binding); } - return this->builder_->GetBlocks().back(); + return builder_->EndBlock(); } -Expr ExprMutator::VisitExpr(const Expr& expr) { - Expr new_expr = ExprFunctor::VisitExpr(expr); - return new_expr; +Expr ExprMutator::VisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); } + +Expr ExprMutator::MutateWithPrologue(const Expr& expr, bool is_dataflow) { + if (is_dataflow) { + builder_->BeginDataflowBlock(); + } else { + builder_->BeginBindingBlock(); + } + + Expr ret = this->Mutate(expr); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + ret = SeqExpr({prologue}, ret); + } + return ret; } +Expr ExprMutator::LookupVar(Var var) { + auto it = var_remap_.find(var); + if (it != var_remap_.end()) { + return builder_->LookupVar(it->first); + } else { + return builder_->LookupVar(var); + } +} // ================== // DataflowMutator -BindingBlock DataflowMutator::VisitDataflowBlock(const DataflowBlock& block) { - this->builder_ = LazyIRBuilderNode::Create(block); - { - With scope(this->builder_); - for (auto binding : block->bindings) { - if (auto* var_binding = binding.as()) { - Var var = this->VisitVarBinding(Downcast(binding), this->builder_); - this->pre_post_var_map_[var_binding->var] = var; - } +void DataflowMutator::VisitBinding(const Binding& binding) { + if (binding.as()) { + VarBinding var_binding = Downcast(binding); + if (builder_->CurrentBlockIsDataFlow()) { + var_remap_[var_binding->var] = this->VisitDataflowVarBinding(var_binding); + } else { + var_remap_[var_binding->var] = ExprMutator::VisitVarBinding(var_binding); } + } else { + ExprMutator::VisitBinding(binding); } - return this->builder_->GetBlocks().back(); } -Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& builder) { - Expr new_value = this->Mutate(binding->value); - Var new_var; - if (new_value.as()) { - new_var = builder->Emit(Downcast(new_value)); - } - if (!binding->var.as()) { - new_var = builder->EmitOutput(new_value); - } - pre_post_var_map_[binding->var] = new_var; - return new_var; +Var DataflowMutator::VisitDataflowVarBinding(const VarBinding& binding) { + return ExprMutator::VisitVarBinding(binding); } -Expr DataflowMutator::LookupVar(Var var) { - auto it = pre_post_var_map_.find(var); - if (it != pre_post_var_map_.end()) { - return builder_->LookupVar(it->first); - } else { - return builder_->LookupVar(var); - } -} } // namespace relax } // namespace tvm diff --git a/src/relax/ir/ir_builder.cc b/src/relax/ir/ir_builder.cc deleted file mode 100644 index 864afdf5421f..000000000000 --- a/src/relax/ir/ir_builder.cc +++ /dev/null @@ -1,460 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relax/ir_builder.cc - */ - -#include -#include -#include -#include -#include - -namespace tvm { -namespace relax { - -TVM_REGISTER_NODE_TYPE(IRBuilderNode); -TVM_REGISTER_NODE_TYPE(LazyIRBuilderNode); -TVM_REGISTER_NODE_TYPE(FunctionScopeNode); -TVM_REGISTER_NODE_TYPE(DataflowScopeNode); - -IRBuilder IRBuilderNode::Create() { - IRBuilder ret(make_object()); - return ret; -} - -void IRBuilderNode::FillFuncNameParam(const Array& params, const std::string& func_name) { - if (!func_name.empty()) { - this->func_.func_name = GlobalVar(func_name); - } - for (Var param : params) { - this->var_map_[param] = param; - } - this->func_.params = params; -} - -void IRBuilderNode::BuildFunction() { - SeqExpr seq = SeqExpr(this->func_.binding_blocks, this->func_.ret); - this->func_.func = Function(this->func_.func_name, this->func_.params, seq, {}); - this->global_var_counter_ = 0; -} - -void IRBuilderNode::BuildBlock() { - if (!this->func_.bindings.empty()) { - if (is_dataflow_) { - this->func_.binding_blocks.emplace_back(DataflowBlock(this->func_.bindings)); - } else { - this->func_.binding_blocks.emplace_back(BindingBlock(this->func_.bindings)); - } - this->func_.bindings.clear(); - } - this->dataflow_var_counter_ = 0; - this->is_dataflow_ = !this->is_dataflow_; -} - -Optional InferShape(const Call& call, DiagnosticContext diag_ctx) { - auto op_map = Op::GetAttrMap("FInferShape"); - if (call->op.as()) { - Op op = Downcast(call->op); - if (op_map.count(op)) { - return op_map[op](call, diag_ctx); - } - } - return NullOpt; -} - -Type InferType(const Call& call, DiagnosticContext diag_ctx) { - auto op_map = Op::GetAttrMap("FInferType"); - if (call->op.as()) { - Op op = Downcast(call->op); - if (op_map.count(op)) { - return op_map[op](call, diag_ctx); - } - } - return VoidType(); -} - -Var IRBuilderNode::Emit(const Call& call) { - Var var; - if (is_dataflow_) { - var = DataflowVar(Id("lv" + std::to_string(dataflow_var_counter_++)), NullOpt, NullOpt); - } else { - var = Var(Id("gv" + std::to_string(global_var_counter_++)), NullOpt, NullOpt); - } - - // Shape inference - auto inferred_shape = InferShape(call, this->diag_ctx_); - if (inferred_shape.defined()) { - if (auto* shape_expr = inferred_shape.value().as()) { - call->shape_ = GetRef(shape_expr); - var->shape_ = call->shape_; - } - } - // Type inference - auto inferred_type = InferType(call, this->diag_ctx_); - call->checked_type_ = inferred_type; - var->checked_type_ = inferred_type; - - this->func_.bindings.emplace_back(VarBinding(var, call)); - this->var_map_[var] = call; - return var; -} - -Var IRBuilderNode::EmitMatchShape(const Expr& value, const Array& pattern) { - Var var; - if (is_dataflow_) { - var = DataflowVar(Id("lv" + std::to_string(dataflow_var_counter_++)), NullOpt, NullOpt); - } else { - var = Var(Id("gv" + std::to_string(global_var_counter_++)), NullOpt, NullOpt); - } - if (value->checked_type().as()) { - var->checked_type_ = ShapeType(Span()); - } else if (value->checked_type().as()){ - ShapeExpr shape = ShapeExpr(pattern); - var->shape_ = shape; - DataType dtype = (Downcast(value->checked_type()))->dtype; - var->checked_type_ = DynTensorType(pattern.size(), dtype); - } else { - this->diag_ctx_.EmitFatal(Diagnostic::Error(value->span) - << "The value passed to EmitMatchShape must be of DynTensorType or ShapeType."); - } - - MatchShape match_shape = MatchShape(value, pattern, var); - this->func_.bindings.emplace_back(match_shape); - return var; -} - -Var IRBuilderNode::Emit(const VarBinding& binding) { - // FIXME(yuchen or ziheng): consider binding in normal block) - if (!binding->var.as()) { - return EmitOutput(binding->value); - } else { - this->func_.bindings.emplace_back(binding); - this->var_map_[binding->var] = binding->value; - return binding->var; - } -} - -Var IRBuilderNode::Emit(const Var& var, const Call& call) { - Expr normalized_call = Normalize(call); - // Reuse the input var if the shape and type of the call matches the var - if (CanProveShapeEqual(var->shape(), call->shape()) && StructuralEqual()(var->checked_type(), call->checked_type())) { - this->func_.bindings.emplace_back(VarBinding(var, normalized_call)); - this->var_map_[var] = normalized_call; - return var; - } else { - Var new_var; - if (normalized_call->shape_.defined()) { - new_var->shape_ = normalized_call->shape_; - } - this->func_.bindings.emplace_back(VarBinding(new_var, normalized_call)); - this->var_map_[new_var] = normalized_call; - return new_var; - } -} - -Var IRBuilderNode::EmitOutput(const Expr& output) { - Var ret; - if (is_dataflow_) { - ret = Var(Id("gv" + std::to_string(global_var_counter_++)), NullOpt, NullOpt); - ret->shape_ = output->shape_; - ret->checked_type_ = output->checked_type_; - this->func_.bindings.emplace_back(VarBinding(ret, output)); - this->var_map_[ret] = output; - } else { - this->func_.ret = output; - } - return ret; -} - -Expr IRBuilderNode::LookupVar(const Var& var) { - auto it = this->var_map_.find(var); - if (it == this->var_map_.end()) { - this->diag_ctx_.EmitFatal(Diagnostic::Error(var->span) - << "The var to be looked up is not in the binding table."); - } - return it->second; -} - -Function IRBuilderNode::Get() { - return this->func_.func; -} - -std::vector IRBuilderNode::GetBlocks() { - this->BuildBlock(); - return this->func_.binding_blocks; -} - -bool IRBuilderNode::CanProveShapeEqual(const Expr& lhs, const Expr& rhs) { - if (lhs == rhs) { - return true; - } - const auto* lhs_shape = lhs.as(); - const auto* rhs_shape = rhs.as(); - if (lhs_shape && rhs_shape) { - size_t lhs_ndim = lhs_shape->values.size(); - size_t rhs_ndim = rhs_shape->values.size(); - if (lhs_ndim != rhs_ndim) { - return false; - } - arith::Analyzer analyzer; - for (size_t i = 0; i < lhs_ndim; ++i) { - PrimExpr lhs_dim = lhs_shape->values[i]; - PrimExpr rhs_dim = rhs_shape->values[i]; - if (!analyzer.CanProveEqual(lhs_dim, rhs_dim)) { - return false; - } - } - return true; - } - return false; -} - -Expr IRBuilderNode::Normalize(const Expr& expr) { - if (expr.as()) { - Call call = Downcast(expr); - // Shape inference - auto inferred_shape = InferShape(call, this->diag_ctx_); - if (inferred_shape.defined()) { - if (auto* shape_expr = inferred_shape.value().as()) { - call->shape_ = GetRef(shape_expr); - } - } - // Type inference - auto inferred_type = InferType(call, this->diag_ctx_); - call->checked_type_ = inferred_type; - return call; - } - return expr; -} - -class FunctionScope::Internal { - public: - static void ExitScope(FunctionScope scope) { scope.ExitWithScope(); } -}; - -FunctionScope::FunctionScope(IRBuilder ib) { - ObjectPtr n = make_object(); - n->ir_builder = std::move(ib); - data_ = std::move(n); -} - -void FunctionScope::ExitWithScope() { - this->get()->ir_builder->BuildBlock(); - this->get()->ir_builder->BuildFunction(); -} - -class DataflowScope::Internal { - public: - static void EnterScope(DataflowScope scope) { scope.EnterWithScope(); } - - static void ExitScope(DataflowScope scope) { scope.ExitWithScope(); } -}; - -DataflowScope::DataflowScope(IRBuilder ib) { - ObjectPtr n = make_object(); - n->ir_builder = std::move(ib); - data_ = std::move(n); -} - -void DataflowScope::EnterWithScope() { this->get()->ir_builder->BuildBlock(); } - -void DataflowScope::ExitWithScope() { this->get()->ir_builder->BuildBlock(); } - -LazyIRBuilder LazyIRBuilderNode::Create(const DataflowBlock& block) { - LazyIRBuilder ret(make_object()); - ret->df_block_ = block; - return ret; -} - -Var LazyIRBuilderNode::Emit(const Call& call) { - if (is_rewrite_) { - index_++; - return IRBuilderNode::Emit(call); - } - Expr expr = Downcast(this->df_block_->bindings[index_])->value; - Call old_call = Downcast(expr); - if (call.same_as(old_call)) { - VarBinding binding = Downcast(this->df_block_->bindings[index_++]); - this->var_map_[binding->var] = binding->value; - return binding->var; - } - else { - is_rewrite_ = true; - for (int i = 0; i < index_; i++) { - Expr expr = Downcast(this->df_block_->bindings[i])->value; - IRBuilderNode::Emit(Downcast(expr)); - } - index_++; - return IRBuilderNode::Emit(call); - } -} - -Var LazyIRBuilderNode::Emit(const VarBinding& binding) { - if (!binding->var.as()) { - return IRBuilderNode::EmitOutput(binding->value); - } - if (is_rewrite_) { - index_++; - return IRBuilderNode::Emit(binding); - } - Binding old_binding = this->df_block_->bindings[index_]; - if (binding.same_as(old_binding)) { - index_++; - this->var_map_[binding->var] = binding->value; - return binding->var; - } - else { - is_rewrite_ = true; - for (int i = 0; i < index_; i++) { - if (!binding->var.as()) { - IRBuilderNode::EmitOutput(binding->value); - } else { - Expr expr = Downcast(this->df_block_->bindings[i])->value; - IRBuilderNode::Emit(Downcast(expr)); - } - } - index_++; - Call call = Downcast(binding->value); - return IRBuilderNode::Emit(call); - } -} - -Var LazyIRBuilderNode::Emit(const Var& var, const Call& call) { - if (is_rewrite_) { - index_++; - return IRBuilderNode::Emit(var, call); - } - Expr expr = Downcast(this->df_block_->bindings[index_])->value; - Call old_call = Downcast(expr); - if (call.same_as(old_call)) { - index_++; - this->var_map_[var] = call; - return var; - } - else { - is_rewrite_ = true; - for (int i = 0; i < index_; i++) { - VarBinding old_binding = Downcast(this->df_block_->bindings[i]); - // Reuse the old bindings - IRBuilderNode::Emit(old_binding); - } - index_++; - return IRBuilderNode::Emit(var, call); - } -} - -Var LazyIRBuilderNode::EmitOutput(const VarBinding& binding) { - if (is_rewrite_) { - index_++; - return IRBuilderNode::EmitOutput(binding->value); - } - Binding old_binding = this->df_block_->bindings[index_]; - if (binding.same_as(old_binding)) { - index_++; - this->var_map_[binding->var] = binding->value; - return binding->var; - } - else { - is_rewrite_ = true; - for (int i = 0; i < index_; i++) { - if (!binding->var.as()) { - IRBuilderNode::EmitOutput(binding->value); - } else { - Expr expr = Downcast(this->df_block_->bindings[i])->value; - IRBuilderNode::Emit(Downcast(expr)); - } - } - index_++; - return IRBuilderNode::EmitOutput(binding->value); - } -} - -void LazyIRBuilderNode::BuildBlock() { - if (!this->func_.bindings.empty()) { - if (is_dataflow_) { - if (is_rewrite_) { - this->func_.binding_blocks.emplace_back(DataflowBlock(this->func_.bindings)); - } - else { - this->func_.binding_blocks.emplace_back(this->df_block_); - } - } else { - this->func_.binding_blocks.emplace_back(BindingBlock(this->func_.bindings)); - } - this->func_.bindings.clear(); - } - this->dataflow_var_counter_ = 0; - this->is_dataflow_ = !this->is_dataflow_; -} - -TVM_REGISTER_GLOBAL("relax.IRBuilderCreate").set_body_typed(IRBuilderNode::Create); - -TVM_REGISTER_GLOBAL("relax.IRBuilderFillFuncNameParam") - .set_body_typed([](IRBuilder builder, const Array& params, const std::string& func_name) { - return builder->FillFuncNameParam(params, func_name); - }); - -TVM_REGISTER_GLOBAL("relax.IRBuilderBuildFunction").set_body_typed([](IRBuilder builder) { - return builder->BuildFunction(); -}); - -TVM_REGISTER_GLOBAL("relax.IRBuilderEmit").set_body_typed([](IRBuilder builder, const Call& call) { - return builder->Emit(call); -}); - -TVM_REGISTER_GLOBAL("relax.IRBuilderEmitMatchShape").set_body_typed([](IRBuilder builder, const Expr& value, const Array& pattern) { - return builder->EmitMatchShape(value, pattern); -}); - -TVM_REGISTER_GLOBAL("relax.IRBuilderEmitOutput") - .set_body_typed([](IRBuilder builder, const Expr& output) { - return builder->EmitOutput(output); - }); - -TVM_REGISTER_GLOBAL("relax.IRBuilderNormalize") - .set_body_typed([](IRBuilder builder, const Expr& expr) { - return builder->Normalize(expr); - }); - -TVM_REGISTER_GLOBAL("relax.IRBuilderGet").set_body_typed([](IRBuilder builder) { - return builder->Get(); -}); - -TVM_REGISTER_GLOBAL("relax.IRBuilderGetBlocks").set_body_typed([](IRBuilder builder) { - return Array(builder->GetBlocks()); -}); - -TVM_REGISTER_GLOBAL("relax.CreateFunctionScope").set_body_typed([](IRBuilder ib) { - return FunctionScope(ib); -}); - -TVM_REGISTER_GLOBAL("relax.ExitFunctionScope").set_body_typed(FunctionScope::Internal::ExitScope); - -TVM_REGISTER_GLOBAL("relax.CreateDataflowScope").set_body_typed([](IRBuilder ib) { - return DataflowScope(ib); -}); - -TVM_REGISTER_GLOBAL("relax.EnterDataflowScope").set_body_typed(DataflowScope::Internal::EnterScope); - -TVM_REGISTER_GLOBAL("relax.ExitDataflowScope").set_body_typed(DataflowScope::Internal::ExitScope); - -} // namespace relax -} // namespace tvm diff --git a/src/relax/transform/fma_rewrite.cc b/src/relax/transform/fma_rewrite.cc index c308d402e705..8108832ff068 100644 --- a/src/relax/transform/fma_rewrite.cc +++ b/src/relax/transform/fma_rewrite.cc @@ -41,23 +41,26 @@ namespace relax { // lv0 = add(k, b) // z0 = ewise_fma(a, lv0, c) -class EwiseFMARewriter : public DataflowMutator { - Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override { +class EwiseFMARewriter : public ExprMutator { + Expr VisitExpr_(const CallNode* call) override { + Expr expr = ExprMutator::VisitExpr_(call); + call = expr.as(); + static const Op& add_op = Op::Get("relax.add"); static const Op& multiply_op = Op::Get("relax.multiply"); static const Op& ewise_fma_op = Op::Get("relax.ewise_fma"); - // TODO: shape & dtype check - const CallNode* op1 = binding->value.as(); - if (op1 && (op1->op == add_op)) { - Expr value = LookupVar(Downcast(op1->args[0])); - const CallNode* op2 = value.as(); - if (op2 && op2->op == multiply_op) { - Call fma_call = Call(ewise_fma_op, {op2->args[0], op2->args[1], op1->args[1]}, {}, {}); - return ir_builder->Emit(binding->var, fma_call); + if (call->op == add_op) { + // NOTE: assumes df block is completely SSA + Expr value = LookupVar(Downcast(call->args[0])); + const CallNode* mul = value.as(); + if (mul && mul->op == multiply_op) { + Call fma_call = Call(ewise_fma_op, {mul->args[0], mul->args[1], call->args[1]}, {}, {}); + return fma_call; } } - return ir_builder->Emit(binding); + + return GetRef(call); } }; diff --git a/src/relax/transform/memory_rewrite.cc b/src/relax/transform/memory_rewrite.cc index ae9832eb9012..39b4a56b3fd1 100644 --- a/src/relax/transform/memory_rewrite.cc +++ b/src/relax/transform/memory_rewrite.cc @@ -18,11 +18,12 @@ */ /*! * \file src/relax/transform/memory_rewrite.cc - * \brief + * \brief */ #include #include #include + #include "../../relay/transforms/pattern_utils.h" namespace tvm { @@ -36,7 +37,7 @@ namespace relax { // lv0 = rx.call("relax.builtin.alloc_tensor", [n, m]) // rx.call_packed(op.identity, x, lv0) -class ExplicitMemMutator : public DataflowMutator { +class ExplicitMemMutator : public ExprMutator { Expr ComputeStorageSize(const Expr& shape, const Type& type) const { DynTensorType tensor_type = Downcast(type); DataType dtype = DataType(tensor_type->dtype); @@ -62,26 +63,39 @@ class ExplicitMemMutator : public DataflowMutator { return ret; } - Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override { + BindingBlock VisitBindingBlock(const BindingBlock& block) { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); + } + + Expr VisitExpr_(const CallNode* call) override { + // post-order mutation + Expr expr = ExprMutator::VisitExpr_(call); + call = expr.as(); + // TODO(@yuchen, @altanh): using mutate cause infinite recursion + // Expr expr = ExprMutator::Mutate(GetRef(call)); + static const Op& call_dps_op = Op::Get("relax.call_dps"); static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); - const CallNode* op = binding->value.as(); - if(op && op->op == call_dps_op) { - // switch current DataflowBlock to an impure BindingBlock - ir_builder->is_dataflow_ = false; - ShapeExpr output_shape = Downcast(op->args[0]); - Type arg_type = Downcast(op->args[2])->fields[0]->checked_type(); + if (call->op == call_dps_op) { + ShapeExpr output_shape = Downcast(call->args[0]); + Type arg_type = Downcast(call->args[2])->fields[0]->checked_type(); Expr output_size = ComputeStorageSize(output_shape, arg_type); - Var tensor = ir_builder->Emit(Call(alloc_tensor_op, {op->args[0]})); - return ir_builder->Emit(binding->var, Call(op->args[1], {op->args[2], tensor})); + Var tensor = builder_->Emit(Call(alloc_tensor_op, {call->args[0]}), "alloc"); + builder_->Emit(Call(call->args[1], {call->args[2], tensor}), "_"); + return tensor; } - return ir_builder->Emit(binding); + + return GetRef(call); } }; -Expr ExplicitMemRewrite(const Expr& e) { - return ExplicitMemMutator().Mutate(e); +Expr ExplicitMemRewrite(const Expr& e) { + return ExplicitMemMutator().Mutate(e); } TVM_REGISTER_GLOBAL("relax.transform.explicit_memory_rewrite") diff --git a/src/relax/transform/shape_lower.cc b/src/relax/transform/shape_lower.cc index f60d9bdb4f8a..f955842a1ab7 100644 --- a/src/relax/transform/shape_lower.cc +++ b/src/relax/transform/shape_lower.cc @@ -18,42 +18,24 @@ */ /*! * \file src/relax/transform/shape_lower.cc - * \brief + * \brief */ #include #include -#include #include +#include #include + #include "../../printer/text_printer.h" namespace tvm { namespace relax { -// Replace ShapeExpr with corresponding Var -class ShapeReplacer : public ExprMutator { - public: - explicit ShapeReplacer(Map mapping) { - mapping_ = mapping; - } - Expr VisitExpr_(const ShapeExprNode* op) override { - return mapping_.at(GetRef(op)); - } - - private: - Map mapping_; -}; - - class ShapeLowerMutator : public ExprMutator { public: - static DataType ShapeDType() { - return DataType::Int(32); - }; + static DataType ShapeDType() { return DataType::Int(32); }; - explicit ShapeLowerMutator(IRModule mod) { - mod_ = mod; - } + explicit ShapeLowerMutator(IRModule mod) { mod_ = mod; } IRModule Lower() { ret_mod_ = IRModule(); @@ -73,71 +55,62 @@ class ShapeLowerMutator : public ExprMutator { ret_mod_->Add(p.first, Downcast(new_func)); } return ret_mod_; - } + } - void VisitMatchShape(const MatchShape& binding, - IRBuilder& builder) override { + void VisitMatchShape(const MatchShape& binding) override { Expr value = binding->value; Array pattern = binding->pattern; - Array indexes; + Array indices; for (size_t i = 0; i < pattern.size(); ++i) { IntImm idx = expr2slot_.at(pattern[i]); - indexes.push_back(idx); + indices.push_back(idx); } - ShapeExpr indexes_(indexes); - Call call(ExternFunc("decode_shape"), {value, shape_heap_, indexes_}); - builder->Emit(call); + builder_->Emit(Call(ExternFunc("decode_shape"), {value, shape_heap_, ShapeExpr(indices)}), "_"); + } + + Expr VisitExpr_(const ShapeExprNode* node) override { + tir::PrimFunc func = CalculateShape(GetRef(node)); + GlobalVar shape_func_var(name_table_->GetUniqueName("shape_func")); + // TODO make sure shape_heap doesnt get redefined by local funcs? + builder_->Emit(Call(shape_func_var, {shape_heap_}), "_"); + ret_mod_->Add(shape_func_var, func); + + // construct shape + Array indices; + for (PrimExpr e : node->values) { + indices.push_back(expr2slot_.at(e)); + } + return builder_->Emit(Call(ExternFunc("construct_shape"), {shape_heap_, ShapeExpr(indices)}), + "sh"); } Expr VisitExpr_(const FunctionNode* node) override { - Expr visited_func = ExprMutator::VisitExpr_(node); - const auto* visited = visited_func.as(); - ICHECK(visited); - const auto* seq = visited->body.as(); - ICHECK(seq); - - // prologue block: allocate shape heap - ShapeExpr heap_size({heap_size_}); - Call alloc_heap_call(ExternFunc("relax.alloc_shape_heap"), {heap_size}); - VarBinding binding(shape_heap_, alloc_heap_call); - BindingBlock prologue({binding}); - - // process body - IRBuilder ib = IRBuilderNode::Create(); - Array shapes = CollectShapeExpr(seq->body); - Map mapping; - for (ShapeExpr shape : shapes) { - // generate tir shape function - tir::PrimFunc func = CalculateShape(shape); - GlobalVar shape_func_var("shape_func" + std::to_string(shape_func_counter_++)); - ib->Emit(Call(shape_func_var, {shape_heap_})); - ret_mod_->Add(shape_func_var, func); - - // construct shape - Array indexes; - for (PrimExpr e : shape->values) { - indexes.push_back(expr2slot_.at(e)); - } - ShapeExpr indexes_(indexes); - Call call(ExternFunc("construct_shape"), {shape_heap_, indexes_}); - Var shape_var = ib->Emit(call); - mapping.Set(shape, shape_var); + Array params; + for (Var param : node->params) { + params.push_back(Downcast(this->Mutate(param))); } - Expr new_body = ShapeReplacer(mapping).Mutate(seq->body); + Type ret_type = this->VisitType(node->ret_type); - // epilogue block: kill the shape heap - Call free_heap_call(ExternFunc("relax.free_shape_heap"), {shape_heap_}); - ib->Emit(free_heap_call); + builder_->BeginBindingBlock(); + builder_->Emit(VarBinding( + shape_heap_, Call(ExternFunc("relax.alloc_shape_heap"), {ShapeExpr({heap_size_})}))); + + Expr new_body = this->Mutate(node->body); - // process blocks Array blocks; - blocks.push_back(prologue); - blocks.insert(blocks.end(), seq->blocks.begin(), seq->blocks.end()); - blocks.push_back(ib->GetBlocks().back()); + if (const SeqExprNode* seq = new_body.as()) { + blocks.push_back(builder_->EndBlock()); + blocks.insert(blocks.end(), seq->blocks.begin(), seq->blocks.end()); + builder_->BeginBindingBlock(); + new_body = seq->body; + } - SeqExpr new_seq(blocks, new_body); - return Function(visited->name, visited->params, new_seq, visited->ret_type); + builder_->Emit(Call(ExternFunc("relax.free_shape_heap"), {shape_heap_}), "_"); + blocks.push_back(builder_->EndBlock()); + new_body = SeqExpr(blocks, new_body); + + return Function(node->name, params, new_body, ret_type); } tir::PrimFunc CalculateShape(ShapeExpr s) { @@ -172,19 +145,7 @@ class ShapeLowerMutator : public ExprMutator { }; tir::PostOrderVisit(expr, func); return ret; - } - - Array CollectShapeExpr(Expr expr) const { - Array ret; - auto func = [&ret](const Expr& e) { - if (e->IsInstance()) { - ret.push_back(Downcast(e)); - } - }; - PostOrderVisit(expr, func); - return ret; - } - + } Map PrepareExpr2Slot(Function expr) const { int cnt = 0; @@ -192,7 +153,7 @@ class ShapeLowerMutator : public ExprMutator { auto func = [&](const Expr& e) { if (e->IsInstance()) { ShapeExpr shape = Downcast(e); - for (auto prim_e: shape->values) { + for (auto prim_e : shape->values) { if (ret.count(prim_e) == 0) { IntImm idx(ShapeDType(), cnt++); ret.Set(prim_e, idx); @@ -215,9 +176,7 @@ class ShapeLowerMutator : public ExprMutator { Map expr2slot_; }; - -TVM_REGISTER_GLOBAL("relax.transform.shape_lower") -.set_body_typed([](IRModule mod) { +TVM_REGISTER_GLOBAL("relax.transform.shape_lower").set_body_typed([](IRModule mod) { return ShapeLowerMutator(mod).Lower(); }); diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index f322fa177b2b..e8434408c57c 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -19,6 +19,7 @@ from tvm import relax as rx from tvm.ir import structural_equal + def test_dispatch_var(): m = tir.Var("m", "int32") n = tir.Var("n", "int32") @@ -27,9 +28,11 @@ def test_dispatch_var(): v0 = rx.Var("v0", [m, n], dtype0) v1 = rx.DataflowVar("v1", [n], dtype1) t = None + def fvisit(e): nonlocal t t = type(e) + rx.analysis.post_order_visit(v0, fvisit) assert t == type(v0) rx.analysis.post_order_visit(v1, fvisit) @@ -43,63 +46,65 @@ def test_post_order_visit(): dtype1 = rx.DynTensorType(rank=1, dtype="float16") x = rx.Var("x", [m, n], dtype0) y = rx.Var("y", [n], dtype1) - ib = rx.IRBuilder() + ib = rx.BlockBuilder() with ib.function([x, y]): with ib.dataflow() as df: lv0 = ib.emit(rx.op.add(x, y)) lv1 = ib.emit(rx.op.multiply(lv0, y)) gv0 = ib.emit_output(lv1) - ib.emit_output(gv0) + ib.emit_func_output(gv0) expr = ib.get() names = [] + def fvisit(e): nonlocal names if isinstance(e, tvm.ir.op.Op): names.append(e.name) + rx.analysis.post_order_visit(expr.body, fvisit) assert names == ["relax.add", "relax.multiply"] -def test_lazy_irbuilder(): - m = tir.Var("m", "int32") - n = tir.Var("n", "int32") - dtype0 = rx.DynTensorType(rank=2, dtype="float16") - dtype1 = rx.DynTensorType(rank=2, dtype="float16") - x = rx.Var("x", [m, n], dtype0) - y = rx.Var("y", [m, n], dtype1) - ib = rx.IRBuilder() +# def test_lazy_irbuilder(): +# m = tir.Var("m", "int32") +# n = tir.Var("n", "int32") +# dtype0 = rx.DynTensorType(rank=2, dtype="float16") +# dtype1 = rx.DynTensorType(rank=2, dtype="float16") +# x = rx.Var("x", [m, n], dtype0) +# y = rx.Var("y", [m, n], dtype1) +# ib = rx.BlockBuilder() - # This program should not be rewritten by the fma_rewriter - with ib.function([x, y]): - with ib.dataflow() as df: - lv0 = ib.emit(rx.op.multiply(x, y)) - lv1 = ib.emit(rx.op.multiply(lv0, y)) - gv0 = ib.emit_output(lv1) - ib.emit_output(gv0) - expr = ib.get() +# # This program should not be rewritten by the fma_rewriter +# with ib.function([x, y]): +# with ib.dataflow() as df: +# lv0 = ib.emit(rx.op.multiply(x, y)) +# lv1 = ib.emit(rx.op.multiply(lv0, y)) +# gv0 = ib.emit_output(lv1) +# ib.emit_func_output(gv0) +# expr = ib.get() + +# # before rewrite +# block0 = expr.body.blocks[0] +# v0 = expr.body.blocks[0].bindings[1].var +# s0 = expr.body.blocks[0].bindings[1].value +# assert isinstance(s0, tvm.relay.Call) +# assert s0.op.name == "relax.multiply" - # before rewrite - block0 = expr.body.blocks[0] - v0 = expr.body.blocks[0].bindings[1].var - s0 = expr.body.blocks[0].bindings[1].value - assert isinstance(s0, tvm.relay.Call) - assert s0.op.name == "relax.multiply" +# # after rewrite (the bindings and the dataflow block are reused) +# func = rx.transform.fma_rewrite(expr) - # after rewrite (the bindings and the dataflow block are reused) - func = rx.transform.fma_rewrite(expr) +# block1 = func.body.blocks[0] +# v1 = func.body.blocks[0].bindings[1].var +# s1 = func.body.blocks[0].bindings[1].value - block1 = func.body.blocks[0] - v1 = func.body.blocks[0].bindings[1].var - s1 = func.body.blocks[0].bindings[1].value - - # the dataflow block and vars are reused - assert block0 == block1 - assert v1 == v0 - assert s1 == s0 +# # the dataflow block and vars are reused +# assert block0 == block1 +# assert v1 == v0 +# assert s1 == s0 if __name__ == "__main__": test_dispatch_var() test_post_order_visit() - test_lazy_irbuilder() + # test_lazy_irbuilder() diff --git a/tests/python/relax/test_irbuilder.py b/tests/python/relax/test_blockbuilder.py similarity index 80% rename from tests/python/relax/test_irbuilder.py rename to tests/python/relax/test_blockbuilder.py index 715fec905c77..676bb8eeab96 100644 --- a/tests/python/relax/test_irbuilder.py +++ b/tests/python/relax/test_blockbuilder.py @@ -15,49 +15,38 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations # must import to defer parsing of annotations import tvm from tvm import tir from tvm import relay from tvm import relax as rx -def test_dataflow_block(): +def test_block_builder(): m = tir.Var("m", "int32") n = tir.Var("n", "int32") dtype0 = rx.DynTensorType(rank=2, dtype="float16") dtype1 = rx.DynTensorType(rank=1, dtype="float16") x = rx.Var("x", [m, n], dtype0) y = rx.Var("y", [n], dtype1) - ib = rx.IRBuilder() - - with ib.dataflow() as df: - lv0 = ib.emit(rx.op.add(x, y)) - - assert lv0.name_hint == "lv0" - assert lv0.shape[0] == m - assert lv0.shape[1] == n - assert lv0.checked_type.rank == 2 - assert lv0.checked_type.dtype == "float16" - - lv1 = ib.emit(rx.op.multiply(lv0, y)) - assert lv1.name_hint == "lv1" - - b0 = ib.match_shape(x, [m, n]) - - gv0 = ib.emit_output(lv1) - assert gv0.name_hint == "gv0" - assert gv0.shape[0] == m - assert gv0.shape[1] == n - assert gv0.checked_type.rank == 2 - assert gv0.checked_type.dtype == "float16" - assert isinstance(gv0, rx.Var) - - blocks = ib.get_blocks() - assert len(blocks) == 1 - assert len(blocks[-1].bindings) == 4 - for i in [0, 1, 3]: - assert isinstance(blocks[-1].bindings[i], rx.VarBinding) - assert isinstance(blocks[-1].bindings[2], rx.MatchShape) + ib = rx.BlockBuilder() + + ib._begin_binding_block() + gv0 = ib.emit(rx.op.add(x, y)) + ib._begin_dataflow_block() + lv0 = ib.emit(rx.op.multiply(gv0, y)) + gv1 = ib.emit_output(rx.op.multiply(lv0, lv0)) + b0 = ib._end_block() + ib._begin_dataflow_block() + lv1 = ib.emit(rx.op.multiply(gv0, y)) + gv2 = ib.emit_output(rx.op.multiply(lv1, lv1)) + b1 = ib._end_block() + gv3 = ib.emit(rx.op.add(x, y)) + b2 = ib._end_block() + + assert isinstance(b0, rx.DataflowBlock) + assert isinstance(b1, rx.DataflowBlock) + assert not isinstance(b2, rx.DataflowBlock) def test_function_single_block(): @@ -67,17 +56,17 @@ def test_function_single_block(): dtype1 = rx.DynTensorType(rank=1, dtype="float16") x = rx.Var("x", [m, n], dtype0) y = rx.Var("y", [n], dtype1) - ib = rx.IRBuilder() + ib = rx.BlockBuilder() with ib.function([x, y]): with ib.dataflow() as df: lv0 = ib.emit(rx.op.add(x, y)) - assert lv0.name_hint == "lv0" + assert lv0.name_hint == "lv" lv1 = ib.emit(rx.op.multiply(lv0, y)) assert lv1.name_hint == "lv1" gv0 = ib.emit_output(lv1) - assert gv0.name_hint == "gv0" - ib.emit_output(gv0) + assert gv0.name_hint == "gv" + ib.emit_func_output(gv0) func = ib.get() assert func.params[0] == x @@ -98,21 +87,21 @@ def test_function_multi_blocks(): dtype1 = rx.DynTensorType(rank=1, dtype="float16") x = rx.Var("x", [m, n], dtype0) y = rx.Var("y", [n], dtype1) - ib = rx.IRBuilder() + ib = rx.BlockBuilder() with ib.function([x, y], "func"): with ib.dataflow() as df: lv0 = ib.emit(rx.op.add(x, y)) - assert lv0.name_hint == "lv0" + assert lv0.name_hint == "lv" gv0 = ib.emit_output(lv0) - assert gv0.name_hint == "gv0" + assert gv0.name_hint == "gv" gv1 = ib.emit(rx.op.add(gv0, gv0)) assert gv1.name_hint == "gv1" with ib.dataflow() as df: - lv0 = ib.emit(rx.op.add(gv1, gv1)) - assert lv0.name_hint == "lv0" + lv1 = ib.emit(rx.op.add(gv1, gv1)) + assert lv1.name_hint == "lv1" gv2 = ib.emit_output(gv1) - ib.emit_output(gv2) + ib.emit_func_output(gv2) func = ib.get() assert gv2.shape[0] == m @@ -139,7 +128,7 @@ def test_binary_shape_type_deduction(): y = rx.Var("y", [n], dtype1) z = rx.Var("z", [5], dtype1) w = rx.Var("w", [k], dtype1) - ib = rx.IRBuilder() + ib = rx.BlockBuilder() with ib.function([x, y, z, w]): with ib.dataflow() as df: @@ -168,10 +157,8 @@ def test_binary_shape_type_deduction(): assert isinstance(lv3.checked_type, rx.DynTensorType) assert lv3.checked_type.rank == 1 assert lv3.checked_type.dtype == "float16" - - gv0 = ib.emit_output(lv3) - - ib.emit_output(gv0) + gv0 = ib.emit_output(lv3) + ib.emit_func_output(gv0) assert isinstance(gv0.shape, tvm.relay.Call) assert isinstance(gv0.checked_type, rx.DynTensorType) assert gv0.checked_type.rank == 1 @@ -185,7 +172,7 @@ def test_emit_match_shape(): x = rx.Var("tensor_value", type_annotation=type_anno0) shape_anno = [16, 8] y = rx.Var("shape_value", type_annotation=rx.ShapeType(), shape_annotation=shape_anno) - ib = rx.IRBuilder() + ib = rx.BlockBuilder() with ib.function([x, y]): with ib.dataflow() as df: @@ -201,11 +188,11 @@ def test_emit_match_shape(): # lv1: Shape = match_shape(shape, [m, n]) lv1 = ib.match_shape(y, [m, n]) assert lv1.checked_type == rx.ShapeType() - gv0 = ib.emit_output(lv1) - - ib.emit_output(gv0) + gv0 = ib.emit_output(lv1) - block = ib.get_blocks()[-1] + ib.emit_func_output(gv0) + func = ib.get() + block = func.body.blocks[0] b0, b1 = block.bindings[:2] assert isinstance(b0, rx.MatchShape) assert isinstance(b1, rx.MatchShape) @@ -228,7 +215,7 @@ def test_normalize(): dtype1 = rx.DynTensorType(rank=1, dtype="float16") x = rx.Var("x", [m, n], dtype0) y = rx.Var("y", [n], dtype1) - ib = rx.IRBuilder() + ib = rx.BlockBuilder() add_call = rx.op.multiply(x, y) assert isinstance(add_call.shape, relay.Call) @@ -240,7 +227,7 @@ def test_normalize(): if __name__ == "__main__": - test_dataflow_block() + test_block_builder() test_function_single_block() test_function_multi_blocks() test_binary_shape_type_deduction() diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 3b3248a43ca8..b0a2df064a53 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -22,6 +22,7 @@ from tvm.ir import structural_equal import numpy as np + def test_fma_rewrite(): m = tir.Var("m", "int32") n = tir.Var("n", "int32") @@ -29,13 +30,12 @@ def test_fma_rewrite(): dtype1 = rx.DynTensorType(rank=2, dtype="float16") x = rx.Var("x", [m, n], dtype0) y = rx.Var("y", [m, n], dtype1) - ib = rx.IRBuilder() + ib = rx.BlockBuilder() with ib.function([x, y]): with ib.dataflow() as df: lv0 = ib.emit(rx.op.multiply(x, y)) - lv1 = ib.emit(rx.op.add(lv0, y)) - gv0 = ib.emit_output(lv1) - ib.emit_output(gv0) + gv0 = ib.emit_output(rx.op.add(lv0, y)) + ib.emit_func_output(gv0) expr = ib.get() # before rewrite @@ -49,20 +49,17 @@ def test_fma_rewrite(): # after rewrite func = rx.transform.fma_rewrite(expr) - v1 = func.body.blocks[0].bindings[1].var s1 = func.body.blocks[0].bindings[1].value assert isinstance(s1, tvm.relay.Call) assert s1.op.name == "relax.ewise_fma" assert structural_equal(v1.shape, rx.ShapeExpr([m, n])) assert structural_equal(s1.shape, rx.ShapeExpr([m, n])) - - # The var binded to the fma call is reused because the shape - # and type of var are unchanged after rewriting - assert lv1 == v0 - assert type(func.body.blocks[0].bindings[2].var) == rx.Var - assert type(func.body.blocks[0].bindings[2].value) == rx.DataflowVar + # The var binded to the fma call is reused because the shape + # and type of var are unchanged after rewriting + assert gv0 == v0 + assert type(func.body.blocks[0].bindings[1].var) == rx.Var def test_explicit_memory_rewrite(): @@ -71,12 +68,11 @@ def test_explicit_memory_rewrite(): shape_anno = [m, n] type_anno = rx.DynTensorType(2, "float32") x = rx.Var("x", shape_anno, type_anno) - ib = rx.IRBuilder() + ib = rx.BlockBuilder() with ib.function(x): with ib.dataflow() as df: - lv0 = rx.call_dps([m, n], rx.extern("test.op.identity"), [x]) - gv0 = ib.emit_output(lv0) - ib.emit_output(gv0) + gv0 = ib.emit_output(rx.call_dps([m, n], rx.extern("test.op.identity"), [x])) + ib.emit_func_output(gv0) expr = ib.get() # before rewrite @@ -90,7 +86,7 @@ def test_explicit_memory_rewrite(): # the dataflow block has changed to binding block due to the rewriting block = func.body.blocks[0] - assert isinstance(block, rx.BindingBlock) + assert not isinstance(block, rx.DataflowBlock) s1 = block.bindings[0].value assert isinstance(s1, tvm.relay.Call) @@ -100,18 +96,21 @@ def test_explicit_memory_rewrite(): s2 = block.bindings[1].value assert s2.op.global_symbol == "test.op.identity" + # rx.parser.pretty_print(func) + @rx.script class Mod: def foo(x: Tensor[_, "float32"]) -> Shape: relax.match_shape(x.shape, (n, m)) - return (n*2, m*3) + return (n * 2, m * 3) + def test_shape_lowering(): mod = Mod() new_mod = rx.transform.shape_lower(mod) assert isinstance(new_mod, tvm.IRModule) - assert isinstance(new_mod["shape_func0"], tvm.tir.function.PrimFunc) + assert isinstance(new_mod["shape_func"], tvm.tir.function.PrimFunc) assert isinstance(new_mod["foo"], tvm.relax.expr.Function) code = rx.parser.astext(new_mod) assert "alloc_shape_heap" in code