-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hexagon] Add support for linked-in model parameters #8865
Changes from 5 commits
8abf2b7
9d0ba8c
2a89a7e
88fd6ea
583c489
70ed99e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -706,12 +706,38 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { | |
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target); | ||
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext()); | ||
std::unique_ptr<CodeGenHexagon> cg(new CodeGenHexagon()); | ||
cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false); | ||
|
||
std::vector<PrimFunc> funcs; | ||
Map<String, LinkedParam> linked_params; | ||
bool could_have_linked_params = target->GetAttr<Bool>("link-params").value_or(Bool(false)); | ||
|
||
for (auto kv : mod->functions) { | ||
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs"; | ||
if (could_have_linked_params && | ||
kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) { | ||
// If `f` is the linked-params function, extract the parameters from the | ||
// attribute dictionary, and skip the codegen. | ||
auto attrs_dict = Downcast<Map<String, ObjectRef>>(kv.second->attrs->dict); | ||
CHECK(attrs_dict.find(::tvm::tir::attr::kLinkedParams) != attrs_dict.end()) | ||
<< "no " << ::tvm::tir::attr::kLinkedParams << " attribute found!"; | ||
|
||
CHECK(linked_params.empty()) << "Multiple linked-param functions"; | ||
linked_params = | ||
Downcast<Map<String, LinkedParam>>(attrs_dict[::tvm::tir::attr::kLinkedParams]); | ||
continue; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit : this seems searching for whether at least one function have linked_params, if so I think we dont need to continue searching. A further suggestion is to break it out to a function for that check w/o needing to maintain found_linked_params bool. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is traversing all functions and
We need to skip the "linked-params" function, because the codegen path for it is different. We create the bool variable so that we know to generate the linked parameters function later on, but the variable is not the only effect of the loop, so there isn't much to gain by extracting it into a separate function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ack, so this is just to skip the linked-params function -- might be better to add a comment. I guess my original concern is it would be bit more readable if it was a different function to obtain the linked-params altogether rather than fusing it with the loop as you correctly note here -- it being not the only effect, WDYT? cc: @d-smirnov , another occurence where linked-param will be used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will still need to recognize it when adding functions to the codegen module (to avoid adding it). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ack and thanks for the explaination -- maybe a comment would help then :) |
||
} | ||
auto f = Downcast<PrimFunc>(kv.second); | ||
funcs.emplace_back(f); | ||
} | ||
|
||
cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false); | ||
for (const PrimFunc& f : funcs) { | ||
cg->AddFunction(f); | ||
} | ||
if (!linked_params.empty()) { | ||
cg->LinkParameters(linked_params); | ||
} | ||
|
||
// Uncomment to get the LLVM module right out of codegen, before optimizations. | ||
// std::cerr << "HexagonModule.0 {\n" << *cg->GetModulePtr() << "}\n"; | ||
std::unique_ptr<llvm::Module> module = cg->Finish(); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,9 +15,11 @@ | |
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import numpy as np | ||
import os | ||
import re | ||
import tvm | ||
import tvm.relay | ||
import tvm.contrib.hexagon as hexagon | ||
|
||
|
||
|
@@ -107,7 +109,71 @@ def test_alloc_vtcm(): | |
assert "HexagonBackendFreeVTCM" in calls | ||
|
||
|
||
def test_linked_params_codegen(): | ||
if not check_prereq_and_setup(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (also ok for a follow-up) do you mind migrating this to the pytest decorator style and modifying the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I'll do it, but I'd prefer to do it in a separate PR, since I'd need to change some extra files. |
||
return | ||
|
||
# A simple model (a single conv2d) to trigger parameter separation: | ||
mod_lines = [ | ||
'#[version = "0.0.5"]', | ||
"def @main(%input: Tensor[(1, 16, 16, 3), uint8], %weights: Tensor[(3, 3, 3, 3), uint8])" | ||
" -> Tensor[(1, 14, 14, 3), uint8] {", | ||
' nn.conv2d(%input, %weights, data_layout="NHWC", kernel_layout="HWIO", ' | ||
'kernel_size=[3, 3], out_dtype="uint8")', | ||
"}", | ||
] | ||
mod = tvm.parser.fromtext("\n".join(mod_lines)) | ||
# Make the params be 81 x 'T': | ||
params = {"weights": np.full([3, 3, 3, 3], fill_value=ord("T"), dtype=np.uint8)} | ||
|
||
target = tvm.target.hexagon("v68", link_params=True) | ||
|
||
with tvm.transform.PassContext(opt_level=3): | ||
lib = tvm.relay.build(mod, target=target, target_host=target, params=params) | ||
llvm_ir = lib.get_lib().get_source("ll") | ||
|
||
# The definition of the parameter: | ||
p0_def_re = r"@__tvm_param__p0 = internal constant \[81 x i8\] c\"T{81}\", align 128" | ||
assert re.search(p0_def_re, llvm_ir) | ||
|
||
# The body of the _lookup_linked_param function: | ||
linked_param_re = r"(define.*@_lookup_linked_param\(.*\).* {[^}]*})" | ||
linked_param_body = re.search(linked_param_re, llvm_ir, flags=re.MULTILINE) | ||
assert linked_param_body and linked_param_body.groups() | ||
|
||
# Reference to the parameter: | ||
p0_use_re = r"\[81 x i8\]\* @__tvm_param__p0" | ||
assert re.search(p0_use_re, linked_param_body.groups()[0]) | ||
|
||
""" | ||
A snippet of actual LLVM IR containing the definition of the linked | ||
parameter, and the the body of the _lookup_linked_param function. | ||
|
||
|
||
@__tvm_param__p0 = internal constant [81 x i8] c"TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT", align 128 | ||
|
||
define dllexport i32 @_lookup_linked_param(i8* nocapture readonly %0, i32* nocapture readnone %1, i32 %2, i8* nocapture %3, i32* nocapture %4, i8* nocapture readnone %5) local_unnamed_addr #2 { | ||
entry: | ||
%6 = bitcast i8* %0 to i64* | ||
%7 = load i64, i64* %6, align 8 | ||
%cond = icmp eq i64 %7, 1 | ||
br i1 %cond, label %case___tvm_param__p0, label %common.ret | ||
|
||
common.ret: ; preds = %entry, %case___tvm_param__p0 | ||
%storemerge = phi i32 [ 3, %case___tvm_param__p0 ], [ 4, %entry ] | ||
store i32 %storemerge, i32* %4, align 4 | ||
ret i32 0 | ||
|
||
case___tvm_param__p0: ; preds = %entry | ||
%8 = bitcast i8* %3 to i8** | ||
store i8* getelementptr inbounds ([81 x i8], [81 x i8]* @__tvm_param__p0, i32 0, i32 0), i8** %8, align 4 | ||
br label %common.ret | ||
} | ||
""" | ||
|
||
|
||
if __name__ == "__main__": | ||
test_basic() | ||
test_llvm_target_features() | ||
test_alloc_vtcm() | ||
test_linked_params_codegen() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we replacing the map here (as opposed to setting it) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "replacing" the map?
Map
is anObjectRef
, so we're not really copying anything here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the reason for declaring the Map outside of the scope of loop is to 'Set' it -- hence the question.
Another way to ask the same thing : any reason to declare the Map outside of the for loop ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's used after the loop. The map holds the parameters that we pass to
LinkParameters
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, it seems this work because it is only set once in the loop.
Any reason not to codegen it here (rather than doing it down there) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to add a check there to make sure it's not set more than once.