forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[microNPU] Add a pass to move allocate nodes to the outer scope (apac…
…he#10725) * [microNPU] Add a pass to move allocate nodes to the outer scope Adds a pass called `HoistAllocates` to move allocate nodes to the top of the body of the main function. In doing so, it opens the door to other optimizations that need to swap the ordering of external calls. Pass illustration: (before) ``` allocate { extern_call { allocate { extern_call { } } } } ``` (after) ``` allocate { allocate { extern_call extern_call } } ``` Change-Id: Ibcfc3c75b15deebb5c6645a4923a6ddf683b37c4 * address comments * uses prim func pass, rather than module pass. * adds error message informing user to run this pass with LowerToTIR() pass for now. Change-Id: I57757b9dc5bff0208034a974a341c09cce0294bc * Support allocates when not followed by a sequence statement With a test to back this case up. Change-Id: I670809f5ee53b583a15d9b783852dda3089756e9 * Add new directory tir/contrib/ethosu to cmake build Change-Id: I3e9f24adfe992ace4e03238a18a8378b03257e1a
- Loading branch information
Showing
6 changed files
with
437 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tir/contrib/ethosu/passes.cc | ||
* | ||
* \brief Passes used in TIR lowering for the microNPU compiler. | ||
*/ | ||
#include <tvm/tir/builtin.h> | ||
#include <tvm/tir/function.h> | ||
#include <tvm/tir/stmt_functor.h> | ||
#include <tvm/tir/transform.h> | ||
|
||
namespace tvm { | ||
namespace tir { | ||
namespace contrib { | ||
namespace ethosu { | ||
|
||
/*! | ||
* \brief This mutator moves allocates to the top of the body of the main | ||
* function. | ||
* | ||
* Note: This pass can currently only be run in conjunction with the | ||
* LowerToTIR() pass as it expects a single primitive function called | ||
* "main" that is being offloaded to the NPU. | ||
* | ||
* For example, | ||
* Before: | ||
* allocate { | ||
* extern_call(...) | ||
* allocate { | ||
* extern_call(...) | ||
* } | ||
* } | ||
* | ||
* After: | ||
* allocate { | ||
* allocate { | ||
* extern_call(...) | ||
* extern_call(...) | ||
* } | ||
* } | ||
*/ | ||
class HoistAllocatesMutator : public StmtExprMutator { | ||
public: | ||
HoistAllocatesMutator() {} | ||
|
||
PrimFunc operator()(PrimFunc main_func) { | ||
Stmt new_main_func_body = this->VisitStmt(main_func->body); | ||
|
||
// Write all allocates that were removed in reverse order | ||
for (auto it = allocates_.rbegin(); it != allocates_.rend(); it++) { | ||
Allocate current_alloc = *it; | ||
if (it != allocates_.rbegin()) { | ||
new_main_func_body = SeqStmt({new_main_func_body}); | ||
} | ||
new_main_func_body = | ||
Allocate(current_alloc->buffer_var, current_alloc->dtype, current_alloc->extents, | ||
current_alloc->condition, new_main_func_body, current_alloc->annotations, | ||
current_alloc->span); | ||
} | ||
|
||
PrimFunc new_main_func = | ||
PrimFunc(main_func->params, new_main_func_body, main_func->ret_type, main_func->buffer_map, | ||
main_func->preflattened_buffer_map, main_func->attrs); | ||
return new_main_func; | ||
} | ||
|
||
private: | ||
Stmt VisitStmt_(const AllocateNode* op) override { | ||
allocates_.push_back(GetRef<Allocate>(op)); | ||
|
||
// Skip the allocate node itself | ||
if (const auto* seq = op->body.as<SeqStmtNode>()) { | ||
// Traverse the allocate body recursively and flatten | ||
Array<Stmt> new_stmts; | ||
new_stmts.reserve(seq->seq.size()); | ||
for (const Stmt& old_stmt : seq->seq) { | ||
new_stmts.push_back(VisitStmt(old_stmt)); | ||
} | ||
return SeqStmt::Flatten(new_stmts); | ||
} else { | ||
return VisitStmt(op->body); | ||
} | ||
} | ||
|
||
/*! A stack to store allocates as they are visited. */ | ||
std::vector<Allocate> allocates_; | ||
}; | ||
|
||
/*! | ||
* \brief A pass to hoist allocate nodes to the top of the body of the main function. | ||
* | ||
* \return tvm::transform::Pass | ||
*/ | ||
tvm::transform::Pass HoistAllocates() { | ||
auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) { | ||
ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main")) | ||
<< "Expected a single primitive function called 'main'. Please run the HoistAllocates pass " | ||
"in conjunction with the LowerToTIR() pass."; | ||
return HoistAllocatesMutator()(f); | ||
}; | ||
return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.HoistAllocates", | ||
{}); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.HoistAllocates").set_body_typed(HoistAllocates); | ||
|
||
} // namespace ethosu | ||
} // namespace contrib | ||
} // namespace tir | ||
} // namespace tvm |
Oops, something went wrong.