From 78ee4c06afc7647cf68ed3e1de53571aa2647810 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 2 Jul 2021 15:02:23 +0100 Subject: [PATCH 1/8] [TIR][USMP] Added buffer info extraction pass This commit adds a pass that takes the main (call graph of operators) TIR PrimFunc and each operators also as TIR PrimFunc. The pass will traverse through all TIR PrimFunc starting the from main. Thereafter, it will extract information from tir.allocates. Among the information, the liveness conflicts are reported. * Added test for a linear model * Added test for parallel/serial mixed for loops * Added test for a substructure of inception-style model. * Exposed buffer_info creation to python * Added member functions to update pool info * Unit tests to cover functionality of buffer_info Change-Id: I5e163ac3e83c830629a5d34ed4407c9962701c60 --- include/tvm/tir/usmp/analysis.h | 33 + include/tvm/tir/usmp/utils.h | 181 ++++ python/tvm/script/tir/scope_handler.py | 1 + python/tvm/tir/__init__.py | 1 + python/tvm/tir/ir_builder.py | 1 + python/tvm/tir/usmp/__init__.py | 21 + python/tvm/tir/usmp/_ffi_api.py | 21 + python/tvm/tir/usmp/analysis/__init__.py | 20 + python/tvm/tir/usmp/analysis/_ffi_api.py | 21 + python/tvm/tir/usmp/analysis/analysis.py | 39 + python/tvm/tir/usmp/utils.py | 132 +++ src/tir/usmp/analysis/extract_buffer_info.cc | 278 ++++++ src/tir/usmp/utils.cc | 129 +++ ...st_tir_usmp_analysis_extract_bufferinfo.py | 849 ++++++++++++++++++ tests/python/unittest/test_tir_usmp_utils.py | 127 +++ 15 files changed, 1854 insertions(+) create mode 100644 include/tvm/tir/usmp/analysis.h create mode 100644 include/tvm/tir/usmp/utils.h create mode 100644 python/tvm/tir/usmp/__init__.py create mode 100644 python/tvm/tir/usmp/_ffi_api.py create mode 100644 python/tvm/tir/usmp/analysis/__init__.py create mode 100644 python/tvm/tir/usmp/analysis/_ffi_api.py create mode 100644 python/tvm/tir/usmp/analysis/analysis.py create mode 100644 python/tvm/tir/usmp/utils.py create mode 100644 src/tir/usmp/analysis/extract_buffer_info.cc create mode 100644 src/tir/usmp/utils.cc create mode 100644 tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py create mode 100644 tests/python/unittest/test_tir_usmp_utils.py diff --git a/include/tvm/tir/usmp/analysis.h b/include/tvm/tir/usmp/analysis.h new file mode 100644 index 000000000000..993e99a163a0 --- /dev/null +++ b/include/tvm/tir/usmp/analysis.h @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/analysis.h + * \brief Analysis utilities and passes for TIR Unified Static Memory Planner. + */ +#ifndef TVM_TIR_USMP_ANALYSIS_H_ +#define TVM_TIR_USMP_ANALYSIS_H_ + +namespace tvm { +namespace tir { +namespace usmp {} +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_USMP_ANALYSIS_H_ diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h new file mode 100644 index 000000000000..22938fead8b7 --- /dev/null +++ b/include/tvm/tir/usmp/utils.h @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/usmp/utils.h + * \brief Utilities for Unified Static Memory Planner + */ + +#ifndef TVM_TIR_USMP_UTILS_H_ +#define TVM_TIR_USMP_UTILS_H_ + +#include +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { + +static constexpr const char* kTargetPoolReadWriteAccess = "rw"; +static constexpr const char* kTargetPoolReadOnlyAccess = "ro"; + +/*! + * \brief The pool information to be used by USMP + */ +struct PoolInfoNode : public Object { + /*! \brief The name of the memory pool */ + String pool_name; + /*! \brief The expected size hint to be used by the allocator. + * The size_hint is defaulted to -1 to indicate the pool is not + * size restricted. + */ + Integer size_hint_bytes; + /*! \brief The accessibility from each Target*/ + Map target_access; // 'rw' or 'ro' + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pool_name", &pool_name); + v->Visit("size_hint_bytes", &size_hint_bytes); + v->Visit("target_access", &target_access); + } + + bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const { + return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) && + equal(target_access, other->target_access); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(pool_name); + hash_reduce(size_hint_bytes); + hash_reduce(target_access); + } + + static constexpr const char* _type_key = "tir.usmp.PoolInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object); +}; + +class PoolInfo : public ObjectRef { + public: + TVM_DLL PoolInfo(String pool_name, Map target_access, + Integer size_hint_bytes = -1); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode); +}; + +/*! + * \brief The buffer information to be used by USMP + */ +struct BufferInfoNode : public Object { + /*! \brief The name of the buffer var */ + String name_hint; + /*! \brief The size in terms of bytes */ + Integer size_bytes; + /*! \brief The pool candidates that this buffer can get pooled to*/ + Array pool_candidates; + /*! \brief The byte alignment required within the pool */ + Integer alignment; + /*! \brief The liveness conflicting other buffer info objects */ + Array conflicts; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name_hint", &name_hint); + v->Visit("size_bytes", &size_bytes); + v->Visit("pool_candidates", &pool_candidates); + v->Visit("alignment", &alignment); + v->Visit("conflicts", &conflicts); + } + + bool SEqualReduce(const BufferInfoNode* other, SEqualReducer equal) const { + return equal(name_hint, other->name_hint) && equal(size_bytes, other->size_bytes) && + equal(pool_candidates, other->pool_candidates) && equal(alignment, other->alignment) && + equal(conflicts, other->conflicts); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name_hint); + hash_reduce(size_bytes); + hash_reduce(alignment); + hash_reduce(conflicts); + hash_reduce(pool_candidates); + } + /*! + * \brief Set the liveness conflicts of this BufferInfo + * + * \param conflicting_buffer_info_objs An array of BufferInfo that conflicts in liveness + */ + TVM_DLL void SetConflicts(Array conflicting_buffer_info_objs); + + static constexpr const char* _type_key = "tir.usmp.BufferInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(BufferInfoNode, Object); +}; + +class BufferInfo : public ObjectRef { + public: + TVM_DLL BufferInfo(String name_hint, Integer size_bytes, Array pool_candidates, + Integer alignment = runtime::kDefaultWorkspaceAlignment); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfo, ObjectRef, BufferInfoNode); +}; + +/*! + * \brief The pool allocation produced after the USMP algorithm + */ +struct PoolAllocationNode : public Object { + /*! \brief The assigned PoolInfo object */ + PoolInfo pool_info; + /*! \brief The byte offset where the tensor is supposed to be placed within the pool*/ + Integer byte_offset; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pool_info", &pool_info); + v->Visit("byte_offset", &byte_offset); + } + + bool SEqualReduce(const PoolAllocationNode* other, SEqualReducer equal) const { + return equal(pool_info, other->pool_info) && equal(byte_offset, other->byte_offset); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(pool_info); + hash_reduce(byte_offset); + } + + static constexpr const char* _type_key = "tir.usmp.PoolAllocation"; + TVM_DECLARE_FINAL_OBJECT_INFO(PoolAllocationNode, Object); +}; + +class PoolAllocation : public ObjectRef { + public: + TVM_DLL PoolAllocation(PoolInfo pool_info, Integer byte_offset); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolAllocation, ObjectRef, PoolAllocationNode); +}; + +/*! + * \brief Convert the IR-bound BufferInfo map to an array of BufferInfo + * + * \param buffer_info_map IR-bound BufferInfo map + */ +Array CreateArrayBufferInfo(const Map& buffer_info_map); + +static constexpr const char* kPoolCandidatesIRModAttr = "candidate_memory_pools"; + +} // namespace usmp +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_USMP_UTILS_H_ diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 4750ad7626e2..0ce02d4cc244 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -110,6 +110,7 @@ def __init__(self): def allocate(extents, dtype, scope, condition=True, annotations=None, span=None): condition = tvm.runtime.convert(condition) scope = tvm.runtime.convert(scope) + return tvm.tir.Allocate( self.buffer_var, dtype, diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 428403a98f16..07ceb29ebf98 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -55,3 +55,4 @@ from . import transform from . import analysis from . import stmt_functor +from . import usmp diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 978c630b17ad..a71476b23e44 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -411,6 +411,7 @@ def allocate(self, dtype, shape, name="buf", scope=""): scope : str, optional The scope of the buffer. + Returns ------- buffer : BufferVar diff --git a/python/tvm/tir/usmp/__init__.py b/python/tvm/tir/usmp/__init__.py new file mode 100644 index 000000000000..8aa0d4ccfe88 --- /dev/null +++ b/python/tvm/tir/usmp/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +"""Namespace for Unified Static Memory Planner""" + +from . import analysis +from .utils import BufferInfo diff --git a/python/tvm/tir/usmp/_ffi_api.py b/python/tvm/tir/usmp/_ffi_api.py new file mode 100644 index 000000000000..5899ef0c86ea --- /dev/null +++ b/python/tvm/tir/usmp/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.tir.usmp""" +import tvm._ffi + + +tvm._ffi._init_api("tir.usmp", __name__) diff --git a/python/tvm/tir/usmp/analysis/__init__.py b/python/tvm/tir/usmp/analysis/__init__.py new file mode 100644 index 000000000000..756e8c7204c5 --- /dev/null +++ b/python/tvm/tir/usmp/analysis/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +"""Namespace for Unified Static Memory Planner""" + +from .analysis import extract_buffer_info diff --git a/python/tvm/tir/usmp/analysis/_ffi_api.py b/python/tvm/tir/usmp/analysis/_ffi_api.py new file mode 100644 index 000000000000..36973f19905c --- /dev/null +++ b/python/tvm/tir/usmp/analysis/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.tir.usmp.analysis""" +import tvm._ffi + + +tvm._ffi._init_api("tir.usmp.analysis", __name__) diff --git a/python/tvm/tir/usmp/analysis/analysis.py b/python/tvm/tir/usmp/analysis/analysis.py new file mode 100644 index 000000000000..ff70355a967b --- /dev/null +++ b/python/tvm/tir/usmp/analysis/analysis.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""USMP Analysis Python API for passes""" +# pylint: disable=invalid-name +from . import _ffi_api +from ...function import PrimFunc +from ....ir.module import IRModule + + +def extract_buffer_info(main_func: PrimFunc, mod: IRModule): + """Convert Parallel For Loop to Serial. + + Parameters + ---------- + main_func: tvm.tir.PrimFunc + The main function containing calls to operator PrimFuncs. + mod : tvm.ir.IRModule + The full IRModule containing all PrimFuncs + + Returns + ------- + Map + extracted buffer info objects + """ + return _ffi_api.extract_buffer_info(main_func, mod) diff --git a/python/tvm/tir/usmp/utils.py b/python/tvm/tir/usmp/utils.py new file mode 100644 index 000000000000..5658878fe149 --- /dev/null +++ b/python/tvm/tir/usmp/utils.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""USMP Utilities and Data Structures""" +# pylint: disable=invalid-name + +from typing import Dict, Optional + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.target import Target +from . import _ffi_api + +CANDIDATE_MEMORY_POOL_ATTR = "candidate_memory_pools" + + +@register_object("tir.usmp.BufferInfo") +class BufferInfo(Object): + """BufferInfo object holds information related to buffers + that are associated with tir.allocates and tir.allocate_consts + that will be used with USMP + + Parameters + ---------- + name_hint : str + The name associated with the buffer (derived from TIR) + + size_bytes : int + The size in bytes + + alignment : int + The byte alignment required in the workspace memory + + """ + + def __init__( + self, + name_hint: str, + size_bytes: int, + alignment: int = None, + ): + self.__init_handle_by_constructor__( + _ffi_api.BufferInfo, # type: ignore # pylint: disable=no-member + name_hint, + size_bytes, + alignment, + ) + + def set_pool_candidates(self, pool_candidates: list): + """Sets the pool candidate names""" + _ffi_api.BufferInfoSetPoolCandidates(self, pool_candidates) + + def set_pool_offsets(self, pool_name: str, pool_offset: int): + """Sets the pool offset by name""" + _ffi_api.BufferInfoSetPoolOffset(self, pool_name, pool_offset) + + def set_conflicts(self, conflicts: list): + """Sets the the conflicting array of buffer info objects""" + _ffi_api.BufferInfoSetConflicts(self, conflicts) + + +@register_object("tir.usmp.PoolInfo") +class PoolInfo(Object): + """PoolInfo object holds information related to memory pools + where the statically sized allocate nodes will pooled into. + + Parameters + ---------- + pool_name : str + The name of the memory pool + + target_access : Dict[Target, str] + A dictionary where keys describe which targets could + access the pool where value could take the values : + a) "rw" : read-write access + b) "ro" : write-only acesss + + size_hint_bytes : Optional[int] + The expected size hint to be used by the allocator. + The default value would be -1 which means the pool + is not size restricted. + + """ + + READ_WRITE_ACCESS = "rw" + READ_ONLY_ACCESS = "ro" + + def __init__( + self, pool_name: str, target_access: Dict[Target, str], size_hint_bytes: Optional[int] = -1 + ): + self.__init_handle_by_constructor__( + _ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member + pool_name, + target_access, + size_hint_bytes, + ) + + +@register_object("tir.usmp.PoolAllocation") +class PoolAllocation(Object): + """PoolAllocation object holds information related to an allocation + that indicates an offset in a pool + + Parameters + ---------- + pool_info : PoolInfo + The PoolInfo to which this allocation corresponds to + + byte_offset : int + The offset in the pool where the allocate node should be placed + + """ + + def __init__(self, pool_info: PoolInfo, byte_offset: int): + self.__init_handle_by_constructor__( + _ffi_api.PoolAllocation, # type: ignore # pylint: disable=no-member + pool_info, + byte_offset, + ) diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc new file mode 100644 index 000000000000..2d2eacda57ed --- /dev/null +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/usmp/convert_for_loops_serial.cc + * \brief Convert all for loops to serial for lesser memory consumption + */ +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace tir { +namespace usmp { + +class BufferInfoExtractor : public StmtExprVisitor { + public: + explicit BufferInfoExtractor(const IRModule& module) : module_(module) { + for (const auto& gv_func : module_->functions) { + functions.Set(gv_func.first->name_hint, Downcast(gv_func.second)); + } + // Pushing a scope info for the initial body of the main function + scope_stack.push(ScopeInfo()); + } + Map operator()(const PrimFunc& func); + + private: + void VisitStmt(const Stmt& n) override; + void VisitStmt_(const AllocateNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const LoadNode* op) override; + void VisitStmt_(const StoreNode* op) override; + void VisitStmt_(const ForNode* op) override; + + void UpdateAliases(const Array& args, const PrimFunc& func); + + Map buffer_info_map; + Map buffer_info_start_stmt_idx; + Map buffer_info_end_stmt_idx; + Map allocate_var_to_stmt_map; + + std::unordered_set currently_live_allocates; + int current_stmt_idx = 0; + struct ScopeInfo { + For for_loop; + }; + std::stack scope_stack; + + Map functions; + IRModule module_; +}; + +void BufferInfoExtractor::VisitStmt(const Stmt& n) { + current_stmt_idx += 1; + StmtExprVisitor::VisitStmt(n); +} + +size_t static CalculateExtentsSize(const AllocateNode* op) { + size_t element_size_bytes = op->dtype.bytes(); + size_t num_elements = 1; + for (const auto& ext : op->extents) { + if (ext->IsInstance()) { + num_elements *= Downcast(ext)->value; + } else { + // We cant statically calculate workspace for dynamic shapes + num_elements = 0; + } + } + return (num_elements * element_size_bytes); +} + +void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) { + const auto& currect_scope_info = scope_stack.top(); + const auto& type = Downcast(op->buffer_var->type_annotation); + const auto& storage_scope = type->storage_scope; + + // If the allocate is in a for loop, + // USMP currently only looks at serial for loops. + if ((!currect_scope_info.for_loop.defined()) || + (currect_scope_info.for_loop.defined() && + currect_scope_info.for_loop->kind == ForKind::kSerial && storage_scope == "global")) { + // USMP can only work with buffers that have global storage_scope + auto size_bytes = CalculateExtentsSize(op); + // We only statically memory plan only allocates with known + // compile time sizes. + if (size_bytes) { + // By default, the core compiler is assumed to attach the a default pool to each allocate. + ICHECK(op->annotations.count(kPoolCandidatesIRModAttr)) + << "Every statically sized allocate node needs an pool candidate attribute"; + auto pool_candidates = Downcast>(op->annotations[kPoolCandidatesIRModAttr]); + ICHECK(pool_candidates.size() > 0) + << "The core compiler should at least attach a single PoolInfo. If there were no " + "user-given arguments for memory pools, the default behaviour is a single size " + "un-restricted pool is assigned"; + auto buffer_info = BufferInfo(op->buffer_var->name_hint, size_bytes, pool_candidates); + auto allocate = GetRef(op); + allocate_var_to_stmt_map.Set(op->buffer_var, allocate); + buffer_info_map.Set(allocate, buffer_info); + } + } + StmtExprVisitor::VisitStmt(op->body); +} + +void BufferInfoExtractor::VisitStmt_(const ForNode* op) { + ScopeInfo si{ + GetRef(op), + }; + scope_stack.push(si); + StmtExprVisitor::VisitStmt_(op); + scope_stack.pop(); +} + +void BufferInfoExtractor::VisitExpr_(const LoadNode* op) { + this->VisitExpr(op->buffer_var); + StmtExprVisitor::VisitExpr_(op); +} + +void BufferInfoExtractor::VisitStmt_(const StoreNode* op) { + this->VisitExpr(op->buffer_var); + StmtExprVisitor::VisitStmt_(op); +} + +void BufferInfoExtractor::VisitExpr_(const VarNode* op) { + auto var = GetRef(op); + if (allocate_var_to_stmt_map.count(var)) { + auto allocate = allocate_var_to_stmt_map[var]; + if (buffer_info_start_stmt_idx.count(allocate) == 0) { + buffer_info_start_stmt_idx.Set(allocate, current_stmt_idx); + } + buffer_info_end_stmt_idx.Set(allocate, current_stmt_idx); + } + StmtExprVisitor::VisitExpr_(op); +} + +Array static GetMatchedBuffers(const PrimFunc& func) { + Array buffer_vars; + for (const auto& param : func->params) { + buffer_vars.push_back(func->buffer_map[param]->data); + } + return buffer_vars; +} + +void BufferInfoExtractor::UpdateAliases(const Array& args, const PrimFunc& func) { + auto param_buffers = GetMatchedBuffers(func); + ICHECK(args.size() == param_buffers.size()); + for (size_t i = 0; i < args.size(); i++) { + auto arg = args[i]; + auto param_buf = param_buffers[i]; + // If tir.allocates are passed in to functions + // The function params are re-directed to point + // to the original allocate + if (arg->IsInstance()) { + auto load = Downcast(arg); + if (allocate_var_to_stmt_map.count(load->buffer_var)) { + allocate_var_to_stmt_map.Set(param_buf, allocate_var_to_stmt_map[load->buffer_var]); + } + } else if (arg->IsInstance()) { + auto var = Downcast(arg); + if (allocate_var_to_stmt_map.count(var)) { + allocate_var_to_stmt_map.Set(param_buf, allocate_var_to_stmt_map[var]); + } + } + } +} + +void BufferInfoExtractor::VisitExpr_(const CallNode* op) { + if (op->op.same_as(builtin::call_extern())) { + auto func = functions.at(Downcast(op->args[0])->value); + auto actual_args = Array(op->args.begin() + 1, op->args.end()); + this->UpdateAliases(actual_args, func); + this->VisitStmt(func->body); + } else if (op->op->IsInstance()) { + auto func = Downcast(op->op); + this->UpdateAliases(op->args, func); + this->VisitStmt(func->body); + } else { + StmtExprVisitor::VisitExpr_(op); + } +} + +Map BufferInfoExtractor::operator()(const PrimFunc& main_func) { + this->VisitStmt(main_func->body); + + enum LivenessEventType { START = 0, END = 1 }; + struct LivenessEvent { + size_t tick; + LivenessEventType le_type; + Allocate allocate; + bool operator==(const LivenessEvent& other) { + if (tick == other.tick && le_type == other.le_type && allocate == other.allocate) { + return true; + } + return false; + } + }; + + std::vector le_events; + for (const auto& kv : buffer_info_map) { + if (!kv.first->IsInstance()) { + continue; + } + auto allocate = Downcast(kv.first); + // If the allocate is not used; we remove it from the analysis + if (buffer_info_start_stmt_idx.count(allocate) == 0) { + continue; + } + LivenessEvent le_event_start; + le_event_start.allocate = allocate; + le_event_start.le_type = START; + le_event_start.tick = buffer_info_start_stmt_idx[allocate]; + le_events.push_back(le_event_start); + + LivenessEvent le_event_end; + le_event_end.allocate = allocate; + le_event_end.le_type = END; + le_event_end.tick = buffer_info_end_stmt_idx[allocate]; + le_events.push_back(le_event_end); + } + + std::sort(le_events.begin(), le_events.end(), + [](const LivenessEvent& lhs, const LivenessEvent& rhs) { + if (lhs.tick < rhs.tick) { + return true; + } else if (lhs.tick == rhs.tick && lhs.le_type == START && rhs.le_type == END) { + return true; + } + return false; + }); + std::unordered_set open_set; + for (const auto& le_event : le_events) { + if (le_event.le_type == START) { + for (const auto& open_allocate : open_set) { + buffer_info_map[open_allocate]->conflicts.push_back(buffer_info_map[le_event.allocate]); + buffer_info_map[le_event.allocate]->conflicts.push_back(buffer_info_map[open_allocate]); + } + open_set.insert(le_event.allocate); + } else { + ICHECK(le_event.le_type == END); + open_set.erase(le_event.allocate); + } + } + return this->buffer_info_map; +} + +Map ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) { + return BufferInfoExtractor(mod)(main_func); +} + +TVM_REGISTER_GLOBAL("tir.usmp.analysis.extract_buffer_info") + .set_body_typed([](PrimFunc main_func, IRModule mod) { + return (ExtractBufferInfo(main_func, mod)); + }); + +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc new file mode 100644 index 000000000000..d80ba26f4b77 --- /dev/null +++ b/src/tir/usmp/utils.cc @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/usmp/utils.cc + * \brief Utilities for Unified Static Memory Planner + */ + +#include +#include +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { + +BufferInfo::BufferInfo(String name_hint, Integer size_bytes, Array pool_candidates, + Integer alignment) { + auto bufinfo_node = make_object(); + bufinfo_node->name_hint = name_hint; + bufinfo_node->size_bytes = size_bytes; + bufinfo_node->pool_candidates = pool_candidates; + bufinfo_node->alignment = alignment; + data_ = std::move(bufinfo_node); +} + +void BufferInfoNode::SetConflicts(Array conflicting_buffer_info_objs) { + this->conflicts = conflicting_buffer_info_objs; +} + +TVM_REGISTER_NODE_TYPE(BufferInfoNode); +TVM_REGISTER_GLOBAL("tir.usmp.BufferInfo") + .set_body_typed([](String name_hint, Integer size_bytes, Array pool_candidates, + Integer alignment) { + if (!alignment.defined()) { + return BufferInfo(name_hint, size_bytes, pool_candidates); + } + return BufferInfo(name_hint, size_bytes, pool_candidates, alignment); + }); +TVM_REGISTER_GLOBAL("tir.usmp.BufferInfoSetConflicts") + .set_body_method(&BufferInfoNode::SetConflicts); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "BufferInfoNode(\n" + << "name_hint=" << node->name_hint << ",\n size_bytes=" << node->size_bytes + << ",\n pool_candidates=" << node->pool_candidates + << ",\n alignment=" << node->alignment << ")"; + }); + +PoolInfo::PoolInfo(String pool_name, Map target_access, Integer size_hint_bytes) { + auto poolinfo_node = make_object(); + poolinfo_node->pool_name = pool_name; + poolinfo_node->size_hint_bytes = size_hint_bytes; + poolinfo_node->target_access = target_access; + data_ = std::move(poolinfo_node); +} + +TVM_REGISTER_NODE_TYPE(PoolInfoNode); +TVM_REGISTER_GLOBAL("tir.usmp.PoolInfo") + .set_body_typed([](String pool_name, Map target_access, + Integer size_hint_bytes) { + return PoolInfo(pool_name, target_access, size_hint_bytes); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PoolInfoNode(\n" + << "pool_name=" << node->pool_name << ",\n target_access=" << node->target_access + << ",\n size_hint_bytes=" << node->size_hint_bytes << ")"; + }); + +PoolAllocation::PoolAllocation(PoolInfo pool_info, Integer byte_offset) { + auto pool_allocation_node = make_object(); + pool_allocation_node->pool_info = pool_info; + pool_allocation_node->byte_offset = byte_offset; + data_ = std::move(pool_allocation_node); +} + +TVM_REGISTER_NODE_TYPE(PoolAllocationNode); +TVM_REGISTER_GLOBAL("tir.usmp.PoolAllocation") + .set_body_typed([](PoolInfo pool_info, Integer byte_offset) { + return PoolAllocation(pool_info, byte_offset); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PoolAllocationNode(\n" + << "pool_info=" << node->pool_info << ",\n byte_offset=" << node->byte_offset + << ")"; + }); + +Array CreateArrayBufferInfo(const Map& buffer_info_map) { + Array ret; + for (const auto& kv : buffer_info_map) { + auto buffer_info = kv.second; + ret.push_back(buffer_info); + } + return ret; +} + +TVM_REGISTER_GLOBAL("tir.usmp.CreateArrayBufferInfo") + .set_body_typed([](Map buffer_info_map) { + return (CreateArrayBufferInfo(buffer_info_map)); + }); + +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py new file mode 100644 index 000000000000..7dea5a9d345c --- /dev/null +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -0,0 +1,849 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import tir, script +from tvm.ir import Range +from tvm.script import tir as T +from tvm.tir import stmt_functor +from tvm.tir import PrimFunc +from tvm.tir.usmp import utils as usmp_utils +from tvm.target import Target + + +def _replace_stmt_with_buf_var_names(buffer_info_map): + """helper to replace tir.allocates with buffer names""" + new_buffer_info_map = dict() + for k, v in buffer_info_map.items(): + new_buffer_info_map[k.buffer_var.name] = v + return new_buffer_info_map + + +def _verify_conflicts(main_buf_name, conflicting_buf_names, buffer_info_map): + """helper to check expected liveness conflicts""" + buf_info = buffer_info_map[main_buf_name] + for conflict in buf_info.conflicts: + assert conflict.name_hint in conflicting_buf_names + + +def _get_allocates(primfunc): + """helper to extract all allocate nodes by name""" + allocates = dict() + + def get_allocate(stmt): + if isinstance(stmt, tvm.tir.Allocate): + allocates[str(stmt.buffer_var.name)] = stmt + + stmt_functor.post_order_visit(primfunc.body, get_allocate) + return allocates + + +def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): + """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" + + def set_poolinfos(stmt): + if isinstance(stmt, tvm.tir.Allocate): + return tvm.tir.Allocate( + buffer_var=stmt.buffer_var, + dtype=stmt.dtype, + extents=stmt.extents, + condition=stmt.condition, + body=stmt.body, + annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos}, + ) + + return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos)) + + +def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): + """helper to assing poolinfos to allocate nodes in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) + return ret + + +# fmt: off +@tvm.script.ir_module +class LinearStructure: + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_7, ff_3, 0, True) + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + for ax3_inner_7 in T.serial(0, 64): + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + + @T.prim_func + def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_9 = T.allocate([301056], "int8", "global") + sid_8 = T.allocate([802816], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +def test_linear(): + fast_memory_pool = usmp_utils.PoolInfo( + pool_name="fast_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS} + ) + slow_memory_pool = usmp_utils.PoolInfo( + pool_name="slow_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS} + ) + tir_mod = LinearStructure + tir_mod = assign_poolinfos_to_allocates_in_irmodule( + tir_mod, [fast_memory_pool, slow_memory_pool] + ) + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info( + tir_mod["tvmgen_default_run_model"], tir_mod + ) + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) + + # check conflicts + _verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map) + _verify_conflicts("Conv2dOutput_7", ["PaddedInput_7", "sid_8"], buffer_info_map) + _verify_conflicts("PaddedInput_7", ["sid_9", "Conv2dOutput_7"], buffer_info_map) + _verify_conflicts("tensor_2", ["sid_8"], buffer_info_map) + _verify_conflicts("sid_9", ["PaddedInput_7"], buffer_info_map) + + # check sizes + assert buffer_info_map["sid_8"].size_bytes == 802816 + assert buffer_info_map["Conv2dOutput_7"].size_bytes == 256 + assert buffer_info_map["PaddedInput_7"].size_bytes == 314646 + assert buffer_info_map["tensor_2"].size_bytes == 200704 + assert buffer_info_map["sid_9"].size_bytes == 301056 + + # check_pool_candidates + assert [ + pool_info.pool_name for pool_info in list(buffer_info_map["sid_8"].pool_candidates) + ] == ["fast_memory", "slow_memory"] + + +# fmt: off +@tvm.script.ir_module +class ParallelSerialMixedForLoops: + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) + placeholder_71 = T.match_buffer(placeholder_68, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [3, 3, 64, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [1, 1, 1, 192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_8 = T.allocate([215296], "int16", "global") + for i0_i1_fused_8 in T.serial(0, 58): + for i2_8, i3_8 in T.grid(58, 64): + T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), T.load("int16", placeholder_71.data, ((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_8 in T.parallel(0, 3136): + dummy_allocate = T.allocate([1], "int32", "global") + for ax3_outer_4 in T.serial(0, 3): + Conv2dOutput_8 = T.allocate([64], "int32", "global") + for ff_4 in T.serial(0, 64): + T.store(Conv2dOutput_8, ff_4, 0, True) + for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): + T.store(Conv2dOutput_8, ff_4, (T.load("int32", Conv2dOutput_8, ff_4) + (T.cast(T.load("int16", PaddedInput_8, (((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)), "int32")*T.cast(T.load("int16", placeholder_72.data, (((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)), "int32"))), True) + for ax3_inner_8 in T.serial(0, 64): + T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", input, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), output, dtype="int32")) + + +__tvm_meta__ = None +# fmt: on + + +# fmt: off +@tvm.script.ir_module +class AllSerialForLoops: + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) + placeholder_71 = T.match_buffer(placeholder_68, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [3, 3, 64, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [1, 1, 1, 192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_8 = T.allocate([215296], "int16", "global") + for i0_i1_fused_8 in T.serial(0, 58): + for i2_8, i3_8 in T.grid(58, 64): + T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), T.load("int16", placeholder_71.data, ((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): + dummy_allocate = T.allocate([1], "int32", "global") + for ax3_outer_4 in T.serial(0, 3): + Conv2dOutput_8 = T.allocate([64], "int32", "global") + for ff_4 in T.serial(0, 64): + T.store(Conv2dOutput_8, ff_4, 0, True) + for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): + T.store(Conv2dOutput_8, ff_4, (T.load("int32", Conv2dOutput_8, ff_4) + (T.cast(T.load("int16", PaddedInput_8, (((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)), "int32")*T.cast(T.load("int16", placeholder_72.data, (((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)), "int32"))), True) + for ax3_inner_8 in T.serial(0, 64): + T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", input, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), output, dtype="int32")) + + +__tvm_meta__ = None +# fmt: on + + +def test_parallel_serial_mixed_for_loops(): + global_ws_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + all_serial_tir_mod = AllSerialForLoops + all_serial_tir_mod = assign_poolinfos_to_allocates_in_irmodule( + all_serial_tir_mod, [global_ws_pool] + ) + main_func = all_serial_tir_mod["tvmgen_default_run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, all_serial_tir_mod) + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) + + # When all loops are serial all allocates are touched by USMP + assert len(buffer_info_map) == 3 + for name, _ in buffer_info_map.items(): + assert name in ["dummy_allocate", "Conv2dOutput_8", "PaddedInput_8"] + + parallel_serial_mixed_tir_mod = ParallelSerialMixedForLoops + parallel_serial_mixed_tir_mod = assign_poolinfos_to_allocates_in_irmodule( + parallel_serial_mixed_tir_mod, [global_ws_pool] + ) + main_func = parallel_serial_mixed_tir_mod["tvmgen_default_run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info( + main_func, parallel_serial_mixed_tir_mod + ) + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) + + # USMP will not touch (yet) the allocates inside parallel for loops + assert len(buffer_info_map) == 2 + for name, _ in buffer_info_map.items(): + assert name in ["Conv2dOutput_8", "PaddedInput_8"] + + +# fmt: off +@tvm.script.ir_module +class InceptionStructure: + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d(placeholder: T.handle, tensor: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d", "tir.noalias": True}) + placeholder_1 = T.match_buffer(placeholder, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + tensor_1 = T.match_buffer(tensor, [1, 28, 28, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused in T.serial(0, 28): + for ax2 in T.serial(0, 28): + for ax3_outer_init, ax3_inner_init in T.grid(3, 64): + T.store(tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer_init*64)) + ax3_inner_init), T.uint8(0), True) + for rv0_rv1_fused, ax3_outer, ax3_inner in T.grid(9, 3, 64): + T.store(tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner), T.max(T.load("uint8", tensor_1.data, ((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner)), T.if_then_else(((((ax0_ax1_fused*2) + T.floordiv(rv0_rv1_fused, 3)) < 56) and (((ax2*2) + T.floormod(rv0_rv1_fused, 3)) < 56)), T.load("uint8", placeholder_1.data, ((((((ax0_ax1_fused*21504) + (T.floordiv(rv0_rv1_fused, 3)*10752)) + (ax2*384)) + (T.floormod(rv0_rv1_fused, 3)*192)) + (ax3_outer*64)) + ax3_inner)), T.uint8(0), dtype="uint8")), True) + + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + + @T.prim_func + def tvmgen_default_fused_cast(placeholder_6: T.handle, T_cast: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast", "tir.noalias": True}) + placeholder_7 = T.match_buffer(placeholder_6, [1, 28, 28, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_1 = T.match_buffer(T_cast, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_2 in T.serial(0, 28): + for ax2_2, ax3_outer_1, ax3_inner_2 in T.grid(28, 12, 16): + T.store(T_cast_1.data, ((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2), T.cast(T.load("uint8", placeholder_7.data, ((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2)), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_concatenate(placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, placeholder_11: T.handle, T_concat: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_concatenate", "tir.noalias": True}) + placeholder_12 = T.match_buffer(placeholder_8, [1, 28, 28, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_concat_1 = T.match_buffer(T_concat, [1, 28, 28, 256], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_13 = T.match_buffer(placeholder_9, [1, 28, 28, 128], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_14 = T.match_buffer(placeholder_11, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_15 = T.match_buffer(placeholder_10, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_3 in T.serial(0, 28): + for ax2_3, ax3 in T.grid(28, 256): + T.store(T_concat_1.data, (((ax0_ax1_fused_3*7168) + (ax2_3*256)) + ax3), T.if_then_else((224 <= ax3), T.load("uint8", placeholder_14.data, ((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 224)), T.if_then_else((192 <= ax3), T.load("uint8", placeholder_15.data, ((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 192)), T.if_then_else((64 <= ax3), T.load("uint8", placeholder_13.data, ((((ax0_ax1_fused_3*3584) + (ax2_3*128)) + ax3) - 64)), T.load("uint8", placeholder_12.data, (((ax0_ax1_fused_3*1792) + (ax2_3*64)) + ax3)), dtype="uint8"), dtype="uint8"), dtype="uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_cast_2: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) + placeholder_19 = T.match_buffer(placeholder_16, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_3 = T.match_buffer(T_cast_2, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput = T.allocate([200704], "int16", "global") + for i0_i1_fused in T.serial(0, 56): + for i2, i3 in T.grid(56, 64): + T.store(PaddedInput, (((i0_i1_fused*3584) + (i2*64)) + i3), T.load("int16", placeholder_19.data, (((i0_i1_fused*3584) + (i2*64)) + i3)), True) + for ax0_ax1_fused_ax2_fused in T.serial(0, 3136): + Conv2dOutput = T.allocate([64], "int32", "global") + for ff in T.serial(0, 64): + T.store(Conv2dOutput, ff, 0, True) + for rc in T.serial(0, 64): + T.store(Conv2dOutput, ff, (T.load("int32", Conv2dOutput, ff) + (T.cast(T.load("int16", PaddedInput, ((ax0_ax1_fused_ax2_fused*64) + rc)), "int32")*T.cast(T.load("int16", placeholder_20.data, ((rc*64) + ff)), "int32"))), True) + for ax3_inner_3 in T.serial(0, 64): + T.store(T_cast_3.data, ((ax0_ax1_fused_ax2_fused*64) + ax3_inner_3), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_inner_3)), 1191576922, 31, -4, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, T_cast_4: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) + placeholder_25 = T.match_buffer(placeholder_22, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_26 = T.match_buffer(placeholder_23, [1, 1, 192, 96], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_27 = T.match_buffer(placeholder_24, [1, 1, 1, 96], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_5 = T.match_buffer(T_cast_4, [1, 28, 28, 96], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_1 = T.allocate([150528], "int16", "global") + for i0_i1_fused_1 in T.serial(0, 28): + for i2_1, i3_1 in T.grid(28, 192): + T.store(PaddedInput_1, (((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1), T.load("int16", placeholder_25.data, (((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)), True) + for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 784): + Conv2dOutput_1 = T.allocate([1], "int32", "global") + for ax3_1 in T.serial(0, 96): + T.store(Conv2dOutput_1, 0, 0, True) + for rc_1 in T.serial(0, 192): + T.store(Conv2dOutput_1, 0, (T.load("int32", Conv2dOutput_1, 0) + (T.cast(T.load("int16", PaddedInput_1, ((ax0_ax1_fused_ax2_fused_1*192) + rc_1)), "int32")*T.cast(T.load("int16", placeholder_26.data, ((rc_1*96) + ax3_1)), "int32"))), True) + T.store(T_cast_5.data, ((ax0_ax1_fused_ax2_fused_1*96) + ax3_1), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_1, 0) + T.load("int32", placeholder_27.data, ax3_1)), 1201322342, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", "tir.noalias": True}) + placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_2 = T.allocate([150528], "int16", "global") + for i0_i1_fused_2 in T.serial(0, 28): + for i2_2, i3_2 in T.grid(28, 192): + T.store(PaddedInput_2, (((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2), T.load("int16", placeholder_33.data, (((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)), True) + for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 784): + Conv2dOutput_2 = T.allocate([64], "int32", "global") + for ff_1 in T.serial(0, 64): + T.store(Conv2dOutput_2, ff_1, 0, True) + for rc_2 in T.serial(0, 192): + T.store(Conv2dOutput_2, ff_1, (T.load("int32", Conv2dOutput_2, ff_1) + (T.cast(T.load("int16", PaddedInput_2, ((ax0_ax1_fused_ax2_fused_2*192) + rc_2)), "int32")*T.cast(T.load("int16", placeholder_34.data, ((rc_2*64) + ff_1)), "int32"))), True) + for ax3_inner_4 in T.serial(0, 64): + T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_2*64) + ax3_inner_4), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_2, ax3_inner_4) + T.load("int32", placeholder_35.data, ax3_inner_4)), 1663316467, 31, -7, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast_1(placeholder_36: T.handle, T_cast_10: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast_1", "tir.noalias": True}) + placeholder_37 = T.match_buffer(placeholder_36, [1, 28, 28, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_11 = T.match_buffer(T_cast_10, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_3 = T.allocate([150528], "uint8", "global") + for ax0_ax1_fused_6 in T.serial(0, 28): + for ax2_6 in T.serial(0, 28): + for ax3_outer_init_1, ax3_inner_init_1 in T.grid(3, 64): + T.store(tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_init_1*64)) + ax3_inner_init_1), T.uint8(0), True) + for rv0_rv1_fused_2, ax3_outer_2, ax3_inner_5 in T.grid(9, 3, 64): + T.store(tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5), T.max(T.load("uint8", tensor_3, ((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5)), T.if_then_else(((((1 <= (T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6)) and ((T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6) < 29)) and (1 <= (ax2_6 + T.floormod(rv0_rv1_fused_2, 3)))) and ((ax2_6 + T.floormod(rv0_rv1_fused_2, 3)) < 29)), T.load("uint8", placeholder_37.data, (((((((T.floordiv(rv0_rv1_fused_2, 3)*5376) + (ax0_ax1_fused_6*5376)) + (ax2_6*192)) + (T.floormod(rv0_rv1_fused_2, 3)*192)) + (ax3_outer_2*64)) + ax3_inner_5) - 5568)), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_7 in T.serial(0, 28): + for ax2_7, ax3_4 in T.grid(28, 192): + T.store(T_cast_11.data, (((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4), T.cast(T.load("uint8", tensor_3, (((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4)), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2(placeholder_38: T.handle, placeholder_39: T.handle, placeholder_40: T.handle, T_cast_12: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", "tir.noalias": True}) + placeholder_41 = T.match_buffer(placeholder_38, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_42 = T.match_buffer(placeholder_39, [1, 1, 192, 32], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_43 = T.match_buffer(placeholder_40, [1, 1, 1, 32], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_13 = T.match_buffer(T_cast_12, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_3 = T.allocate([150528], "int16", "global") + for i0_i1_fused_3 in T.serial(0, 28): + for i2_3, i3_3 in T.grid(28, 192): + T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), T.load("int16", placeholder_41.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 784): + Conv2dOutput_3 = T.allocate([1], "int32", "global") + for ax3_5 in T.serial(0, 32): + T.store(Conv2dOutput_3, 0, 0, True) + for rc_3 in T.serial(0, 192): + T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast(T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*T.cast(T.load("int16", placeholder_42.data, ((rc_3*32) + ax3_5)), "int32"))), True) + T.store(T_cast_13.data, ((ax0_ax1_fused_ax2_fused_3*32) + ax3_5), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_43.data, ax3_5)), 1811141736, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1136333842, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_44: T.handle, placeholder_45: T.handle, placeholder_46: T.handle, T_cast_14: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) + placeholder_47 = T.match_buffer(placeholder_44, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_48 = T.match_buffer(placeholder_45, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_49 = T.match_buffer(placeholder_46, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_15 = T.match_buffer(T_cast_14, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_4 = T.allocate([150528], "int16", "global") + for i0_i1_fused_4 in T.serial(0, 28): + for i2_4, i3_4 in T.grid(28, 192): + T.store(PaddedInput_4, (((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4), T.load("int16", placeholder_47.data, (((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)), True) + for ax0_ax1_fused_ax2_fused_4 in T.serial(0, 784): + Conv2dOutput_4 = T.allocate([1], "int32", "global") + for ax3_6 in T.serial(0, 16): + T.store(Conv2dOutput_4, 0, 0, True) + for rc_4 in T.serial(0, 192): + T.store(Conv2dOutput_4, 0, (T.load("int32", Conv2dOutput_4, 0) + (T.cast(T.load("int16", PaddedInput_4, ((ax0_ax1_fused_ax2_fused_4*192) + rc_4)), "int32")*T.cast(T.load("int16", placeholder_48.data, ((rc_4*16) + ax3_6)), "int32"))), True) + T.store(T_cast_15.data, ((ax0_ax1_fused_ax2_fused_4*16) + ax3_6), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_4, 0) + T.load("int32", placeholder_49.data, ax3_6)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1(placeholder_50: T.handle, placeholder_51: T.handle, placeholder_52: T.handle, T_cast_16: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", "tir.noalias": True}) + placeholder_53 = T.match_buffer(placeholder_50, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_54 = T.match_buffer(placeholder_51, [3, 3, 16, 32], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_55 = T.match_buffer(placeholder_52, [1, 1, 1, 32], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_17 = T.match_buffer(T_cast_16, [1, 28, 28, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_5 = T.allocate([14400], "int16", "global") + for i0_i1_fused_5 in T.serial(0, 30): + for i2_5, i3_5 in T.grid(30, 16): + T.store(PaddedInput_5, (((i0_i1_fused_5*480) + (i2_5*16)) + i3_5), T.if_then_else(((((1 <= i0_i1_fused_5) and (i0_i1_fused_5 < 29)) and (1 <= i2_5)) and (i2_5 < 29)), T.load("int16", placeholder_53.data, ((((i0_i1_fused_5*448) + (i2_5*16)) + i3_5) - 464)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_5 in T.serial(0, 784): + Conv2dOutput_5 = T.allocate([1], "int32", "global") + for ax3_7 in T.serial(0, 32): + T.store(Conv2dOutput_5, 0, 0, True) + for ry, rx, rc_5 in T.grid(3, 3, 16): + T.store(Conv2dOutput_5, 0, (T.load("int32", Conv2dOutput_5, 0) + (T.cast(T.load("int16", PaddedInput_5, (((((T.floordiv(ax0_ax1_fused_ax2_fused_5, 28)*480) + (ry*480)) + (rx*16)) + (T.floormod(ax0_ax1_fused_ax2_fused_5, 28)*16)) + rc_5)), "int32")*T.cast(T.load("int16", placeholder_54.data, ((((ry*1536) + (rx*512)) + (rc_5*32)) + ax3_7)), "int32"))), True) + T.store(T_cast_17.data, ((ax0_ax1_fused_ax2_fused_5*32) + ax3_7), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_5, 0) + T.load("int32", placeholder_55.data, ax3_7)), 1131968888, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1900719667, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_(placeholder_56: T.handle, placeholder_57: T.handle, placeholder_58: T.handle, T_cast_18: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", "tir.noalias": True}) + placeholder_59 = T.match_buffer(placeholder_56, [1, 28, 28, 96], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_60 = T.match_buffer(placeholder_57, [3, 3, 96, 128], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_61 = T.match_buffer(placeholder_58, [1, 1, 1, 128], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_19 = T.match_buffer(T_cast_18, [1, 28, 28, 128], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_6 = T.allocate([86400], "int16", "global") + for i0_i1_fused_6 in T.serial(0, 30): + for i2_6, i3_6 in T.grid(30, 96): + T.store(PaddedInput_6, (((i0_i1_fused_6*2880) + (i2_6*96)) + i3_6), T.if_then_else(((((1 <= i0_i1_fused_6) and (i0_i1_fused_6 < 29)) and (1 <= i2_6)) and (i2_6 < 29)), T.load("int16", placeholder_59.data, ((((i0_i1_fused_6*2688) + (i2_6*96)) + i3_6) - 2784)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_6 in T.serial(0, 784): + Conv2dOutput_6 = T.allocate([64], "int32", "global") + for ax3_outer_3 in T.serial(0, 2): + for ff_2 in T.serial(0, 64): + T.store(Conv2dOutput_6, ff_2, 0, True) + for ry_1, rx_1, rc_6 in T.grid(3, 3, 96): + T.store(Conv2dOutput_6, ff_2, (T.load("int32", Conv2dOutput_6, ff_2) + (T.cast(T.load("int16", PaddedInput_6, (((((T.floordiv(ax0_ax1_fused_ax2_fused_6, 28)*2880) + (ry_1*2880)) + (rx_1*96)) + (T.floormod(ax0_ax1_fused_ax2_fused_6, 28)*96)) + rc_6)), "int32")*T.cast(T.load("int16", placeholder_60.data, (((((ry_1*36864) + (rx_1*12288)) + (rc_6*128)) + (ax3_outer_3*64)) + ff_2)), "int32"))), True) + for ax3_inner_6 in T.serial(0, 64): + T.store(T_cast_19.data, (((ax0_ax1_fused_ax2_fused_6*128) + (ax3_outer_3*64)) + ax3_inner_6), T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_6, ax3_inner_6) + T.load("int32", placeholder_61.data, ((ax3_outer_3*64) + ax3_inner_6))), 1374050734, 31, -7, dtype="int32"), 255), 0), "uint8"), "int32"), 1544713713, 31, 0, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "T.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_7, ff_3, 0, True) + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + for ax3_inner_7 in T.serial(0, 64): + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) + placeholder_71 = T.match_buffer(placeholder_68, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [3, 3, 64, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [1, 1, 1, 192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [1, 56, 56, 192], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_8 = T.allocate([215296], "int16", "global") + for i0_i1_fused_8 in T.serial(0, 58): + for i2_8, i3_8 in T.grid(58, 64): + T.store(PaddedInput_8, (((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8), T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), T.load("int16", placeholder_71.data, ((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): + Conv2dOutput_8 = T.allocate([64], "int32", "global") + for ax3_outer_4 in T.serial(0, 3): + for ff_4 in T.serial(0, 64): + T.store(Conv2dOutput_8, ff_4, 0, True) + for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): + T.store(Conv2dOutput_8, ff_4, (T.load("int32", Conv2dOutput_8, ff_4) + (T.cast(T.load("int16", PaddedInput_8, (((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)), "int32")*T.cast(T.load("int16", placeholder_72.data, (((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)), "int32"))), True) + for ax3_inner_8 in T.serial(0, 64): + T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_32 = T.allocate([301056], "int8", "global") + sid_20 = T.allocate([150528], "int8", "global") + sid_6 = T.allocate([401408], "int8", "global") + sid_9 = T.allocate([301056], "int8", "global") + sid_7 = T.allocate([401408], "int8", "global") + sid_8 = T.allocate([802816], "int8", "global") + sid_2 = T.allocate([50176], "int8", "global") + sid_3 = T.allocate([301056], "int8", "global") + sid_19 = T.allocate([100352], "int8", "global") + sid_4 = T.allocate([150528], "int8", "global") + sid_5 = T.allocate([602112], "int8", "global") + sid_25 = T.allocate([25088], "int8", "global") + sid_26 = T.allocate([25088], "int8", "global") + sid_31 = T.allocate([25088], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, sid_7, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_7, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_6, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", sid_6, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_5, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d", sid_5, sid_4, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_cast", sid_4, sid_3, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", sid_3, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_2, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_3, T.lookup_param("p9", dtype="handle"), T.lookup_param("p10", dtype="handle"), sid_20, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", sid_20, T.lookup_param("p11", dtype="handle"), T.lookup_param("p12", dtype="handle"), sid_19, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", sid_3, T.lookup_param("p13", dtype="handle"), T.lookup_param("p14", dtype="handle"), sid_26, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", sid_26, T.lookup_param("p15", dtype="handle"), T.lookup_param("p16", dtype="handle"), sid_25, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast_1", sid_4, sid_32, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", sid_32, T.lookup_param("p17", dtype="handle"), T.lookup_param("p18", dtype="handle"), sid_31, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_concatenate", sid_2, sid_19, sid_25, sid_31, output, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +def test_inception_structure(): + global_ws_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + tir_mod = InceptionStructure + tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool]) + main_func = tir_mod["tvmgen_default_run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) + + # check conflicts + _verify_conflicts("sid_5", ["Conv2dOutput_8", "sid_4"], buffer_info_map) + _verify_conflicts( + "Conv2dOutput_2", ["PaddedInput_2", "sid_4", "sid_3", "sid_2"], buffer_info_map + ) + _verify_conflicts("sid_9", ["PaddedInput_7"], buffer_info_map) + _verify_conflicts("PaddedInput_7", ["sid_9", "Conv2dOutput_7"], buffer_info_map) + _verify_conflicts( + "sid_26", ["sid_19", "Conv2dOutput_4", "sid_2", "sid_4", "PaddedInput_5"], buffer_info_map + ) + _verify_conflicts("Conv2dOutput", ["PaddedInput", "sid_6"], buffer_info_map) + _verify_conflicts( + "PaddedInput_4", ["sid_19", "sid_2", "sid_4", "sid_3", "Conv2dOutput_4"], buffer_info_map + ) + _verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map) + _verify_conflicts("tensor_3", ["sid_25", "sid_19", "sid_2", "sid_4", "sid_32"], buffer_info_map) + _verify_conflicts( + "sid_3", + [ + "sid_4", + "PaddedInput_2", + "Conv2dOutput_2", + "sid_2", + "PaddedInput_1", + "Conv2dOutput_1", + "sid_20", + "PaddedInput_6", + "Conv2dOutput_6", + "sid_19", + "PaddedInput_4", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_32", ["tensor_3", "sid_25", "sid_19", "sid_2", "PaddedInput_3"], buffer_info_map + ) + _verify_conflicts("PaddedInput_8", ["sid_6", "Conv2dOutput_8"], buffer_info_map) + _verify_conflicts( + "Conv2dOutput_6", ["PaddedInput_6", "sid_2", "sid_4", "sid_3", "sid_19"], buffer_info_map + ) + _verify_conflicts( + "sid_4", + [ + "sid_5", + "sid_3", + "PaddedInput_2", + "Conv2dOutput_2", + "sid_2", + "PaddedInput_1", + "Conv2dOutput_1", + "sid_20", + "PaddedInput_6", + "Conv2dOutput_6", + "sid_19", + "PaddedInput_4", + "Conv2dOutput_4", + "sid_26", + "PaddedInput_5", + "Conv2dOutput_5", + "sid_25", + "tensor_3", + ], + buffer_info_map, + ) + _verify_conflicts("PaddedInput_2", ["sid_3", "sid_4", "Conv2dOutput_2"], buffer_info_map) + _verify_conflicts( + "Conv2dOutput_4", ["sid_19", "sid_2", "sid_4", "PaddedInput_4", "sid_26"], buffer_info_map + ) + _verify_conflicts( + "PaddedInput_1", ["sid_2", "sid_4", "sid_3", "Conv2dOutput_1"], buffer_info_map + ) + _verify_conflicts("sid_6", ["Conv2dOutput", "PaddedInput_8"], buffer_info_map) + _verify_conflicts("Conv2dOutput_8", ["PaddedInput_8", "sid_5"], buffer_info_map) + _verify_conflicts( + "sid_25", + [ + "Conv2dOutput_5", + "sid_19", + "sid_2", + "sid_4", + "tensor_3", + "sid_32", + "PaddedInput_3", + "Conv2dOutput_3", + "sid_31", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_6", ["sid_20", "sid_2", "sid_4", "sid_3", "Conv2dOutput_6"], buffer_info_map + ) + _verify_conflicts( + "sid_7", + [ + "tensor_2", + "PaddedInput", + ], + buffer_info_map, + ) + _verify_conflicts("sid_31", ["Conv2dOutput_3", "sid_25", "sid_19", "sid_2"], buffer_info_map) + _verify_conflicts("tensor_2", ["sid_8", "sid_7"], buffer_info_map) + _verify_conflicts( + "sid_2", + [ + "Conv2dOutput_2", + "sid_4", + "sid_3", + "PaddedInput_1", + "Conv2dOutput_1", + "sid_20", + "PaddedInput_6", + "Conv2dOutput_6", + "sid_19", + "PaddedInput_4", + "Conv2dOutput_4", + "sid_26", + "PaddedInput_5", + "Conv2dOutput_5", + "sid_25", + "tensor_3", + "sid_32", + "PaddedInput_3", + "Conv2dOutput_3", + "sid_31", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_3", ["sid_25", "PaddedInput_3", "sid_19", "sid_2", "sid_31"], buffer_info_map + ) + _verify_conflicts("PaddedInput", ["sid_7", "Conv2dOutput"], buffer_info_map) + _verify_conflicts( + "Conv2dOutput_1", ["PaddedInput_1", "sid_2", "sid_4", "sid_3", "sid_20"], buffer_info_map + ) + _verify_conflicts( + "PaddedInput_5", ["sid_26", "sid_19", "sid_2", "sid_4", "Conv2dOutput_5"], buffer_info_map + ) + _verify_conflicts( + "PaddedInput_3", ["sid_32", "sid_25", "sid_19", "sid_2", "Conv2dOutput_3"], buffer_info_map + ) + _verify_conflicts( + "sid_19", + [ + "Conv2dOutput_6", + "sid_2", + "sid_4", + "sid_3", + "PaddedInput_4", + "Conv2dOutput_4", + "sid_26", + "PaddedInput_5", + "Conv2dOutput_5", + "sid_25", + "tensor_3", + "sid_32", + "PaddedInput_3", + "Conv2dOutput_3", + "sid_31", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_5", ["PaddedInput_5", "sid_19", "sid_2", "sid_4", "sid_25"], buffer_info_map + ) + _verify_conflicts("Conv2dOutput_7", ["PaddedInput_7", "sid_8"], buffer_info_map) + _verify_conflicts( + "sid_20", ["sid_2", "Conv2dOutput_1", "sid_4", "sid_3", "PaddedInput_6"], buffer_info_map + ) + + # check sizes + assert buffer_info_map["sid_20"].size_bytes == 150528 + assert buffer_info_map["tensor_2"].size_bytes == 200704 + assert buffer_info_map["sid_5"].size_bytes == 602112 + assert buffer_info_map["sid_9"].size_bytes == 301056 + assert buffer_info_map["Conv2dOutput_3"].size_bytes == 4 + assert buffer_info_map["sid_26"].size_bytes == 25088 + assert buffer_info_map["Conv2dOutput_2"].size_bytes == 256 + assert buffer_info_map["PaddedInput_5"].size_bytes == 28800 + assert buffer_info_map["sid_8"].size_bytes == 802816 + assert buffer_info_map["Conv2dOutput_5"].size_bytes == 4 + assert buffer_info_map["sid_3"].size_bytes == 301056 + assert buffer_info_map["Conv2dOutput"].size_bytes == 256 + assert buffer_info_map["PaddedInput_3"].size_bytes == 301056 + assert buffer_info_map["sid_32"].size_bytes == 301056 + assert buffer_info_map["PaddedInput_8"].size_bytes == 430592 + assert buffer_info_map["sid_4"].size_bytes == 150528 + assert buffer_info_map["PaddedInput_7"].size_bytes == 314646 + assert buffer_info_map["sid_6"].size_bytes == 401408 + assert buffer_info_map["Conv2dOutput_8"].size_bytes == 256 + assert buffer_info_map["sid_25"].size_bytes == 25088 + assert buffer_info_map["PaddedInput"].size_bytes == 401408 + assert buffer_info_map["sid_7"].size_bytes == 401408 + assert buffer_info_map["Conv2dOutput_1"].size_bytes == 4 + assert buffer_info_map["Conv2dOutput_4"].size_bytes == 4 + assert buffer_info_map["PaddedInput_2"].size_bytes == 301056 + assert buffer_info_map["sid_31"].size_bytes == 25088 + assert buffer_info_map["PaddedInput_1"].size_bytes == 301056 + assert buffer_info_map["Conv2dOutput_6"].size_bytes == 256 + assert buffer_info_map["PaddedInput_4"].size_bytes == 301056 + assert buffer_info_map["sid_2"].size_bytes == 50176 + assert buffer_info_map["tensor_3"].size_bytes == 150528 + assert buffer_info_map["Conv2dOutput_7"].size_bytes == 256 + assert buffer_info_map["sid_19"].size_bytes == 100352 + assert buffer_info_map["PaddedInput_6"].size_bytes == 172800 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py new file mode 100644 index 000000000000..0127921fdb56 --- /dev/null +++ b/tests/python/unittest/test_tir_usmp_utils.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import tir, script +from tvm.script import ty +from tvm.tir import stmt_functor + + +# fmt: off +@tvm.script.tir +class LinearStructure: + def tvmgen_default_fused_cast_subtract(placeholder_2: ty.handle, placeholder_3: ty.handle, T_subtract: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = tir.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = tir.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in tir.serial(0, 224): + for ax2_1, ax3_inner_1 in tir.grid(224, 3): + tir.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (tir.cast(tir.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - tir.load("int16", placeholder_5.data, 0)), True) + + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: ty.handle, placeholder_63: ty.handle, placeholder_64: ty.handle, T_cast_20: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) + placeholder_65 = tir.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = tir.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = tir.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = tir.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = tir.allocate([157323], "int16", "global") + for i0_i1_fused_7 in tir.serial(0, 229): + for i2_7, i3_7 in tir.grid(229, 3): + tir.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), tir.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), tir.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), tir.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_7 in tir.serial(0, 12544): + Conv2dOutput_7 = tir.allocate([64], "int32", "global") + for ff_3 in tir.serial(0, 64): + tir.store(Conv2dOutput_7, ff_3, 0, True) + for ry_2, rx_2, rc_7 in tir.grid(7, 7, 3): + tir.store(Conv2dOutput_7, ff_3, (tir.load("int32", Conv2dOutput_7, ff_3) + (tir.cast(tir.load("int16", PaddedInput_7, (((((tir.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (tir.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*tir.cast(tir.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + for ax3_inner_7 in tir.serial(0, 64): + tir.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), tir.cast(tir.max(tir.min(tir.q_multiply_shift((tir.load("int32", Conv2dOutput_7, ax3_inner_7) + tir.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: ty.handle, T_cast_6: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = tir.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = tir.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = tir.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in tir.serial(0, 56): + for ax2_4 in tir.serial(0, 56): + for ax3_init in tir.serial(0, 64): + tir.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), tir.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in tir.grid(9, 64): + tir.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), tir.max(tir.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), tir.if_then_else(((((ax0_ax1_fused_4*2) + tir.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + tir.floormod(rv0_rv1_fused_1, 3)) < 112)), tir.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (tir.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (tir.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), tir.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5 in tir.serial(0, 56): + for ax2_5, ax3_3 in tir.grid(56, 64): + tir.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), tir.cast(tir.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + + def tvmgen_default_run_model(input: ty.handle, output: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + tir.attr("default", "device_id", 0) + tir.attr("default", "device_type", 1) + sid_9 = tir.allocate([301056], "int8", "global") + sid_8 = tir.allocate([802816], "int8", "global") + tir.evaluate(tir.call_extern("tvmgen_default_fused_cast_subtract", input, tir.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + tir.evaluate(tir.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, tir.lookup_param("p1", dtype="handle"), tir.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + tir.evaluate(tir.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +def test_create_buffer_info(): + buffer_info_obj = tvm.tir.usmp.BufferInfo("buf1", 256) + assert buffer_info_obj.name_hint == "buf1" + assert buffer_info_obj.size_bytes == 256 + # default workspace alignment + assert buffer_info_obj.alignment == 1 + + buffer_info_obj = tvm.tir.usmp.BufferInfo("buf2", 512, 8) + assert buffer_info_obj.name_hint == "buf2" + assert buffer_info_obj.size_bytes == 512 + assert buffer_info_obj.alignment == 8 + + +def test_create_array_buffer_info(): + fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") + tir_mod = LinearStructure() + main_func = tir_mod["tvmgen_default_run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_info_array = fcreate_array_bi(buffer_info_map) + + current_offset = 0 + offsets = [] + for bi in buffer_info_array: + bi.set_pool_offsets("global", current_offset) + offsets.append(current_offset) + current_offset += bi.size_bytes + + bi_idx = 0 + for _, bi in buffer_info_map.items(): + assert bi.pool_name == "global" + assert bi.pool_offset == offsets[bi_idx] + bi_idx += 1 + + +if __name__ == "__main__": + pytest.main([__file__]) From 60664a273ed89b242cd517cd94689de3dad009b7 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 6 Oct 2021 12:24:01 +0100 Subject: [PATCH 2/8] [TIR][USMP] Added buffer info extraction pass Swap key-value pairs of returned values of the buffer_info extraction pass. Change-Id: Ia4f7289592bc776ef6189a41a7891038751bf31f --- src/tir/usmp/analysis/extract_buffer_info.cc | 35 ++++++++++--------- src/tir/usmp/utils.cc | 6 ++-- ...st_tir_usmp_analysis_extract_bufferinfo.py | 2 +- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index 2d2eacda57ed..94a8d743f3a0 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -43,7 +43,7 @@ class BufferInfoExtractor : public StmtExprVisitor { // Pushing a scope info for the initial body of the main function scope_stack.push(ScopeInfo()); } - Map operator()(const PrimFunc& func); + Map operator()(const PrimFunc& func); private: void VisitStmt(const Stmt& n) override; @@ -56,7 +56,7 @@ class BufferInfoExtractor : public StmtExprVisitor { void UpdateAliases(const Array& args, const PrimFunc& func); - Map buffer_info_map; + Map buffer_info_map; Map buffer_info_start_stmt_idx; Map buffer_info_end_stmt_idx; Map allocate_var_to_stmt_map; @@ -117,7 +117,7 @@ void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) { auto buffer_info = BufferInfo(op->buffer_var->name_hint, size_bytes, pool_candidates); auto allocate = GetRef(op); allocate_var_to_stmt_map.Set(op->buffer_var, allocate); - buffer_info_map.Set(allocate, buffer_info); + buffer_info_map.Set(buffer_info, allocate); } } StmtExprVisitor::VisitStmt(op->body); @@ -200,16 +200,16 @@ void BufferInfoExtractor::VisitExpr_(const CallNode* op) { } } -Map BufferInfoExtractor::operator()(const PrimFunc& main_func) { +Map BufferInfoExtractor::operator()(const PrimFunc& main_func) { this->VisitStmt(main_func->body); enum LivenessEventType { START = 0, END = 1 }; struct LivenessEvent { size_t tick; LivenessEventType le_type; - Allocate allocate; + BufferInfo buffer_info; bool operator==(const LivenessEvent& other) { - if (tick == other.tick && le_type == other.le_type && allocate == other.allocate) { + if (tick == other.tick && le_type == other.le_type && buffer_info == other.buffer_info) { return true; } return false; @@ -218,22 +218,23 @@ Map BufferInfoExtractor::operator()(const PrimFunc& main_ std::vector le_events; for (const auto& kv : buffer_info_map) { - if (!kv.first->IsInstance()) { + if (!kv.second->IsInstance()) { continue; } - auto allocate = Downcast(kv.first); + auto allocate = Downcast(kv.second); + auto buffer_info = Downcast(kv.first); // If the allocate is not used; we remove it from the analysis if (buffer_info_start_stmt_idx.count(allocate) == 0) { continue; } LivenessEvent le_event_start; - le_event_start.allocate = allocate; + le_event_start.buffer_info = buffer_info; le_event_start.le_type = START; le_event_start.tick = buffer_info_start_stmt_idx[allocate]; le_events.push_back(le_event_start); LivenessEvent le_event_end; - le_event_end.allocate = allocate; + le_event_end.buffer_info = buffer_info; le_event_end.le_type = END; le_event_end.tick = buffer_info_end_stmt_idx[allocate]; le_events.push_back(le_event_end); @@ -248,23 +249,23 @@ Map BufferInfoExtractor::operator()(const PrimFunc& main_ } return false; }); - std::unordered_set open_set; + std::unordered_set open_set; for (const auto& le_event : le_events) { if (le_event.le_type == START) { - for (const auto& open_allocate : open_set) { - buffer_info_map[open_allocate]->conflicts.push_back(buffer_info_map[le_event.allocate]); - buffer_info_map[le_event.allocate]->conflicts.push_back(buffer_info_map[open_allocate]); + for (const auto& open_buffer_info : open_set) { + open_buffer_info->conflicts.push_back(le_event.buffer_info); + le_event.buffer_info->conflicts.push_back(open_buffer_info); } - open_set.insert(le_event.allocate); + open_set.insert(le_event.buffer_info); } else { ICHECK(le_event.le_type == END); - open_set.erase(le_event.allocate); + open_set.erase(le_event.buffer_info); } } return this->buffer_info_map; } -Map ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) { +Map ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) { return BufferInfoExtractor(mod)(main_func); } diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index d80ba26f4b77..119507d6d335 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -110,17 +110,17 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); -Array CreateArrayBufferInfo(const Map& buffer_info_map) { +Array CreateArrayBufferInfo(const Map& buffer_info_map) { Array ret; for (const auto& kv : buffer_info_map) { - auto buffer_info = kv.second; + auto buffer_info = kv.first; ret.push_back(buffer_info); } return ret; } TVM_REGISTER_GLOBAL("tir.usmp.CreateArrayBufferInfo") - .set_body_typed([](Map buffer_info_map) { + .set_body_typed([](Map buffer_info_map) { return (CreateArrayBufferInfo(buffer_info_map)); }); diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py index 7dea5a9d345c..26a3b6161434 100644 --- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -30,7 +30,7 @@ def _replace_stmt_with_buf_var_names(buffer_info_map): """helper to replace tir.allocates with buffer names""" new_buffer_info_map = dict() for k, v in buffer_info_map.items(): - new_buffer_info_map[k.buffer_var.name] = v + new_buffer_info_map[v.buffer_var.name] = k return new_buffer_info_map From 85bb587d21babda8c8c620db3d07dba8f37cce49 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 6 Oct 2021 18:42:08 +0100 Subject: [PATCH 3/8] [TIR][USMP] Added buffer info extraction pass Updating the USMP utility tests to include tests that test creation of PoolInfo and PoolAllocation Objects. Change-Id: I5d349d0ffcac6b0160072d832dd9d5418699228e --- python/tvm/tir/usmp/utils.py | 85 ++++---- ...st_tir_usmp_analysis_extract_bufferinfo.py | 14 +- tests/python/unittest/test_tir_usmp_utils.py | 204 ++++++++++++------ 3 files changed, 186 insertions(+), 117 deletions(-) diff --git a/python/tvm/tir/usmp/utils.py b/python/tvm/tir/usmp/utils.py index 5658878fe149..a0fe9612c441 100644 --- a/python/tvm/tir/usmp/utils.py +++ b/python/tvm/tir/usmp/utils.py @@ -17,7 +17,7 @@ """USMP Utilities and Data Structures""" # pylint: disable=invalid-name -from typing import Dict, Optional +from typing import Dict, Optional, List from tvm._ffi import register_object from tvm.runtime import Object @@ -27,6 +27,43 @@ CANDIDATE_MEMORY_POOL_ATTR = "candidate_memory_pools" +@register_object("tir.usmp.PoolInfo") +class PoolInfo(Object): + """PoolInfo object holds information related to memory pools + where the statically sized allocate nodes will pooled into. + + Parameters + ---------- + pool_name : str + The name of the memory pool + + target_access : Dict[Target, str] + A dictionary where keys describe which targets could + access the pool where value could take the values : + a) "rw" : read-write access + b) "ro" : write-only acesss + + size_hint_bytes : Optional[int] + The expected size hint to be used by the allocator. + The default value would be -1 which means the pool + is not size restricted. + + """ + + READ_WRITE_ACCESS = "rw" + READ_ONLY_ACCESS = "ro" + + def __init__( + self, pool_name: str, target_access: Dict[Target, str], size_hint_bytes: Optional[int] = -1 + ): + self.__init_handle_by_constructor__( + _ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member + pool_name, + target_access, + size_hint_bytes, + ) + + @register_object("tir.usmp.BufferInfo") class BufferInfo(Object): """BufferInfo object holds information related to buffers @@ -41,7 +78,10 @@ class BufferInfo(Object): size_bytes : int The size in bytes - alignment : int + pool_candidates : List[PoolInfo] + The list of candidates pools this buffer could be placed + + alignment : Optional[int] The byte alignment required in the workspace memory """ @@ -50,12 +90,14 @@ def __init__( self, name_hint: str, size_bytes: int, - alignment: int = None, + pool_candidates: List[PoolInfo], + alignment: Optional[int] = None, ): self.__init_handle_by_constructor__( _ffi_api.BufferInfo, # type: ignore # pylint: disable=no-member name_hint, size_bytes, + pool_candidates, alignment, ) @@ -72,43 +114,6 @@ def set_conflicts(self, conflicts: list): _ffi_api.BufferInfoSetConflicts(self, conflicts) -@register_object("tir.usmp.PoolInfo") -class PoolInfo(Object): - """PoolInfo object holds information related to memory pools - where the statically sized allocate nodes will pooled into. - - Parameters - ---------- - pool_name : str - The name of the memory pool - - target_access : Dict[Target, str] - A dictionary where keys describe which targets could - access the pool where value could take the values : - a) "rw" : read-write access - b) "ro" : write-only acesss - - size_hint_bytes : Optional[int] - The expected size hint to be used by the allocator. - The default value would be -1 which means the pool - is not size restricted. - - """ - - READ_WRITE_ACCESS = "rw" - READ_ONLY_ACCESS = "ro" - - def __init__( - self, pool_name: str, target_access: Dict[Target, str], size_hint_bytes: Optional[int] = -1 - ): - self.__init_handle_by_constructor__( - _ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member - pool_name, - target_access, - size_hint_bytes, - ) - - @register_object("tir.usmp.PoolAllocation") class PoolAllocation(Object): """PoolAllocation object holds information related to an allocation diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py index 26a3b6161434..0f8a1ce75809 100644 --- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -53,7 +53,7 @@ def get_allocate(stmt): return allocates -def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): +def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" def set_poolinfos(stmt): @@ -70,12 +70,12 @@ def set_poolinfos(stmt): return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos)) -def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): +def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): """helper to assing poolinfos to allocate nodes in a IRModule""" ret = tvm.IRModule() for global_var, basefunc in mod.functions.items(): if isinstance(basefunc, tvm.tir.PrimFunc): - ret[global_var] = assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) + ret[global_var] = _assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) return ret @@ -158,7 +158,7 @@ def test_linear(): pool_name="slow_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS} ) tir_mod = LinearStructure - tir_mod = assign_poolinfos_to_allocates_in_irmodule( + tir_mod = _assign_poolinfos_to_allocates_in_irmodule( tir_mod, [fast_memory_pool, slow_memory_pool] ) buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info( @@ -274,7 +274,7 @@ def test_parallel_serial_mixed_for_loops(): target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) all_serial_tir_mod = AllSerialForLoops - all_serial_tir_mod = assign_poolinfos_to_allocates_in_irmodule( + all_serial_tir_mod = _assign_poolinfos_to_allocates_in_irmodule( all_serial_tir_mod, [global_ws_pool] ) main_func = all_serial_tir_mod["tvmgen_default_run_model"] @@ -287,7 +287,7 @@ def test_parallel_serial_mixed_for_loops(): assert name in ["dummy_allocate", "Conv2dOutput_8", "PaddedInput_8"] parallel_serial_mixed_tir_mod = ParallelSerialMixedForLoops - parallel_serial_mixed_tir_mod = assign_poolinfos_to_allocates_in_irmodule( + parallel_serial_mixed_tir_mod = _assign_poolinfos_to_allocates_in_irmodule( parallel_serial_mixed_tir_mod, [global_ws_pool] ) main_func = parallel_serial_mixed_tir_mod["tvmgen_default_run_model"] @@ -634,7 +634,7 @@ def test_inception_structure(): target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) tir_mod = InceptionStructure - tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool]) + tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool]) main_func = tir_mod["tvmgen_default_run_model"] buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py index 0127921fdb56..0974f1fe38e4 100644 --- a/tests/python/unittest/test_tir_usmp_utils.py +++ b/tests/python/unittest/test_tir_usmp_utils.py @@ -17,110 +17,174 @@ import pytest import tvm -from tvm import tir, script -from tvm.script import ty +from tvm.script import tir as T from tvm.tir import stmt_functor +from tvm.tir.usmp import utils as usmp_utils +from tvm.target import Target # fmt: off -@tvm.script.tir +@tvm.script.ir_module class LinearStructure: - def tvmgen_default_fused_cast_subtract(placeholder_2: ty.handle, placeholder_3: ty.handle, T_subtract: ty.handle) -> None: + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) - placeholder_4 = tir.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = tir.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T_subtract_1 = tir.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - for ax0_ax1_fused_1 in tir.serial(0, 224): - for ax2_1, ax3_inner_1 in tir.grid(224, 3): - tir.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (tir.cast(tir.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - tir.load("int16", placeholder_5.data, 0)), True) + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: ty.handle, placeholder_63: ty.handle, placeholder_64: ty.handle, T_cast_20: ty.handle) -> None: + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) - placeholder_65 = tir.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_66 = tir.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_67 = tir.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_21 = tir.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_7 = tir.allocate([157323], "int16", "global") - for i0_i1_fused_7 in tir.serial(0, 229): - for i2_7, i3_7 in tir.grid(229, 3): - tir.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), tir.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), tir.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), tir.int16(0), dtype="int16"), True) - for ax0_ax1_fused_ax2_fused_7 in tir.serial(0, 12544): - Conv2dOutput_7 = tir.allocate([64], "int32", "global") - for ff_3 in tir.serial(0, 64): - tir.store(Conv2dOutput_7, ff_3, 0, True) - for ry_2, rx_2, rc_7 in tir.grid(7, 7, 3): - tir.store(Conv2dOutput_7, ff_3, (tir.load("int32", Conv2dOutput_7, ff_3) + (tir.cast(tir.load("int16", PaddedInput_7, (((((tir.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (tir.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*tir.cast(tir.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) - for ax3_inner_7 in tir.serial(0, 64): - tir.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), tir.cast(tir.max(tir.min(tir.q_multiply_shift((tir.load("int32", Conv2dOutput_7, ax3_inner_7) + tir.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) - - def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: ty.handle, T_cast_6: ty.handle) -> None: + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_7, ff_3, 0, True) + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + for ax3_inner_7 in T.serial(0, 64): + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) - placeholder_29 = tir.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_cast_7 = tir.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - tensor_2 = tir.allocate([200704], "uint8", "global") - for ax0_ax1_fused_4 in tir.serial(0, 56): - for ax2_4 in tir.serial(0, 56): - for ax3_init in tir.serial(0, 64): - tir.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), tir.uint8(0), True) - for rv0_rv1_fused_1, ax3_2 in tir.grid(9, 64): - tir.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), tir.max(tir.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), tir.if_then_else(((((ax0_ax1_fused_4*2) + tir.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + tir.floormod(rv0_rv1_fused_1, 3)) < 112)), tir.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (tir.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (tir.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), tir.uint8(0), dtype="uint8")), True) - for ax0_ax1_fused_5 in tir.serial(0, 56): - for ax2_5, ax3_3 in tir.grid(56, 64): - tir.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), tir.cast(tir.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) - - def tvmgen_default_run_model(input: ty.handle, output: ty.handle) -> None: + tensor_2 = T.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + + @T.prim_func + def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) # body - tir.attr("default", "device_id", 0) - tir.attr("default", "device_type", 1) - sid_9 = tir.allocate([301056], "int8", "global") - sid_8 = tir.allocate([802816], "int8", "global") - tir.evaluate(tir.call_extern("tvmgen_default_fused_cast_subtract", input, tir.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) - tir.evaluate(tir.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, tir.lookup_param("p1", dtype="handle"), tir.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) - tir.evaluate(tir.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_9 = T.allocate([301056], "int8", "global") + sid_8 = T.allocate([802816], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) __tvm_meta__ = None # fmt: on +def test_create_pool_info(): + target = Target("c") + pool_info = usmp_utils.PoolInfo( + pool_name="foo_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + assert pool_info.pool_name == "foo_workspace" + assert dict(pool_info.target_access) == {target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} + # default pool size constraint + assert pool_info.size_hint_bytes == -1 + + pool_info = usmp_utils.PoolInfo( + pool_name="bar_workspace", + target_access={target: usmp_utils.PoolInfo.READ_ONLY_ACCESS}, + size_hint_bytes=1425, + ) + assert pool_info.pool_name == "bar_workspace" + assert dict(pool_info.target_access) == {target: usmp_utils.PoolInfo.READ_ONLY_ACCESS} + assert pool_info.size_hint_bytes == 1425 + + def test_create_buffer_info(): - buffer_info_obj = tvm.tir.usmp.BufferInfo("buf1", 256) + global_ws_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + buffer_info_obj = tvm.tir.usmp.BufferInfo( + name_hint="buf1", size_bytes=256, pool_candidates=[global_ws_pool] + ) assert buffer_info_obj.name_hint == "buf1" assert buffer_info_obj.size_bytes == 256 + assert list(buffer_info_obj.pool_candidates) == [global_ws_pool] # default workspace alignment assert buffer_info_obj.alignment == 1 - buffer_info_obj = tvm.tir.usmp.BufferInfo("buf2", 512, 8) + buffer_info_obj = tvm.tir.usmp.BufferInfo("buf2", 512, [global_ws_pool], 8) assert buffer_info_obj.name_hint == "buf2" assert buffer_info_obj.size_bytes == 512 + assert list(buffer_info_obj.pool_candidates) == [global_ws_pool] assert buffer_info_obj.alignment == 8 +def test_create_pool_allocation(): + pool_info = usmp_utils.PoolInfo( + pool_name="foo_workspace", + target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + pool_allocation = usmp_utils.PoolAllocation(pool_info=pool_info, byte_offset=64) + assert pool_allocation.pool_info == pool_info + assert pool_allocation.byte_offset == 64 + + +def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): + """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" + + def set_poolinfos(stmt): + if isinstance(stmt, tvm.tir.Allocate): + return tvm.tir.Allocate( + buffer_var=stmt.buffer_var, + dtype=stmt.dtype, + extents=stmt.extents, + condition=stmt.condition, + body=stmt.body, + annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos}, + ) + + return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos)) + + +def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): + """helper to assing poolinfos to allocate nodes in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = _assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) + return ret + + def test_create_array_buffer_info(): + global_ws_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") - tir_mod = LinearStructure() + tir_mod = LinearStructure + tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool]) main_func = tir_mod["tvmgen_default_run_model"] buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) buffer_info_array = fcreate_array_bi(buffer_info_map) - - current_offset = 0 - offsets = [] - for bi in buffer_info_array: - bi.set_pool_offsets("global", current_offset) - offsets.append(current_offset) - current_offset += bi.size_bytes - - bi_idx = 0 - for _, bi in buffer_info_map.items(): - assert bi.pool_name == "global" - assert bi.pool_offset == offsets[bi_idx] - bi_idx += 1 + for buffer_info in buffer_info_array: + assert buffer_info in buffer_info_map.keys() if __name__ == "__main__": From 7cd3bc583d053d2a613d90dd4d7ab4e8c60badc8 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Thu, 7 Oct 2021 17:05:16 +0100 Subject: [PATCH 4/8] [TIR][USMP] Added buffer info extraction pass * Removing the unnecessary header : include/tvm/tir/usmp/analysis.h * Some nits and cleanup Change-Id: Iac3ddd9428c56cd8ef49cf643e797bf6fdf4e97a --- include/tvm/tir/usmp/analysis.h | 33 -------------------- src/tir/usmp/analysis/extract_buffer_info.cc | 4 +-- 2 files changed, 2 insertions(+), 35 deletions(-) delete mode 100644 include/tvm/tir/usmp/analysis.h diff --git a/include/tvm/tir/usmp/analysis.h b/include/tvm/tir/usmp/analysis.h deleted file mode 100644 index 993e99a163a0..000000000000 --- a/include/tvm/tir/usmp/analysis.h +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/tir/analysis.h - * \brief Analysis utilities and passes for TIR Unified Static Memory Planner. - */ -#ifndef TVM_TIR_USMP_ANALYSIS_H_ -#define TVM_TIR_USMP_ANALYSIS_H_ - -namespace tvm { -namespace tir { -namespace usmp {} -} // namespace tir -} // namespace tvm - -#endif // TVM_TIR_USMP_ANALYSIS_H_ diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index 94a8d743f3a0..0413fb2db75a 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -77,14 +77,14 @@ void BufferInfoExtractor::VisitStmt(const Stmt& n) { StmtExprVisitor::VisitStmt(n); } -size_t static CalculateExtentsSize(const AllocateNode* op) { +static size_t CalculateExtentsSize(const AllocateNode* op) { size_t element_size_bytes = op->dtype.bytes(); size_t num_elements = 1; for (const auto& ext : op->extents) { if (ext->IsInstance()) { num_elements *= Downcast(ext)->value; } else { - // We cant statically calculate workspace for dynamic shapes + // We can't statically calculate workspace for dynamic shapes num_elements = 0; } } From 486facf2704ccf6b6fa2011e72ecb92273d40ea4 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Thu, 7 Oct 2021 17:19:06 +0100 Subject: [PATCH 5/8] [TIR][USMP] Added buffer info extraction pass * Change the class data members to have a trailing underscore Change-Id: I71809b3c73b0bc0cd133fad1392ae8c17c895ee4 --- src/tir/usmp/analysis/extract_buffer_info.cc | 60 ++++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index 0413fb2db75a..d98a136bb7e4 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -38,10 +38,10 @@ class BufferInfoExtractor : public StmtExprVisitor { public: explicit BufferInfoExtractor(const IRModule& module) : module_(module) { for (const auto& gv_func : module_->functions) { - functions.Set(gv_func.first->name_hint, Downcast(gv_func.second)); + functions_.Set(gv_func.first->name_hint, Downcast(gv_func.second)); } // Pushing a scope info for the initial body of the main function - scope_stack.push(ScopeInfo()); + scope_stack_.push(ScopeInfo()); } Map operator()(const PrimFunc& func); @@ -56,24 +56,24 @@ class BufferInfoExtractor : public StmtExprVisitor { void UpdateAliases(const Array& args, const PrimFunc& func); - Map buffer_info_map; - Map buffer_info_start_stmt_idx; - Map buffer_info_end_stmt_idx; - Map allocate_var_to_stmt_map; + Map buffer_info_map_; + Map buffer_info_start_stmt_idx_; + Map buffer_info_end_stmt_idx_; + Map allocate_var_to_stmt_map_; std::unordered_set currently_live_allocates; - int current_stmt_idx = 0; + int current_stmt_idx_ = 0; struct ScopeInfo { For for_loop; }; - std::stack scope_stack; + std::stack scope_stack_; - Map functions; + Map functions_; IRModule module_; }; void BufferInfoExtractor::VisitStmt(const Stmt& n) { - current_stmt_idx += 1; + current_stmt_idx_ += 1; StmtExprVisitor::VisitStmt(n); } @@ -92,7 +92,7 @@ static size_t CalculateExtentsSize(const AllocateNode* op) { } void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) { - const auto& currect_scope_info = scope_stack.top(); + const auto& currect_scope_info = scope_stack_.top(); const auto& type = Downcast(op->buffer_var->type_annotation); const auto& storage_scope = type->storage_scope; @@ -116,8 +116,8 @@ void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) { "un-restricted pool is assigned"; auto buffer_info = BufferInfo(op->buffer_var->name_hint, size_bytes, pool_candidates); auto allocate = GetRef(op); - allocate_var_to_stmt_map.Set(op->buffer_var, allocate); - buffer_info_map.Set(buffer_info, allocate); + allocate_var_to_stmt_map_.Set(op->buffer_var, allocate); + buffer_info_map_.Set(buffer_info, allocate); } } StmtExprVisitor::VisitStmt(op->body); @@ -127,9 +127,9 @@ void BufferInfoExtractor::VisitStmt_(const ForNode* op) { ScopeInfo si{ GetRef(op), }; - scope_stack.push(si); + scope_stack_.push(si); StmtExprVisitor::VisitStmt_(op); - scope_stack.pop(); + scope_stack_.pop(); } void BufferInfoExtractor::VisitExpr_(const LoadNode* op) { @@ -144,12 +144,12 @@ void BufferInfoExtractor::VisitStmt_(const StoreNode* op) { void BufferInfoExtractor::VisitExpr_(const VarNode* op) { auto var = GetRef(op); - if (allocate_var_to_stmt_map.count(var)) { - auto allocate = allocate_var_to_stmt_map[var]; - if (buffer_info_start_stmt_idx.count(allocate) == 0) { - buffer_info_start_stmt_idx.Set(allocate, current_stmt_idx); + if (allocate_var_to_stmt_map_.count(var)) { + auto allocate = allocate_var_to_stmt_map_[var]; + if (buffer_info_start_stmt_idx_.count(allocate) == 0) { + buffer_info_start_stmt_idx_.Set(allocate, current_stmt_idx_); } - buffer_info_end_stmt_idx.Set(allocate, current_stmt_idx); + buffer_info_end_stmt_idx_.Set(allocate, current_stmt_idx_); } StmtExprVisitor::VisitExpr_(op); } @@ -173,13 +173,13 @@ void BufferInfoExtractor::UpdateAliases(const Array& args, const PrimF // to the original allocate if (arg->IsInstance()) { auto load = Downcast(arg); - if (allocate_var_to_stmt_map.count(load->buffer_var)) { - allocate_var_to_stmt_map.Set(param_buf, allocate_var_to_stmt_map[load->buffer_var]); + if (allocate_var_to_stmt_map_.count(load->buffer_var)) { + allocate_var_to_stmt_map_.Set(param_buf, allocate_var_to_stmt_map_[load->buffer_var]); } } else if (arg->IsInstance()) { auto var = Downcast(arg); - if (allocate_var_to_stmt_map.count(var)) { - allocate_var_to_stmt_map.Set(param_buf, allocate_var_to_stmt_map[var]); + if (allocate_var_to_stmt_map_.count(var)) { + allocate_var_to_stmt_map_.Set(param_buf, allocate_var_to_stmt_map_[var]); } } } @@ -187,7 +187,7 @@ void BufferInfoExtractor::UpdateAliases(const Array& args, const PrimF void BufferInfoExtractor::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::call_extern())) { - auto func = functions.at(Downcast(op->args[0])->value); + auto func = functions_.at(Downcast(op->args[0])->value); auto actual_args = Array(op->args.begin() + 1, op->args.end()); this->UpdateAliases(actual_args, func); this->VisitStmt(func->body); @@ -217,26 +217,26 @@ Map BufferInfoExtractor::operator()(const PrimFunc& main_ }; std::vector le_events; - for (const auto& kv : buffer_info_map) { + for (const auto& kv : buffer_info_map_) { if (!kv.second->IsInstance()) { continue; } auto allocate = Downcast(kv.second); auto buffer_info = Downcast(kv.first); // If the allocate is not used; we remove it from the analysis - if (buffer_info_start_stmt_idx.count(allocate) == 0) { + if (buffer_info_start_stmt_idx_.count(allocate) == 0) { continue; } LivenessEvent le_event_start; le_event_start.buffer_info = buffer_info; le_event_start.le_type = START; - le_event_start.tick = buffer_info_start_stmt_idx[allocate]; + le_event_start.tick = buffer_info_start_stmt_idx_[allocate]; le_events.push_back(le_event_start); LivenessEvent le_event_end; le_event_end.buffer_info = buffer_info; le_event_end.le_type = END; - le_event_end.tick = buffer_info_end_stmt_idx[allocate]; + le_event_end.tick = buffer_info_end_stmt_idx_[allocate]; le_events.push_back(le_event_end); } @@ -262,7 +262,7 @@ Map BufferInfoExtractor::operator()(const PrimFunc& main_ open_set.erase(le_event.buffer_info); } } - return this->buffer_info_map; + return this->buffer_info_map_; } Map ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) { From 63a155c99643770e2b77ab13c305c59444fc37f0 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Thu, 21 Oct 2021 18:08:34 +0100 Subject: [PATCH 6/8] [TIR][USMP] Added buffer info extraction pass Adding more documentation for data structures and the approach Change-Id: Ide2bfffaeff9add86853b6992017264e5d796299 --- include/tvm/tir/usmp/utils.h | 28 +++++++-- python/tvm/tir/usmp/utils.py | 15 ++++- src/tir/usmp/analysis/extract_buffer_info.cc | 58 +++++++++++++++---- src/tir/usmp/utils.cc | 5 +- ...st_tir_usmp_analysis_extract_bufferinfo.py | 6 +- tests/python/unittest/test_tir_usmp_utils.py | 3 +- 6 files changed, 95 insertions(+), 20 deletions(-) diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h index 22938fead8b7..0110f427c707 100644 --- a/include/tvm/tir/usmp/utils.h +++ b/include/tvm/tir/usmp/utils.h @@ -33,7 +33,17 @@ namespace tvm { namespace tir { namespace usmp { +/*! + * \brief The string parameter to indicate read and write access to a pool + * This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in + * python/tvm/tir/usmp/utils.py + */ static constexpr const char* kTargetPoolReadWriteAccess = "rw"; +/*! + * \brief The string parameter to indicate read only access to a pool + * This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in + * python/tvm/tir/usmp/utils.py + */ static constexpr const char* kTargetPoolReadOnlyAccess = "ro"; /*! @@ -43,8 +53,8 @@ struct PoolInfoNode : public Object { /*! \brief The name of the memory pool */ String pool_name; /*! \brief The expected size hint to be used by the allocator. - * The size_hint is defaulted to -1 to indicate the pool is not - * size restricted. + * The size_hint_bytes is defaulted to kUnrestrictedPoolSizeHint + * to indicate the pool is not size restricted. */ Integer size_hint_bytes; /*! \brief The accessibility from each Target*/ @@ -71,10 +81,15 @@ struct PoolInfoNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object); }; +/*! + * \brief The PoolSize is unrestricted for the memory planner + */ +static const int kUnrestrictedPoolSizeHint = -1; + class PoolInfo : public ObjectRef { public: TVM_DLL PoolInfo(String pool_name, Map target_access, - Integer size_hint_bytes = -1); + Integer size_hint_bytes = kUnrestrictedPoolSizeHint); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode); }; @@ -172,7 +187,12 @@ class PoolAllocation : public ObjectRef { */ Array CreateArrayBufferInfo(const Map& buffer_info_map); -static constexpr const char* kPoolCandidatesIRModAttr = "candidate_memory_pools"; +/*! + * \brief The allocate node attribute to indicate candidate memory pools. + * This needs to be kept in sync with CANDIDATE_MEMORY_POOL_ATTR in + * python/tvm/tir/usmp/utils.py. + */ +static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_pools"; } // namespace usmp } // namespace tir diff --git a/python/tvm/tir/usmp/utils.py b/python/tvm/tir/usmp/utils.py index a0fe9612c441..0445775869e8 100644 --- a/python/tvm/tir/usmp/utils.py +++ b/python/tvm/tir/usmp/utils.py @@ -24,6 +24,10 @@ from tvm.target import Target from . import _ffi_api + +# The allocate node attribute to indicate candidate memory pools. +# This needs to be kept in sync with CANDIDATE_MEMORY_POOL_ATTR in +# include/tvm/tir/usmp/utils.h CANDIDATE_MEMORY_POOL_ATTR = "candidate_memory_pools" @@ -50,11 +54,20 @@ class PoolInfo(Object): """ + # The string parameter to indicate read and write access to a pool + # This needs to be kept in sync with kTargetPoolReadWriteAccess in + # include/tvm/tir/usmp/utils.h READ_WRITE_ACCESS = "rw" + # The string parameter to indicate read only access to a pool + # This needs to be kept in sync with kTargetPoolReadOnlyAccess in + # include/tvm/tir/usmp/utils.h READ_ONLY_ACCESS = "ro" def __init__( - self, pool_name: str, target_access: Dict[Target, str], size_hint_bytes: Optional[int] = -1 + self, + pool_name: str, + target_access: Dict[Target, str], + size_hint_bytes: Optional[int] = None, ): self.__init_handle_by_constructor__( _ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index d98a136bb7e4..c6bd00e49299 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -18,8 +18,12 @@ */ /*! - * \file tir/analysis/usmp/convert_for_loops_serial.cc - * \brief Convert all for loops to serial for lesser memory consumption + * \file tir/analysis/usmp/extract_buffer_info.cc + * + * \brief This analysis pass consumes a TIR IRModule with a main function + * that defines a ordering in the calles to operators and produces BufferInfo + * objects that contains information about tir.allocate nodes and liveness + * conflicts between other tir.allocate nodes. */ #include #include @@ -34,6 +38,12 @@ namespace tvm { namespace tir { namespace usmp { +/*! \brief This class takes a TIR IRModule and a main PrimFunc that contains + * that defines a ordering in the calles to operators and produces BufferInfo + * objects that contains information about tir.allocate nodes and liveness + * conflicts between other tir.allocate nodes. + */ + class BufferInfoExtractor : public StmtExprVisitor { public: explicit BufferInfoExtractor(const IRModule& module) : module_(module) { @@ -63,6 +73,11 @@ class BufferInfoExtractor : public StmtExprVisitor { std::unordered_set currently_live_allocates; int current_stmt_idx_ = 0; + // This structure is supposed to contain information + // around the scope the visitor is currently in. + // We only check whether the current scope belong to + // a Serial ForKind. We are not planning for Parallel + // ForKind just yet. struct ScopeInfo { For for_loop; }; @@ -77,7 +92,7 @@ void BufferInfoExtractor::VisitStmt(const Stmt& n) { StmtExprVisitor::VisitStmt(n); } -static size_t CalculateExtentsSize(const AllocateNode* op) { +static Integer CalculateExtentsSize(const AllocateNode* op) { size_t element_size_bytes = op->dtype.bytes(); size_t num_elements = 1; for (const auto& ext : op->extents) { @@ -85,10 +100,10 @@ static size_t CalculateExtentsSize(const AllocateNode* op) { num_elements *= Downcast(ext)->value; } else { // We can't statically calculate workspace for dynamic shapes - num_elements = 0; + return Integer(); } } - return (num_elements * element_size_bytes); + return Integer(num_elements * element_size_bytes); } void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) { @@ -96,20 +111,25 @@ void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) { const auto& type = Downcast(op->buffer_var->type_annotation); const auto& storage_scope = type->storage_scope; - // If the allocate is in a for loop, - // USMP currently only looks at serial for loops. + // If the allocate is in a for loop, USMP currently only looks at serial for loops. + // If its not a serial for loop, then memory planner will omit them in the current memory planning + // process leaving them to as tir.allocate nodes for codegen. Additionally, the USMP can only work + // with buffers that have global storage_scope if ((!currect_scope_info.for_loop.defined()) || (currect_scope_info.for_loop.defined() && currect_scope_info.for_loop->kind == ForKind::kSerial && storage_scope == "global")) { - // USMP can only work with buffers that have global storage_scope auto size_bytes = CalculateExtentsSize(op); // We only statically memory plan only allocates with known // compile time sizes. - if (size_bytes) { + if (size_bytes.defined()) { // By default, the core compiler is assumed to attach the a default pool to each allocate. - ICHECK(op->annotations.count(kPoolCandidatesIRModAttr)) + ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr)) << "Every statically sized allocate node needs an pool candidate attribute"; - auto pool_candidates = Downcast>(op->annotations[kPoolCandidatesIRModAttr]); + auto pool_candidates = + Downcast>(op->annotations[kPoolCandidatesAllocateAttr]); + + // TODO(@manupa-arm): improve the error when the responsible component for attaching a single + // pool is added ICHECK(pool_candidates.size() > 0) << "The core compiler should at least attach a single PoolInfo. If there were no " "user-given arguments for memory pools, the default behaviour is a single size " @@ -203,6 +223,13 @@ void BufferInfoExtractor::VisitExpr_(const CallNode* op) { Map BufferInfoExtractor::operator()(const PrimFunc& main_func) { this->VisitStmt(main_func->body); + // A liveness event is an event that when + // traversing the tir.Stmts where tir.allocate node + // begins or ceases to be Live. This particular struct + // is used to solve interval overlap problem using + // a sweep-line algorithm. For that, we need to record + // where the liveness event occurred in a chronological + // order. enum LivenessEventType { START = 0, END = 1 }; struct LivenessEvent { size_t tick; @@ -216,6 +243,8 @@ Map BufferInfoExtractor::operator()(const PrimFunc& main_ } }; + // Create a vector of liveness events + // associated with each BufferNodes. std::vector le_events; for (const auto& kv : buffer_info_map_) { if (!kv.second->IsInstance()) { @@ -240,6 +269,9 @@ Map BufferInfoExtractor::operator()(const PrimFunc& main_ le_events.push_back(le_event_end); } + // Sort the liveness events based on the chronological + // ordering. For events that are simultaneous, START event + // takes precedence. std::sort(le_events.begin(), le_events.end(), [](const LivenessEvent& lhs, const LivenessEvent& rhs) { if (lhs.tick < rhs.tick) { @@ -249,6 +281,9 @@ Map BufferInfoExtractor::operator()(const PrimFunc& main_ } return false; }); + + // Traverse the liveness events using a open set to track what + // is live while updating the conflicts through out the linear traversal std::unordered_set open_set; for (const auto& le_event : le_events) { if (le_event.le_type == START) { @@ -258,7 +293,6 @@ Map BufferInfoExtractor::operator()(const PrimFunc& main_ } open_set.insert(le_event.buffer_info); } else { - ICHECK(le_event.le_type == END); open_set.erase(le_event.buffer_info); } } diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index 119507d6d335..a494c368344b 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -78,7 +78,10 @@ TVM_REGISTER_NODE_TYPE(PoolInfoNode); TVM_REGISTER_GLOBAL("tir.usmp.PoolInfo") .set_body_typed([](String pool_name, Map target_access, Integer size_hint_bytes) { - return PoolInfo(pool_name, target_access, size_hint_bytes); + if (size_hint_bytes.defined()) { + return PoolInfo(pool_name, target_access, size_hint_bytes); + } + return PoolInfo(pool_name, target_access); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py index 0f8a1ce75809..46c6bf420c1b 100644 --- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import pytest +import sys import tvm from tvm import tir, script @@ -79,6 +80,9 @@ def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): return ret +# These are test IRModules that contains varied topologies of operator graphs +# that includes a main TIR function that includes call to such operators. + # fmt: off @tvm.script.ir_module class LinearStructure: @@ -846,4 +850,4 @@ def test_inception_structure(): if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__] + sys.argv[1:]) diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py index 0974f1fe38e4..53064fae7b46 100644 --- a/tests/python/unittest/test_tir_usmp_utils.py +++ b/tests/python/unittest/test_tir_usmp_utils.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import pytest +import sys import tvm from tvm.script import tir as T @@ -188,4 +189,4 @@ def test_create_array_buffer_info(): if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__] + sys.argv[1:]) From c130133f0da789c7b1185186daf127a5e7e46b53 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Mon, 22 Nov 2021 12:08:48 +0000 Subject: [PATCH 7/8] [TIR][USMP] Added buffer info extraction pass * Added more documentation * Added functionality to handle multiple calls for the same PrimFunc with a test. Change-Id: Ib7c27b3cf17f415067a224f1e57d8b928f4c7c6f --- include/tvm/tir/usmp/utils.h | 19 +- src/tir/usmp/analysis/extract_buffer_info.cc | 351 ++++--- src/tir/usmp/utils.cc | 14 + ...st_tir_usmp_analysis_extract_bufferinfo.py | 910 ++++++++++++++++-- 4 files changed, 1078 insertions(+), 216 deletions(-) diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h index 0110f427c707..32a2bc6e292d 100644 --- a/include/tvm/tir/usmp/utils.h +++ b/include/tvm/tir/usmp/utils.h @@ -47,7 +47,7 @@ static constexpr const char* kTargetPoolReadWriteAccess = "rw"; static constexpr const char* kTargetPoolReadOnlyAccess = "ro"; /*! - * \brief The pool information to be used by USMP + * \brief Describes a pool of memory accessible by one or more targets. */ struct PoolInfoNode : public Object { /*! \brief The name of the memory pool */ @@ -94,7 +94,13 @@ class PoolInfo : public ObjectRef { }; /*! - * \brief The buffer information to be used by USMP + * \brief Describes an abstract memory buffer that will get allocated inside a pool. + * The actual memory buffer in represented by PoolAllocationNode after static memory planning. + * + * See also for relay-level counterparts: + * relay::StorageToken (graph_plan_memory.cc) + * relay::backend::StorageInfoNode (relay/backend/utils.h) + * Region (python/tvm/relay/transform/memory_plan.py) */ struct BufferInfoNode : public Object { /*! \brief The name of the buffer var */ @@ -103,7 +109,7 @@ struct BufferInfoNode : public Object { Integer size_bytes; /*! \brief The pool candidates that this buffer can get pooled to*/ Array pool_candidates; - /*! \brief The byte alignment required within the pool */ + /*! \brief The byte alignment required for buffers that will placed within the pool */ Integer alignment; /*! \brief The liveness conflicting other buffer info objects */ Array conflicts; @@ -194,6 +200,13 @@ Array CreateArrayBufferInfo(const Map& buffer_info */ static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_pools"; +/*! + * \brief Calculate the size of the extents in bytes + * + * \param op the allocate node + */ +Integer CalculateExtentsSize(const AllocateNode* op); + } // namespace usmp } // namespace tir } // namespace tvm diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index c6bd00e49299..c25578fd9779 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -21,7 +21,7 @@ * \file tir/analysis/usmp/extract_buffer_info.cc * * \brief This analysis pass consumes a TIR IRModule with a main function - * that defines a ordering in the calles to operators and produces BufferInfo + * that defines a ordering in the callees to operators and produces BufferInfo * objects that contains information about tir.allocate nodes and liveness * conflicts between other tir.allocate nodes. */ @@ -38,12 +38,21 @@ namespace tvm { namespace tir { namespace usmp { -/*! \brief This class takes a TIR IRModule and a main PrimFunc that contains - * that defines a ordering in the calles to operators and produces BufferInfo - * objects that contains information about tir.allocate nodes and liveness - * conflicts between other tir.allocate nodes. +/*! + * \brief The visitor class to obtain buffer information + * + * The visitor would initiate the traversal from the main + * function and visits into the operator PrimFuncs. It will + * crate unique BufferInfo objects for each Allocate node. + * + * Every time the buffer variable of the allocate node is referenced + * it will be recorded using the stmt index. However, note that + * the same buffer variable could be references multiple times + * from different calls. Thereafter, a sweep is done on all the + * BufferInfo objects using the per-call liveness events. In the sweep, + * The BufferInfo objects that are live together will be recorded as + * mutual conflicts of each other. */ - class BufferInfoExtractor : public StmtExprVisitor { public: explicit BufferInfoExtractor(const IRModule& module) : module_(module) { @@ -65,49 +74,163 @@ class BufferInfoExtractor : public StmtExprVisitor { void VisitStmt_(const ForNode* op) override; void UpdateAliases(const Array& args, const PrimFunc& func); + void RecordAllocateNodeInfo(const AllocateNode* op); + void VisitPrimFunc(const PrimFunc& func, const Call& call); + /*! + * \brief Maintains the mapping of BufferInfo to their associated TIR Statements. + */ Map buffer_info_map_; - Map buffer_info_start_stmt_idx_; - Map buffer_info_end_stmt_idx_; + /*! + * \brief Records the order of calls in the main for stability. + */ + std::set call_order_; + /*! + * \brief Records first access in-terms of Stmts to each buffer per call + * + * This is because multiple calls could happen to the same PrimFunc. + */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + buffer_info_start_stmt_idx_; + /*! + * \brief Records last access in-terms of Stmts to each buffer per call + * + * This is because multiple calls could happen to the same PrimFunc. + */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + buffer_info_end_stmt_idx_; + /*! + * \brief Maintains the mapping of buffer variable to their allocate nodes to ensure + * that only one BufferInfo object is created. + */ Map allocate_var_to_stmt_map_; - - std::unordered_set currently_live_allocates; + /*! + * \brief Indicates a count of stmts visited so far to use as a metric of liveness + */ int current_stmt_idx_ = 0; - // This structure is supposed to contain information - // around the scope the visitor is currently in. - // We only check whether the current scope belong to - // a Serial ForKind. We are not planning for Parallel - // ForKind just yet. + /*! + * \brief This structure is supposed to contain information around the scope + * the visitor is currently in. + */ struct ScopeInfo { + /*! + * \brief We need to record access per call + */ + Call call; + /*! + * \brief Having access to PrimFunc metadata is useful + */ + PrimFunc func; + /*! + * \brief We currently support only serial for loops. Therefore + * need to know what kind of for loop the visitor is in. + */ For for_loop; + /*! + * \brief We record the live allocate_nodes because once in loops + * the liveness range has to be extended to the whole of the nested + * loops structure. + */ + std::unordered_set allocate_nodes; + /*! + * \brief This is recorded to extend the liveness of all allocates within + * nested loop structure. + */ + Integer initial_stmt_of_the_nested_loops; }; std::stack scope_stack_; + /*! + * \brief A liveness event is an event that when + * traversing the tir.Stmts where tir.allocate node + * begins or ceases to be Live. This particular struct + * is used to solve interval overlap problem using + * a sweep-line algorithm. For that, we need to record + * where the liveness event occurred in a chronological + * order. + */ + enum LivenessEventType { START = 0, END = 1 }; + struct LivenessEvent { + size_t tick; + LivenessEventType le_type; + BufferInfo buffer_info; + bool operator==(const LivenessEvent& other) { + if (tick == other.tick && le_type == other.le_type && buffer_info == other.buffer_info) { + return true; + } + return false; + } + }; + /*! + * \brief We need to create unique buffer name is the same name is used in + * two allocate nodes for clarity for memory planning algorithms. + */ + std::string GetUniqueBufferName(std::string name); + + /*! + * \brief This is per buffer name counter to aid the generating the above + * unique name. + */ + std::unordered_map buffer_names; + /*! + * \brief The TIR main function calls by name to PrimFuncs to be able to + * support BYOC. Therefore, this Map records functions that are present + * in the IRModule by name/ + */ Map functions_; + /*! + * \brief The IRModule being analyzed. + */ IRModule module_; }; +std::string BufferInfoExtractor::GetUniqueBufferName(std::string name) { + if (buffer_names.find(name) == buffer_names.end()) { + buffer_names[name] = 1; + return name; + } else { + buffer_names[name] = buffer_names[name] + 1; + return name + std::to_string(buffer_names[name]); + } +} + void BufferInfoExtractor::VisitStmt(const Stmt& n) { current_stmt_idx_ += 1; StmtExprVisitor::VisitStmt(n); } -static Integer CalculateExtentsSize(const AllocateNode* op) { - size_t element_size_bytes = op->dtype.bytes(); - size_t num_elements = 1; - for (const auto& ext : op->extents) { - if (ext->IsInstance()) { - num_elements *= Downcast(ext)->value; - } else { - // We can't statically calculate workspace for dynamic shapes - return Integer(); - } +void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) { + auto size_bytes = CalculateExtentsSize(op); + // We only statically memory plan only allocates with known + // compile time sizes. + if (size_bytes.defined() && + allocate_var_to_stmt_map_.find(op->buffer_var) == allocate_var_to_stmt_map_.end()) { + // By default, the core compiler is assumed to attach the a default pool to each allocate. + ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr)) + << "Every statically sized allocate node needs an pool candidate attribute"; + auto pool_candidates = Downcast>(op->annotations[kPoolCandidatesAllocateAttr]); + + // TODO(@manupa-arm): improve the error when the responsible component for attaching a single + // pool is added + ICHECK(pool_candidates.size() > 0) + << "The core compiler should at least attach a single PoolInfo. If there were no " + "user-given arguments for memory pools, the default behaviour is a single size " + "un-restricted pool is assigned"; + PrimFunc func = scope_stack_.top().func; + Optional tgt = func->GetAttr(tvm::attr::kTarget); + ICHECK(tgt) << "There should not be any PrimFuncs without a target attached by now"; + auto workspace_alignment = + tgt.value()->GetAttr("workspace-byte-alignment").value_or(16); + auto buffer_info = BufferInfo(GetUniqueBufferName(op->buffer_var->name_hint), size_bytes, + pool_candidates, workspace_alignment); + auto allocate = GetRef(op); + allocate_var_to_stmt_map_.Set(op->buffer_var, allocate); + buffer_info_map_.Set(buffer_info, allocate); } - return Integer(num_elements * element_size_bytes); } void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) { - const auto& currect_scope_info = scope_stack_.top(); + ScopeInfo& current_scope_info = scope_stack_.top(); const auto& type = Downcast(op->buffer_var->type_annotation); const auto& storage_scope = type->storage_scope; @@ -115,40 +238,38 @@ void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) { // If its not a serial for loop, then memory planner will omit them in the current memory planning // process leaving them to as tir.allocate nodes for codegen. Additionally, the USMP can only work // with buffers that have global storage_scope - if ((!currect_scope_info.for_loop.defined()) || - (currect_scope_info.for_loop.defined() && - currect_scope_info.for_loop->kind == ForKind::kSerial && storage_scope == "global")) { - auto size_bytes = CalculateExtentsSize(op); - // We only statically memory plan only allocates with known - // compile time sizes. - if (size_bytes.defined()) { - // By default, the core compiler is assumed to attach the a default pool to each allocate. - ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr)) - << "Every statically sized allocate node needs an pool candidate attribute"; - auto pool_candidates = - Downcast>(op->annotations[kPoolCandidatesAllocateAttr]); - - // TODO(@manupa-arm): improve the error when the responsible component for attaching a single - // pool is added - ICHECK(pool_candidates.size() > 0) - << "The core compiler should at least attach a single PoolInfo. If there were no " - "user-given arguments for memory pools, the default behaviour is a single size " - "un-restricted pool is assigned"; - auto buffer_info = BufferInfo(op->buffer_var->name_hint, size_bytes, pool_candidates); - auto allocate = GetRef(op); - allocate_var_to_stmt_map_.Set(op->buffer_var, allocate); - buffer_info_map_.Set(buffer_info, allocate); - } + + if (!current_scope_info.for_loop.defined()) { + RecordAllocateNodeInfo(op); + } else if (current_scope_info.for_loop.defined() && + current_scope_info.for_loop->kind == ForKind::kSerial && storage_scope == "global") { + RecordAllocateNodeInfo(op); } StmtExprVisitor::VisitStmt(op->body); + current_scope_info.allocate_nodes.erase(GetRef(op)); } void BufferInfoExtractor::VisitStmt_(const ForNode* op) { - ScopeInfo si{ - GetRef(op), - }; + ScopeInfo si{scope_stack_.top().call, scope_stack_.top().func, GetRef(op), + scope_stack_.top().allocate_nodes, + scope_stack_.top().initial_stmt_of_the_nested_loops}; + if (!scope_stack_.top().initial_stmt_of_the_nested_loops.defined()) { + si.initial_stmt_of_the_nested_loops = Integer(current_stmt_idx_); + } + Call current_call = scope_stack_.top().call; scope_stack_.push(si); StmtExprVisitor::VisitStmt_(op); + // Extending the liveness to beginning of for-loop next and end of the current for-loop + for (const Allocate& allocate : scope_stack_.top().allocate_nodes) { + if (scope_stack_.top().initial_stmt_of_the_nested_loops->value < + buffer_info_start_stmt_idx_[current_call][allocate]) { + buffer_info_start_stmt_idx_[current_call].Set( + allocate, scope_stack_.top().initial_stmt_of_the_nested_loops->value); + } + if (current_stmt_idx_ > buffer_info_end_stmt_idx_[current_call][allocate]) { + buffer_info_end_stmt_idx_[current_call].Set(allocate, current_stmt_idx_); + } + } scope_stack_.pop(); } @@ -164,12 +285,18 @@ void BufferInfoExtractor::VisitStmt_(const StoreNode* op) { void BufferInfoExtractor::VisitExpr_(const VarNode* op) { auto var = GetRef(op); + Call current_call = scope_stack_.top().call; if (allocate_var_to_stmt_map_.count(var)) { auto allocate = allocate_var_to_stmt_map_[var]; - if (buffer_info_start_stmt_idx_.count(allocate) == 0) { - buffer_info_start_stmt_idx_.Set(allocate, current_stmt_idx_); + if (buffer_info_start_stmt_idx_[current_call].count(allocate) == 0) { + buffer_info_start_stmt_idx_[current_call].Set(allocate, current_stmt_idx_); + } + buffer_info_end_stmt_idx_[current_call].Set(allocate, current_stmt_idx_); + + ScopeInfo& currect_scope_info = scope_stack_.top(); + if (currect_scope_info.for_loop.defined()) { + currect_scope_info.allocate_nodes.insert(Downcast(allocate)); } - buffer_info_end_stmt_idx_.Set(allocate, current_stmt_idx_); } StmtExprVisitor::VisitExpr_(op); } @@ -205,74 +332,78 @@ void BufferInfoExtractor::UpdateAliases(const Array& args, const PrimF } } +void BufferInfoExtractor::VisitPrimFunc(const PrimFunc& func, const Call& call) { + ScopeInfo si{call, func, scope_stack_.top().for_loop, scope_stack_.top().allocate_nodes, + scope_stack_.top().initial_stmt_of_the_nested_loops}; + call_order_.insert(call); + scope_stack_.push(si); + this->VisitStmt(func->body); + scope_stack_.pop(); +} + void BufferInfoExtractor::VisitExpr_(const CallNode* op) { - if (op->op.same_as(builtin::call_extern())) { - auto func = functions_.at(Downcast(op->args[0])->value); - auto actual_args = Array(op->args.begin() + 1, op->args.end()); - this->UpdateAliases(actual_args, func); - this->VisitStmt(func->body); - } else if (op->op->IsInstance()) { + if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) { + StringImm func_name = Downcast(op->args[0])->value; + if (functions_.find(func_name->value) != functions_.end()) { + auto func = functions_.at(func_name->value); + auto actual_args = Array(op->args.begin() + 1, op->args.end()); + this->UpdateAliases(actual_args, func); + VisitPrimFunc(func, GetRef(op)); + return; + } + } + if (op->op->IsInstance()) { auto func = Downcast(op->op); this->UpdateAliases(op->args, func); - this->VisitStmt(func->body); - } else { - StmtExprVisitor::VisitExpr_(op); + VisitPrimFunc(func, GetRef(op)); + return; } + StmtExprVisitor::VisitExpr_(op); } Map BufferInfoExtractor::operator()(const PrimFunc& main_func) { - this->VisitStmt(main_func->body); - - // A liveness event is an event that when - // traversing the tir.Stmts where tir.allocate node - // begins or ceases to be Live. This particular struct - // is used to solve interval overlap problem using - // a sweep-line algorithm. For that, we need to record - // where the liveness event occurred in a chronological - // order. - enum LivenessEventType { START = 0, END = 1 }; - struct LivenessEvent { - size_t tick; - LivenessEventType le_type; - BufferInfo buffer_info; - bool operator==(const LivenessEvent& other) { - if (tick == other.tick && le_type == other.le_type && buffer_info == other.buffer_info) { - return true; - } - return false; - } - }; + VisitPrimFunc(main_func, Call()); // Create a vector of liveness events // associated with each BufferNodes. - std::vector le_events; - for (const auto& kv : buffer_info_map_) { - if (!kv.second->IsInstance()) { + std::vector le_events_timeline; + for (const auto& kv1 : buffer_info_map_) { + if (!kv1.second->IsInstance()) { continue; } - auto allocate = Downcast(kv.second); - auto buffer_info = Downcast(kv.first); - // If the allocate is not used; we remove it from the analysis - if (buffer_info_start_stmt_idx_.count(allocate) == 0) { - continue; + auto allocate = Downcast(kv1.second); + auto buffer_info = Downcast(kv1.first); + + ICHECK(call_order_.size() >= buffer_info_end_stmt_idx_.size()); + ICHECK(call_order_.size() >= buffer_info_end_stmt_idx_.size()); + + for (const Call& call : call_order_) { + Map buffer_info_starts = buffer_info_start_stmt_idx_[call]; + if (buffer_info_starts.find(allocate) != buffer_info_starts.end()) { + LivenessEvent le_event_start; + le_event_start.buffer_info = buffer_info; + le_event_start.le_type = START; + le_event_start.tick = buffer_info_starts[allocate]; + le_events_timeline.push_back(le_event_start); + } + } + + for (const Call& call : call_order_) { + Map buffer_info_ends = buffer_info_end_stmt_idx_[call]; + if (buffer_info_ends.find(allocate) != buffer_info_ends.end()) { + LivenessEvent le_event_end; + le_event_end.buffer_info = buffer_info; + le_event_end.le_type = END; + le_event_end.tick = buffer_info_ends[allocate]; + le_events_timeline.push_back(le_event_end); + } } - LivenessEvent le_event_start; - le_event_start.buffer_info = buffer_info; - le_event_start.le_type = START; - le_event_start.tick = buffer_info_start_stmt_idx_[allocate]; - le_events.push_back(le_event_start); - - LivenessEvent le_event_end; - le_event_end.buffer_info = buffer_info; - le_event_end.le_type = END; - le_event_end.tick = buffer_info_end_stmt_idx_[allocate]; - le_events.push_back(le_event_end); } // Sort the liveness events based on the chronological // ordering. For events that are simultaneous, START event // takes precedence. - std::sort(le_events.begin(), le_events.end(), + std::sort(le_events_timeline.begin(), le_events_timeline.end(), [](const LivenessEvent& lhs, const LivenessEvent& rhs) { if (lhs.tick < rhs.tick) { return true; @@ -285,11 +416,13 @@ Map BufferInfoExtractor::operator()(const PrimFunc& main_ // Traverse the liveness events using a open set to track what // is live while updating the conflicts through out the linear traversal std::unordered_set open_set; - for (const auto& le_event : le_events) { + for (const auto& le_event : le_events_timeline) { if (le_event.le_type == START) { for (const auto& open_buffer_info : open_set) { open_buffer_info->conflicts.push_back(le_event.buffer_info); - le_event.buffer_info->conflicts.push_back(open_buffer_info); + if (le_event.buffer_info != open_buffer_info) { + le_event.buffer_info->conflicts.push_back(open_buffer_info); + } } open_set.insert(le_event.buffer_info); } else { diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index a494c368344b..b7177cc1635b 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -122,6 +122,20 @@ Array CreateArrayBufferInfo(const Map& buffer_info return ret; } +Integer CalculateExtentsSize(const AllocateNode* op) { + size_t element_size_bytes = op->dtype.bytes(); + size_t num_elements = 1; + for (const auto& ext : op->extents) { + if (ext->IsInstance()) { + num_elements *= Downcast(ext)->value; + } else { + // We can't statically calculate workspace for dynamic shapes + return Integer(); + } + } + return Integer(num_elements * element_size_bytes); +} + TVM_REGISTER_GLOBAL("tir.usmp.CreateArrayBufferInfo") .set_body_typed([](Map buffer_info_map) { return (CreateArrayBufferInfo(buffer_info_map)); diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py index 46c6bf420c1b..fa645f1379ff 100644 --- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -31,7 +31,7 @@ def _replace_stmt_with_buf_var_names(buffer_info_map): """helper to replace tir.allocates with buffer names""" new_buffer_info_map = dict() for k, v in buffer_info_map.items(): - new_buffer_info_map[v.buffer_var.name] = k + new_buffer_info_map[k.name_hint] = k return new_buffer_info_map @@ -72,7 +72,7 @@ def set_poolinfos(stmt): def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): - """helper to assing poolinfos to allocate nodes in a IRModule""" + """helper to assign poolinfos to allocate nodes in a IRModule""" ret = tvm.IRModule() for global_var, basefunc in mod.functions.items(): if isinstance(basefunc, tvm.tir.PrimFunc): @@ -80,6 +80,15 @@ def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): return ret +def _assign_targets_to_primfuncs_irmodule(mod, target): + """helper to assign target for PrimFunc in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = basefunc.with_attr("target", target) + return ret + + # These are test IRModules that contains varied topologies of operator graphs # that includes a main TIR function that includes call to such operators. @@ -139,7 +148,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) @T.prim_func - def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + def run_model(input: T.handle, output: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) # body @@ -155,27 +164,27 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: def test_linear(): + target = Target("c") fast_memory_pool = usmp_utils.PoolInfo( - pool_name="fast_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS} + pool_name="fast_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} ) slow_memory_pool = usmp_utils.PoolInfo( - pool_name="slow_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS} + pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} ) tir_mod = LinearStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) tir_mod = _assign_poolinfos_to_allocates_in_irmodule( tir_mod, [fast_memory_pool, slow_memory_pool] ) - buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info( - tir_mod["tvmgen_default_run_model"], tir_mod - ) + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(tir_mod["run_model"], tir_mod) buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) # check conflicts - _verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map) - _verify_conflicts("Conv2dOutput_7", ["PaddedInput_7", "sid_8"], buffer_info_map) - _verify_conflicts("PaddedInput_7", ["sid_9", "Conv2dOutput_7"], buffer_info_map) + _verify_conflicts("PaddedInput_7", ["sid_9", "sid_8", "Conv2dOutput_7"], buffer_info_map) _verify_conflicts("tensor_2", ["sid_8"], buffer_info_map) _verify_conflicts("sid_9", ["PaddedInput_7"], buffer_info_map) + _verify_conflicts("sid_8", ["PaddedInput_7", "Conv2dOutput_7", "tensor_2"], buffer_info_map) + _verify_conflicts("Conv2dOutput_7", ["sid_8", "PaddedInput_7"], buffer_info_map) # check sizes assert buffer_info_map["sid_8"].size_bytes == 802816 @@ -218,7 +227,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func - def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + def run_model(input: T.handle, output: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) # body @@ -259,7 +268,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func - def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + def run_model(input: T.handle, output: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) # body @@ -273,15 +282,17 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: def test_parallel_serial_mixed_for_loops(): + target = Target("c") global_ws_pool = usmp_utils.PoolInfo( pool_name="global_workspace", - target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) all_serial_tir_mod = AllSerialForLoops + all_serial_tir_mod = _assign_targets_to_primfuncs_irmodule(all_serial_tir_mod, target) all_serial_tir_mod = _assign_poolinfos_to_allocates_in_irmodule( all_serial_tir_mod, [global_ws_pool] ) - main_func = all_serial_tir_mod["tvmgen_default_run_model"] + main_func = all_serial_tir_mod["run_model"] buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, all_serial_tir_mod) buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) @@ -291,10 +302,13 @@ def test_parallel_serial_mixed_for_loops(): assert name in ["dummy_allocate", "Conv2dOutput_8", "PaddedInput_8"] parallel_serial_mixed_tir_mod = ParallelSerialMixedForLoops + parallel_serial_mixed_tir_mod = _assign_targets_to_primfuncs_irmodule( + parallel_serial_mixed_tir_mod, target + ) parallel_serial_mixed_tir_mod = _assign_poolinfos_to_allocates_in_irmodule( parallel_serial_mixed_tir_mod, [global_ws_pool] ) - main_func = parallel_serial_mixed_tir_mod["tvmgen_default_run_model"] + main_func = parallel_serial_mixed_tir_mod["run_model"] buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info( main_func, parallel_serial_mixed_tir_mod ) @@ -593,7 +607,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol T.store(T_cast_23.data, (((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_8, ax3_inner_8) + T.load("int32", placeholder_73.data, ((ax3_outer_4*64) + ax3_inner_8))), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func - def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + def run_model(input: T.handle, output: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) # body @@ -633,183 +647,318 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: def test_inception_structure(): + target = Target("c") global_ws_pool = usmp_utils.PoolInfo( pool_name="global_workspace", - target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) tir_mod = InceptionStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool]) - main_func = tir_mod["tvmgen_default_run_model"] + main_func = tir_mod["run_model"] buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) # check conflicts - _verify_conflicts("sid_5", ["Conv2dOutput_8", "sid_4"], buffer_info_map) _verify_conflicts( - "Conv2dOutput_2", ["PaddedInput_2", "sid_4", "sid_3", "sid_2"], buffer_info_map + "PaddedInput_8", + [ + "sid_6", + "Conv2dOutput_8", + "sid_5", + ], + buffer_info_map, ) - _verify_conflicts("sid_9", ["PaddedInput_7"], buffer_info_map) - _verify_conflicts("PaddedInput_7", ["sid_9", "Conv2dOutput_7"], buffer_info_map) _verify_conflicts( - "sid_26", ["sid_19", "Conv2dOutput_4", "sid_2", "sid_4", "PaddedInput_5"], buffer_info_map + "sid_26", + [ + "PaddedInput_4", + "Conv2dOutput_4", + "PaddedInput_5", + ], + buffer_info_map, ) - _verify_conflicts("Conv2dOutput", ["PaddedInput", "sid_6"], buffer_info_map) _verify_conflicts( - "PaddedInput_4", ["sid_19", "sid_2", "sid_4", "sid_3", "Conv2dOutput_4"], buffer_info_map + "Conv2dOutput", + [ + "sid_6", + "PaddedInput", + ], + buffer_info_map, ) - _verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map) - _verify_conflicts("tensor_3", ["sid_25", "sid_19", "sid_2", "sid_4", "sid_32"], buffer_info_map) _verify_conflicts( - "sid_3", + "sid_4", + [ + "sid_5", + "sid_3", + "tensor_3", + ], + buffer_info_map, + ) + _verify_conflicts( + "tensor_2", + [ + "sid_8", + "sid_7", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_7", + [ + "sid_8", + "PaddedInput_7", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_1", [ - "sid_4", - "PaddedInput_2", - "Conv2dOutput_2", - "sid_2", - "PaddedInput_1", - "Conv2dOutput_1", "sid_20", - "PaddedInput_6", - "Conv2dOutput_6", - "sid_19", + "PaddedInput_1", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_4", + [ + "sid_26", "PaddedInput_4", ], buffer_info_map, ) _verify_conflicts( - "sid_32", ["tensor_3", "sid_25", "sid_19", "sid_2", "PaddedInput_3"], buffer_info_map + "Conv2dOutput_2", + [ + "PaddedInput_2", + "sid_2", + ], + buffer_info_map, ) - _verify_conflicts("PaddedInput_8", ["sid_6", "Conv2dOutput_8"], buffer_info_map) _verify_conflicts( - "Conv2dOutput_6", ["PaddedInput_6", "sid_2", "sid_4", "sid_3", "sid_19"], buffer_info_map + "PaddedInput_3", + [ + "sid_32", + "sid_31", + "Conv2dOutput_3", + ], + buffer_info_map, ) _verify_conflicts( - "sid_4", + "sid_3", [ - "sid_5", - "sid_3", + "sid_4", "PaddedInput_2", - "Conv2dOutput_2", - "sid_2", "PaddedInput_1", - "Conv2dOutput_1", - "sid_20", + "PaddedInput_4", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_6", + [ "PaddedInput_6", - "Conv2dOutput_6", "sid_19", - "PaddedInput_4", - "Conv2dOutput_4", - "sid_26", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_5", + [ "PaddedInput_5", - "Conv2dOutput_5", "sid_25", - "tensor_3", ], buffer_info_map, ) - _verify_conflicts("PaddedInput_2", ["sid_3", "sid_4", "Conv2dOutput_2"], buffer_info_map) _verify_conflicts( - "Conv2dOutput_4", ["sid_19", "sid_2", "sid_4", "PaddedInput_4", "sid_26"], buffer_info_map + "PaddedInput_7", + [ + "sid_9", + "sid_8", + "Conv2dOutput_7", + ], + buffer_info_map, ) _verify_conflicts( - "PaddedInput_1", ["sid_2", "sid_4", "sid_3", "Conv2dOutput_1"], buffer_info_map + "sid_7", + [ + "tensor_2", + "PaddedInput", + ], + buffer_info_map, ) - _verify_conflicts("sid_6", ["Conv2dOutput", "PaddedInput_8"], buffer_info_map) - _verify_conflicts("Conv2dOutput_8", ["PaddedInput_8", "sid_5"], buffer_info_map) _verify_conflicts( - "sid_25", + "sid_31", [ - "Conv2dOutput_5", - "sid_19", - "sid_2", - "sid_4", - "tensor_3", - "sid_32", "PaddedInput_3", "Conv2dOutput_3", - "sid_31", + "sid_25", + "sid_2", + "sid_19", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput_6", ["sid_20", "sid_2", "sid_4", "sid_3", "Conv2dOutput_6"], buffer_info_map + "sid_5", + [ + "Conv2dOutput_8", + "PaddedInput_8", + "sid_4", + ], + buffer_info_map, ) _verify_conflicts( - "sid_7", + "sid_6", [ - "tensor_2", "PaddedInput", + "Conv2dOutput", + "PaddedInput_8", ], buffer_info_map, ) - _verify_conflicts("sid_31", ["Conv2dOutput_3", "sid_25", "sid_19", "sid_2"], buffer_info_map) - _verify_conflicts("tensor_2", ["sid_8", "sid_7"], buffer_info_map) _verify_conflicts( - "sid_2", + "sid_20", [ - "Conv2dOutput_2", - "sid_4", - "sid_3", "PaddedInput_1", "Conv2dOutput_1", - "sid_20", "PaddedInput_6", - "Conv2dOutput_6", - "sid_19", - "PaddedInput_4", - "Conv2dOutput_4", - "sid_26", - "PaddedInput_5", - "Conv2dOutput_5", - "sid_25", - "tensor_3", - "sid_32", - "PaddedInput_3", - "Conv2dOutput_3", - "sid_31", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_3", ["sid_25", "PaddedInput_3", "sid_19", "sid_2", "sid_31"], buffer_info_map + "Conv2dOutput_8", + [ + "PaddedInput_8", + "sid_5", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_1", + [ + "sid_3", + "sid_20", + "Conv2dOutput_1", + ], + buffer_info_map, ) - _verify_conflicts("PaddedInput", ["sid_7", "Conv2dOutput"], buffer_info_map) _verify_conflicts( - "Conv2dOutput_1", ["PaddedInput_1", "sid_2", "sid_4", "sid_3", "sid_20"], buffer_info_map + "Conv2dOutput_3", + [ + "sid_31", + "PaddedInput_3", + ], + buffer_info_map, ) _verify_conflicts( - "PaddedInput_5", ["sid_26", "sid_19", "sid_2", "sid_4", "Conv2dOutput_5"], buffer_info_map + "PaddedInput", + [ + "sid_7", + "sid_6", + "Conv2dOutput", + ], + buffer_info_map, ) _verify_conflicts( - "PaddedInput_3", ["sid_32", "sid_25", "sid_19", "sid_2", "Conv2dOutput_3"], buffer_info_map + "PaddedInput_2", + [ + "sid_3", + "Conv2dOutput_2", + "sid_2", + ], + buffer_info_map, ) _verify_conflicts( "sid_19", [ "Conv2dOutput_6", + "PaddedInput_6", + "sid_31", "sid_2", - "sid_4", + "sid_25", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_4", + [ "sid_3", - "PaddedInput_4", + "sid_26", "Conv2dOutput_4", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_5", + [ "sid_26", - "PaddedInput_5", "Conv2dOutput_5", "sid_25", - "tensor_3", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_6", + [ + "sid_20", + "Conv2dOutput_6", + "sid_19", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_25", + [ + "Conv2dOutput_5", + "PaddedInput_5", + "sid_31", + "sid_2", + "sid_19", + ], + buffer_info_map, + ) + _verify_conflicts( + "tensor_3", + [ + "sid_4", "sid_32", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_32", + [ + "tensor_3", "PaddedInput_3", - "Conv2dOutput_3", - "sid_31", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_5", ["PaddedInput_5", "sid_19", "sid_2", "sid_4", "sid_25"], buffer_info_map + "sid_9", + [ + "PaddedInput_7", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_2", + [ + "Conv2dOutput_2", + "PaddedInput_2", + "sid_31", + "sid_25", + "sid_19", + ], + buffer_info_map, ) - _verify_conflicts("Conv2dOutput_7", ["PaddedInput_7", "sid_8"], buffer_info_map) _verify_conflicts( - "sid_20", ["sid_2", "Conv2dOutput_1", "sid_4", "sid_3", "PaddedInput_6"], buffer_info_map + "sid_8", + [ + "PaddedInput_7", + "Conv2dOutput_7", + "tensor_2", + ], + buffer_info_map, ) # check sizes @@ -849,5 +998,558 @@ def test_inception_structure(): assert buffer_info_map["PaddedInput_6"].size_bytes == 172800 +# fmt: off +@tvm.script.ir_module +class MultipleCallsToSamePrimFuncModule: + @T.prim_func + def tvmgen_default_fused_layout_transform_1(placeholder: T.handle, T_layout_trans: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_layout_transform_1", "tir.noalias": True}) + placeholder_1 = T.match_buffer(placeholder, [1, 3, 24, 12], dtype="float32") + T_layout_trans_1 = T.match_buffer(T_layout_trans, [1, 1, 24, 12, 3], dtype="float32") + # body + for ax0_ax1_fused_ax2_fused, ax3, ax4_inner in T.grid(24, 12, 3): + T.store(T_layout_trans_1.data, ax0_ax1_fused_ax2_fused * 36 + ax3 * 3 + ax4_inner, T.load("float32", placeholder_1.data, ax4_inner * 288 + ax0_ax1_fused_ax2_fused * 12 + ax3), True) + + @T.prim_func + def tvmgen_default_fused_nn_contrib_conv2d_NCHWc(placeholder_2: T.handle, placeholder_3: T.handle, conv2d_NCHWc: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_contrib_conv2d_NCHWc", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [1, 1, 24, 12, 3], dtype="float32") + placeholder_5 = T.match_buffer(placeholder_3, [1, 1, 3, 3, 3, 3], dtype="float32") + conv2d_NCHWc_1 = T.match_buffer(conv2d_NCHWc, [1, 1, 24, 12, 3], dtype="float32") + # body + data_pad = T.allocate([1, 1, 26, 14, 3], "float32", "global") + for i0_i1_fused_i2_fused, i3, i4 in T.grid(26, 14, 3): + T.store(data_pad, i0_i1_fused_i2_fused * 42 + i3 * 3 + i4, T.if_then_else(1 <= i0_i1_fused_i2_fused and i0_i1_fused_i2_fused < 25 and 1 <= i3 and i3 < 13, T.load("float32", placeholder_4.data, i0_i1_fused_i2_fused * 36 + i3 * 3 + i4 - 39), T.float32(0), dtype="float32"), True) + for n_oc_chunk_fused_oh_fused in T.serial(0, 24): + conv2d_NCHWc_global = T.allocate([1, 1, 1, 12, 3], "float32", "global") + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 3, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 6, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 9, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 12, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 15, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 18, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 21, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 24, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 27, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 30, T.float32(0), True) + for oc_block_c_init in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c_init + 33, T.float32(0), True) + for kh, kw, ic_inner in T.grid(3, 3, 3): + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c, T.load("float32", conv2d_NCHWc_global, oc_block_c) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 3, T.load("float32", conv2d_NCHWc_global, oc_block_c + 3) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 3) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 6, T.load("float32", conv2d_NCHWc_global, oc_block_c + 6) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 6) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 9, T.load("float32", conv2d_NCHWc_global, oc_block_c + 9) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 9) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 12, T.load("float32", conv2d_NCHWc_global, oc_block_c + 12) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 12) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 15, T.load("float32", conv2d_NCHWc_global, oc_block_c + 15) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 15) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 18, T.load("float32", conv2d_NCHWc_global, oc_block_c + 18) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 18) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 21, T.load("float32", conv2d_NCHWc_global, oc_block_c + 21) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 21) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 24, T.load("float32", conv2d_NCHWc_global, oc_block_c + 24) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 24) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 27, T.load("float32", conv2d_NCHWc_global, oc_block_c + 27) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 27) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 30, T.load("float32", conv2d_NCHWc_global, oc_block_c + 30) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 30) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for oc_block_c in T.serial(0, 3): + T.store(conv2d_NCHWc_global, oc_block_c + 33, T.load("float32", conv2d_NCHWc_global, oc_block_c + 33) + T.load("float32", data_pad, kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 33) * T.load("float32", placeholder_5.data, kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c), True) + for ow_inner, oc_block in T.grid(12, 3): + T.store(conv2d_NCHWc_1.data, n_oc_chunk_fused_oh_fused * 36 + ow_inner * 3 + oc_block, T.load("float32", conv2d_NCHWc_global, ow_inner * 3 + oc_block), True) + + @T.prim_func + def tvmgen_default_fused_nn_softmax_add_add_multiply_add(placeholder_6: T.handle, placeholder_7: T.handle, placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, T_add: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_softmax_add_add_multiply_add", "tir.noalias": True}) + placeholder_11 = T.match_buffer(placeholder_6, [1, 3, 24, 12], dtype="float32") + placeholder_12 = T.match_buffer(placeholder_7, [1, 3, 24, 12], dtype="float32") + placeholder_13 = T.match_buffer(placeholder_8, [3, 1, 1], dtype="float32") + placeholder_14 = T.match_buffer(placeholder_9, [3, 1, 1], dtype="float32") + placeholder_15 = T.match_buffer(placeholder_10, [3, 1, 1], dtype="float32") + T_add_1 = T.match_buffer(T_add, [1, 3, 24, 12], dtype="float32") + # body + for ax0_ax1_fused_ax2_fused in T.serial(0, 72): + T_softmax_norm = T.allocate([1, 1, 1, 12], "float32", "global") + with T.allocate([1, 1, 1], "float32", "global") as T_softmax_maxelem: + T.store(T_softmax_maxelem, 0, T.float32(-3.4028234663852886e+38), True) + for k in T.serial(0, 12): + T.store(T_softmax_maxelem, 0, T.max(T.load("float32", T_softmax_maxelem, 0), T.load("float32", placeholder_11.data, ax0_ax1_fused_ax2_fused * 12 + k)), True) + T_softmax_exp = T.allocate([1, 1, 1, 12], "float32", "global") + for i3 in T.serial(0, 12): + T.store(T_softmax_exp, i3, T.exp(T.load("float32", placeholder_11.data, ax0_ax1_fused_ax2_fused * 12 + i3) - T.load("float32", T_softmax_maxelem, 0), dtype="float32"), True) + T_softmax_expsum = T.allocate([1, 1, 1], "float32", "global") + T.store(T_softmax_expsum, 0, T.float32(0), True) + for k in T.serial(0, 12): + T.store(T_softmax_expsum, 0, T.load("float32", T_softmax_expsum, 0) + T.load("float32", T_softmax_exp, k), True) + for i3 in T.serial(0, 12): + T.store(T_softmax_norm, i3, T.load("float32", T_softmax_exp, i3) / T.load("float32", T_softmax_expsum, 0), True) + for ax3 in T.serial(0, 12): + T.store(T_add_1.data, ax0_ax1_fused_ax2_fused * 12 + ax3, (T.load("float32", placeholder_12.data, ax0_ax1_fused_ax2_fused * 12 + ax3) + T.load("float32", T_softmax_norm, ax3) + T.load("float32", placeholder_13.data, T.floordiv(ax0_ax1_fused_ax2_fused, 24))) * T.load("float32", placeholder_14.data, T.floordiv(ax0_ax1_fused_ax2_fused, 24)) + T.load("float32", placeholder_15.data, T.floordiv(ax0_ax1_fused_ax2_fused, 24)), True) + + @T.prim_func + def tvmgen_default_fused_nn_contrib_dense_pack_nn_relu(placeholder_16: T.handle, placeholder_17: T.handle, T_relu: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", "tir.noalias": True}) + placeholder_18 = T.match_buffer(placeholder_16, [72, 12], dtype="float32") + placeholder_19 = T.match_buffer(placeholder_17, [2, 12, 6], dtype="float32") + T_relu_1 = T.match_buffer(T_relu, [72, 12], dtype="float32") + # body + for ax1_outer_ax0_outer_fused in T.serial(0, 18): + compute = T.allocate([8, 6], "float32", "global") + with T.allocate([8, 6], "float32", "global") as compute_global: + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 6, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 12, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 18, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 24, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 30, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 36, T.float32(0), True) + for x_c_init in T.serial(0, 6): + T.store(compute_global, x_c_init + 42, T.float32(0), True) + for k_outer in T.serial(0, 12): + for x_c in T.serial(0, 6): + T.store(compute_global, x_c, T.load("float32", compute_global, x_c) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 6, T.load("float32", compute_global, x_c + 6) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 12) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 12, T.load("float32", compute_global, x_c + 12) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 24) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 18, T.load("float32", compute_global, x_c + 18) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 36) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 24, T.load("float32", compute_global, x_c + 24) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 48) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 30, T.load("float32", compute_global, x_c + 30) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 60) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 36, T.load("float32", compute_global, x_c + 36) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 72) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_c in T.serial(0, 6): + T.store(compute_global, x_c + 42, T.load("float32", compute_global, x_c + 42) + T.load("float32", placeholder_18.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 84) * T.load("float32", placeholder_19.data, T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner, T.load("float32", compute_global, x_inner_inner), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 6, T.load("float32", compute_global, x_inner_inner + 6), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 12, T.load("float32", compute_global, x_inner_inner + 12), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 18, T.load("float32", compute_global, x_inner_inner + 18), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 24, T.load("float32", compute_global, x_inner_inner + 24), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 30, T.load("float32", compute_global, x_inner_inner + 30), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 36, T.load("float32", compute_global, x_inner_inner + 36), True) + for x_inner_inner in T.serial(0, 6): + T.store(compute, x_inner_inner + 42, T.load("float32", compute_global, x_inner_inner + 42), True) + for ax0_inner_inner, ax1_inner_inner in T.grid(8, 6): + T.store(T_relu_1.data, T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + ax0_inner_inner * 12 + T.floordiv(ax1_outer_ax0_outer_fused, 9) * 6 + ax1_inner_inner, T.max(T.load("float32", compute, ax0_inner_inner * 6 + ax1_inner_inner), T.float32(0)), True) + + @T.prim_func + def tvmgen_default_fused_reshape_1(placeholder_20: T.handle, T_reshape: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_reshape_1", "tir.noalias": True}) + placeholder_21 = T.match_buffer(placeholder_20, [1, 3, 24, 12], dtype="float32") + T_reshape_1 = T.match_buffer(T_reshape, [72, 12], dtype="float32") + # body + for ax0, ax1_inner in T.grid(72, 12): + T.store(T_reshape_1.data, ax0 * 12 + ax1_inner, T.load("float32", placeholder_21.data, ax0 * 12 + ax1_inner), True) + + @T.prim_func + def tvmgen_default_fused_layout_transform(placeholder_22: T.handle, T_layout_trans_2: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_layout_transform", "tir.noalias": True}) + placeholder_23 = T.match_buffer(placeholder_22, [1, 1, 24, 12, 3], dtype="float32") + T_layout_trans_3 = T.match_buffer(T_layout_trans_2, [1, 3, 24, 12], dtype="float32") + # body + for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): + T.store(T_layout_trans_3.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner, T.load("float32", placeholder_23.data, ax2 * 36 + ax3_inner * 3 + ax0_ax1_fused), True) + + @T.prim_func + def tvmgen_default_fused_reshape(placeholder_24: T.handle, T_reshape_2: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_reshape", "tir.noalias": True}) + placeholder_25 = T.match_buffer(placeholder_24, [72, 12], dtype="float32") + T_reshape_3 = T.match_buffer(T_reshape_2, [1, 3, 24, 12], dtype="float32") + # body + for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): + T.store(T_reshape_3.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner, T.load("float32", placeholder_25.data, ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner), True) + + @T.prim_func + def tvmgen_default_fused_nn_softmax_add(placeholder_26: T.handle, placeholder_27: T.handle, T_add_2: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_softmax_add", "tir.noalias": True}) + placeholder_28 = T.match_buffer(placeholder_26, [1, 3, 24, 12], dtype="float32") + placeholder_29 = T.match_buffer(placeholder_27, [1, 3, 24, 12], dtype="float32") + T_add_3 = T.match_buffer(T_add_2, [1, 3, 24, 12], dtype="float32") + # body + for ax0_ax1_fused_ax2_fused in T.serial(0, 72): + T_softmax_norm = T.allocate([1, 1, 1, 12], "float32", "global") + with T.allocate([1, 1, 1], "float32", "global") as T_softmax_maxelem: + T.store(T_softmax_maxelem, 0, T.float32(-3.4028234663852886e+38), True) + for k in T.serial(0, 12): + T.store(T_softmax_maxelem, 0, T.max(T.load("float32", T_softmax_maxelem, 0), T.load("float32", placeholder_28.data, ax0_ax1_fused_ax2_fused * 12 + k)), True) + T_softmax_exp = T.allocate([1, 1, 1, 12], "float32", "global") + for i3 in T.serial(0, 12): + T.store(T_softmax_exp, i3, T.exp(T.load("float32", placeholder_28.data, ax0_ax1_fused_ax2_fused * 12 + i3) - T.load("float32", T_softmax_maxelem, 0), dtype="float32"), True) + T_softmax_expsum = T.allocate([1, 1, 1], "float32", "global") + T.store(T_softmax_expsum, 0, T.float32(0), True) + for k in T.serial(0, 12): + T.store(T_softmax_expsum, 0, T.load("float32", T_softmax_expsum, 0) + T.load("float32", T_softmax_exp, k), True) + for i3 in T.serial(0, 12): + T.store(T_softmax_norm, i3, T.load("float32", T_softmax_exp, i3) / T.load("float32", T_softmax_expsum, 0), True) + for ax3 in T.serial(0, 12): + T.store(T_add_3.data, ax0_ax1_fused_ax2_fused * 12 + ax3, T.load("float32", placeholder_29.data, ax0_ax1_fused_ax2_fused * 12 + ax3) + T.load("float32", T_softmax_norm, ax3), True) + + @T.prim_func + def run_model(data: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + data_buffer = T.match_buffer(data, [1, 3, 24, 12], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [1, 3, 24, 12], dtype="float32", align=16) + # body + sid_11 = T.allocate([3456], "int8", "global.workspace") + sid_5 = T.allocate([3456], "int8", "global.workspace") + sid_10 = T.allocate([3456], "int8", "global.workspace") + sid_6 = T.allocate([3456], "int8", "global.workspace") + sid_8 = T.allocate([3456], "int8", "global.workspace") + sid_2 = T.allocate([3456], "int8", "global.workspace") + sid_7 = T.allocate([3456], "int8", "global.workspace") + sid_3 = T.allocate([3456], "int8", "global.workspace") + sid_12 = T.allocate([3456], "int8", "global.workspace") + sid_4 = T.allocate([3456], "int8", "global.workspace") + sid_18 = T.allocate([3456], "int8", "global.workspace") + sid_19 = T.allocate([3456], "int8", "global.workspace") + sid_20 = T.allocate([3456], "int8", "global.workspace") + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", data_buffer.data, sid_8, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_8, T.cast(T.lookup_param("p0", dtype="handle"), "handle"), sid_7, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_7, sid_6, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", data_buffer.data, sid_12, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_12, T.cast(T.lookup_param("p1", dtype="handle"), "handle"), sid_11, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_11, sid_10, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add_add_multiply_add", sid_6, sid_10, T.cast(T.lookup_param("p2", dtype="handle"), "handle"), T.cast(T.lookup_param("p3", dtype="handle"), "handle"), T.cast(T.lookup_param("p4", dtype="handle"), "handle"), sid_5, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform_1", sid_5, sid_4, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", sid_4, T.cast(T.lookup_param("p5", dtype="handle"), "handle"), sid_3, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_layout_transform", sid_3, sid_2, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape_1", sid_5, sid_20, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", sid_20, T.cast(T.lookup_param("p6", dtype="handle"), "handle"), sid_19, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_reshape", sid_19, sid_18, dtype="int32")) + T.evaluate(T.tvm_call_cpacked("tvmgen_default_fused_nn_softmax_add", sid_2, sid_18, output_buffer.data, dtype="int32")) +# fmt: on + + +def test_multiple_calls_to_same_primfunc(): + target = Target("c") + global_ws_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + tir_mod = MultipleCallsToSamePrimFuncModule + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) + tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool]) + main_func = tir_mod["run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map) + + # check conflicts + _verify_conflicts( + "sid_18", + [ + "sid_19", + "sid_2", + "T_softmax_exp2", + "T_softmax_maxelem2", + "T_softmax_expsum2", + "T_softmax_norm2", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_3", + [ + "data_pad", + "conv2d_NCHWc_global", + "sid_2", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_norm", + [ + "T_softmax_expsum", + "T_softmax_exp", + "sid_5", + "sid_6", + "T_softmax_maxelem", + "sid_10", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_norm2", + [ + "T_softmax_expsum2", + "T_softmax_maxelem2", + "T_softmax_exp2", + "sid_18", + "sid_2", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_11", + [ + "compute", + "sid_12", + "compute_global", + "sid_10", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_10", + [ + "sid_11", + "sid_6", + "sid_5", + "T_softmax_norm", + "T_softmax_expsum", + "T_softmax_maxelem", + "T_softmax_exp", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_5", + [ + "T_softmax_norm", + "T_softmax_expsum", + "T_softmax_exp", + "sid_6", + "T_softmax_maxelem", + "sid_10", + "sid_4", + "sid_20", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_expsum", + [ + "T_softmax_exp", + "T_softmax_norm", + "sid_5", + "sid_6", + "T_softmax_maxelem", + "sid_10", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_8", + [ + "data_pad", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_expsum2", + [ + "T_softmax_maxelem2", + "T_softmax_exp2", + "sid_18", + "sid_2", + "T_softmax_norm2", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_maxelem2", + [ + "T_softmax_exp2", + "sid_18", + "sid_2", + "T_softmax_expsum2", + "T_softmax_norm2", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_12", + [ + "sid_11", + "compute", + "compute_global", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_19", + [ + "sid_20", + "compute", + "compute_global", + "sid_18", + ], + buffer_info_map, + ) + _verify_conflicts( + "conv2d_NCHWc_global", + [ + "data_pad", + "sid_7", + "sid_3", + "data_pad", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_exp2", + [ + "sid_18", + "sid_2", + "T_softmax_maxelem2", + "T_softmax_expsum2", + "T_softmax_norm2", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_7", + [ + "conv2d_NCHWc_global", + "data_pad", + "sid_6", + ], + buffer_info_map, + ) + _verify_conflicts( + "data_pad", + [ + "sid_8", + "conv2d_NCHWc_global", + "sid_7", + "sid_4", + "sid_3", + "conv2d_NCHWc_global", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_20", + [ + "sid_5", + "sid_19", + "compute", + "compute_global", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_4", + [ + "sid_5", + "data_pad", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_exp", + [ + "T_softmax_expsum", + "T_softmax_norm", + "sid_5", + "sid_6", + "T_softmax_maxelem", + "sid_10", + ], + buffer_info_map, + ) + _verify_conflicts( + "compute_global", + [ + "sid_12", + "sid_11", + "compute", + "compute", + "sid_20", + "sid_19", + ], + buffer_info_map, + ) + _verify_conflicts( + "compute", + [ + "sid_11", + "sid_12", + "compute_global", + "sid_20", + "sid_19", + "compute_global", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_6", + [ + "sid_7", + "sid_5", + "T_softmax_norm", + "T_softmax_expsum", + "T_softmax_exp", + "T_softmax_maxelem", + "sid_10", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_maxelem", + [ + "sid_6", + "sid_5", + "T_softmax_norm", + "T_softmax_expsum", + "T_softmax_exp", + "sid_10", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_2", + [ + "sid_3", + "sid_18", + "T_softmax_exp2", + "T_softmax_maxelem2", + "T_softmax_expsum2", + "T_softmax_norm2", + ], + buffer_info_map, + ) + + if __name__ == "__main__": pytest.main([__file__] + sys.argv[1:]) From 64f1d552c492095357cc65aec9470f1fa9e3e38e Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Mon, 22 Nov 2021 14:41:22 +0000 Subject: [PATCH 8/8] [TIR][USMP] Added buffer info extraction pass * Attaching targets to PrimFuncs in the util test case Change-Id: I82960512659a346f6242b2b5789ec1120f8ea2cf --- tests/python/unittest/test_tir_usmp_utils.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py index 53064fae7b46..232bf6a151fc 100644 --- a/tests/python/unittest/test_tir_usmp_utils.py +++ b/tests/python/unittest/test_tir_usmp_utils.py @@ -173,13 +173,24 @@ def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): return ret +def _assign_targets_to_primfuncs_irmodule(mod, target): + """helper to assign target for PrimFunc in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = basefunc.with_attr("target", target) + return ret + + def test_create_array_buffer_info(): + target = Target("c") global_ws_pool = usmp_utils.PoolInfo( pool_name="global_workspace", - target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") tir_mod = LinearStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool]) main_func = tir_mod["tvmgen_default_run_model"] buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)