Skip to content

Commit

Permalink
[microNPU] Add a pass to move allocate nodes to the outer scope (apac…
Browse files Browse the repository at this point in the history
…he#10725)

* [microNPU] Add a pass to move allocate nodes to the outer scope

Adds a pass called `HoistAllocates` to move allocate nodes to the top
of the body of the main function. In doing so, it opens the door to
other optimizations that need to swap the ordering of external calls.

Pass illustration:
(before)
```
allocate {
    extern_call {
        allocate {
            extern_call {

            }
        }
    }
}
```

(after)
```
allocate {
    allocate {
        extern_call
        extern_call
    }
}
```

Change-Id: Ibcfc3c75b15deebb5c6645a4923a6ddf683b37c4

* address comments

* uses prim func pass, rather than module pass.
* adds error message informing user to run this pass with LowerToTIR()
  pass for now.

Change-Id: I57757b9dc5bff0208034a974a341c09cce0294bc

* Support allocates when not followed by a sequence statement

With a test to back this case up.

Change-Id: I670809f5ee53b583a15d9b783852dda3089756e9

* Add new directory tir/contrib/ethosu to cmake build

Change-Id: I3e9f24adfe992ace4e03238a18a8378b03257e1a
  • Loading branch information
lhutton1 authored and pfk-beta committed Apr 11, 2022
1 parent 60a6db2 commit 2839703
Show file tree
Hide file tree
Showing 6 changed files with 437 additions and 8 deletions.
3 changes: 2 additions & 1 deletion cmake/modules/contrib/EthosU.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ if(USE_ETHOSU)
tvm_file_glob(GLOB COMPILER_ETHOSU_SRCS
src/relay/backend/contrib/ethosu/*
src/contrib/ethosu/cascader/*
src/contrib/ethosu/cascader/parts/*)
src/contrib/ethosu/cascader/parts/*
src/tir/contrib/ethosu/*)
list(APPEND COMPILER_SRCS ${COMPILER_ETHOSU_SRCS})
else()
# Keeping just utils.cc because it has Object definitions
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/_ffi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
import tvm._ffi # type: ignore

tvm._ffi._init_api("relay.ext.ethos-u", __name__)
tvm._ffi._init_api("tir.contrib.ethos-u", __name__)
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
mod = ethosu_passes.ReplaceOperators()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
mod = ethosu_passes.HoistAllocates()(mod)
disable_storage_rewrite = curr_cfg.get("tir.disable_storage_rewrite", False)
if not disable_storage_rewrite:
mod = tvm.tir.transform.StorageRewrite()(mod)
Expand Down
28 changes: 21 additions & 7 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from .transform import get_copy_params
from .utils import get_weights_buffer, get_scale_bias_buffer

from .. import _ffi_api


def RemoveZeroStores():
"""This pass removes stores which just store zero to initialise buffers.
Expand All @@ -48,7 +50,7 @@ def _ftransform(f, mod, ctx):
)

return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.ethosu.remove_zero_stores"
_ftransform, opt_level=0, name="tir.contrib.ethos-u.remove_zero_stores"
)


Expand Down Expand Up @@ -207,7 +209,7 @@ def _ftransform(f, mod, ctx):
)

return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.ethosu.replace_operators"
_ftransform, opt_level=0, name="tir.contrib.ethos-u.replace_operators"
)


Expand Down Expand Up @@ -296,7 +298,7 @@ def _ftransform(f, mod, ctx):

def _divide_constants(mod):
transform_func = tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.ethosu.divide_constants"
_ftransform, opt_level=0, name="tir.contrib.ethos-u.divide_constants"
)
new_func = transform_func(mod)
return new_func, new_const_dict
Expand Down Expand Up @@ -549,7 +551,7 @@ def _encode_constants(mod):
for key, value in divided_const_dict.items():
const_dict[key] = value
transform_func = tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.ethosu.encode_constants"
_ftransform, opt_level=0, name="tir.contrib.ethos-u.encode_constants"
)
new_func = transform_func(mod)
return new_func, new_const_dict
Expand Down Expand Up @@ -584,7 +586,7 @@ def _ftransform(f, mod, ctx):
)

return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.ethosu.annotate_allocates"
_ftransform, opt_level=0, name="tir.contrib.ethos-u.annotate_allocates"
)


Expand Down Expand Up @@ -751,7 +753,7 @@ def _ftransform(f, mod, ctx):
)

return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.ethosu.remove_concatenates"
_ftransform, opt_level=0, name="tir.contrib.ethos-u.remove_concatenates"
)


Expand Down Expand Up @@ -795,9 +797,21 @@ def _ftransform(f, mod, ctx):

def _create_primfunc_without_constants(mod):
transform_func = tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.ethosu.CreatePrimFuncWithoutConstants"
_ftransform, opt_level=0, name="tir.contrib.ethos-u.CreatePrimFuncWithoutConstants"
)
mod = transform_func(mod)
return mod, new_const_dict

return _create_primfunc_without_constants


def HoistAllocates() -> tvm.IRModule:
"""
Hoist allocate nodes up to the top of the body of the main function.
Returns
-------
tvm.IRModule
The new module with hoisted allocate nodes.
"""
return _ffi_api.HoistAllocates()
128 changes: 128 additions & 0 deletions src/tir/contrib/ethosu/passes.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tir/contrib/ethosu/passes.cc
*
* \brief Passes used in TIR lowering for the microNPU compiler.
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tir {
namespace contrib {
namespace ethosu {

/*!
* \brief This mutator moves allocates to the top of the body of the main
* function.
*
* Note: This pass can currently only be run in conjunction with the
* LowerToTIR() pass as it expects a single primitive function called
* "main" that is being offloaded to the NPU.
*
* For example,
* Before:
* allocate {
* extern_call(...)
* allocate {
* extern_call(...)
* }
* }
*
* After:
* allocate {
* allocate {
* extern_call(...)
* extern_call(...)
* }
* }
*/
class HoistAllocatesMutator : public StmtExprMutator {
public:
HoistAllocatesMutator() {}

PrimFunc operator()(PrimFunc main_func) {
Stmt new_main_func_body = this->VisitStmt(main_func->body);

// Write all allocates that were removed in reverse order
for (auto it = allocates_.rbegin(); it != allocates_.rend(); it++) {
Allocate current_alloc = *it;
if (it != allocates_.rbegin()) {
new_main_func_body = SeqStmt({new_main_func_body});
}
new_main_func_body =
Allocate(current_alloc->buffer_var, current_alloc->dtype, current_alloc->extents,
current_alloc->condition, new_main_func_body, current_alloc->annotations,
current_alloc->span);
}

PrimFunc new_main_func =
PrimFunc(main_func->params, new_main_func_body, main_func->ret_type, main_func->buffer_map,
main_func->preflattened_buffer_map, main_func->attrs);
return new_main_func;
}

private:
Stmt VisitStmt_(const AllocateNode* op) override {
allocates_.push_back(GetRef<Allocate>(op));

// Skip the allocate node itself
if (const auto* seq = op->body.as<SeqStmtNode>()) {
// Traverse the allocate body recursively and flatten
Array<Stmt> new_stmts;
new_stmts.reserve(seq->seq.size());
for (const Stmt& old_stmt : seq->seq) {
new_stmts.push_back(VisitStmt(old_stmt));
}
return SeqStmt::Flatten(new_stmts);
} else {
return VisitStmt(op->body);
}
}

/*! A stack to store allocates as they are visited. */
std::vector<Allocate> allocates_;
};

/*!
* \brief A pass to hoist allocate nodes to the top of the body of the main function.
*
* \return tvm::transform::Pass
*/
tvm::transform::Pass HoistAllocates() {
auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) {
ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main"))
<< "Expected a single primitive function called 'main'. Please run the HoistAllocates pass "
"in conjunction with the LowerToTIR() pass.";
return HoistAllocatesMutator()(f);
};
return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.HoistAllocates",
{});
}

TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.HoistAllocates").set_body_typed(HoistAllocates);

} // namespace ethosu
} // namespace contrib
} // namespace tir
} // namespace tvm
Loading

0 comments on commit 2839703

Please sign in to comment.