-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] Partial Evaluation #2714
Conversation
This is now feature complete. I will add some test and it will be good. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool PR! Just a few minor suggestions on my part.
tagging more reviewers, @wweic @junrushao1994 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR enables compile-time partial evaluation. LGTM. This could slow down compilation, especially when all computation is done on CPU. But it should be fine for most of applications.
Only some minor nits.
src/relay/pass/partial_eval.cc
Outdated
|
||
using namespace runtime; | ||
|
||
struct StaticNode { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just inherit the TVM's node system? Should be similar i guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is considerably more work, and we has no need.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While it might take a bit more work to adopt TVM's node system, it makes us share the same infra and reuse later. I will recommend using node instead here.
Note that it is not too much more work, just replace shared_ptr->NodePtr and make_shared->make_node
src/relay/pass/partial_eval.cc
Outdated
|
||
using namespace runtime; | ||
|
||
struct StaticNode { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While it might take a bit more work to adopt TVM's node system, it makes us share the same infra and reuse later. I will recommend using node instead here.
Note that it is not too much more work, just replace shared_ptr->NodePtr and make_shared->make_node
Nice work. I'll review today. Interesting to see its effect in the future. As I understand, majority of relay function calls are ADT utilities and recursion generated by framework converter. I don't think PE can help the latter case since it's conditioned on dynamic value, not easy to handle by PE. |
@wweic we are planning to use PE to remove the closure/reference of higher order ad as much as possible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall looks great. Request for comments and clarification.
src/relay/pass/partial_eval.cc
Outdated
Environment(const Environment&) = delete; | ||
|
||
template<typename T> | ||
T Extend(const std::function<T()>& cont) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe name it continuation
or body
to make it clear?
@junrushao1994 can you approve? |
@MarisaKirisame I modified your test case a bit, instead of creating a ref from const 1, I pass in an argument and create ref from there. I think PE did an incorrect inlining: In [5]: print(pe_f.astext())
v0.0.1
%17 = fn (%d: bool, %i: int32) -> int32 {
%0 = ref(%i)
%16 = {
let %x: ref(int32) = %0
%4 = fn () -> () {
%1 = add(%i, %i) // ty=int32
%2 = (%x := %1)
%3 = {
let %x2: () = %2
%x2
}
%3
}
%15 = {
let %x1: fn () -> () = %4
%11 = if (%d) {
%5 = add(%i, %i) // ty=int32
%6 = (%x := %5)
%7 = {
let %x4: () = %6
%x4
}
%7
} else {
%8 = add(%i, %i) // ty=int32
%9 = (%x := %8)
%10 = {
let %x5: () = %9
%x5
}
%10
}
%14 = {
let %x3: () = %11
%12 = %x^
%13 = {
let %x6: int32 = %12
%x6
}
%13
}
%14
}
%15
}
%16
}
%17
In [6]: print(f.astext())
v0.0.1
%13 = fn (%d: bool, %i: int32) -> int32 {
%0 = ref(%i)
%12 = {
let %r: ref(int32) = %0
%5 = fn () -> () {
%1 = %r^
%2 = %r^
%3 = add(%1, %2) // ty=int32
%4 = (%r := %3)
%4
}
%11 = {
let %u: fn () -> () = %5
%8 = if (%d) {
%6 = %u() // ty=()
%6
} else {
%7 = %u() // ty=()
%7
}
%10 = {
let %eff: () = %8
%9 = %r^
%9
}
%10
}
%11
}
%12
}
%13 It's aggressively inlines the value into the closure, but I think closure only captures the reference, not the value, right? |
@wweic good catch, that is indeed incorrect. can you give me the offending test? I know how to fix it but wanna add it to the test case. |
def test_if_ref():
shape = ()
dtype = 'bool'
t = relay.TensorType(shape, dtype)
d = relay.Var("d", t)
i = relay.Var("i", relay.TensorType(shape, 'int32'))
r = relay.Var("r")
update = relay.Function([], relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)))
u = relay.Var("u")
body = relay.If(d, u(), u())
eff = relay.Var("eff")
body = relay.Let(eff, body, relay.RefRead(r))
f = relay.Function([d, i], relay.Let(r, relay.RefCreate(i), relay.Let(u, update, body)))
f = infer_type(f)
pe_f = infer_type(partial_eval(f)) |
@wweic I had updated with a regression test to the error. |
@wweic the test is added. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! @MarisaKirisame
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :-)
@tqchen this is good. can you give some review? |
lint lint save save add more case save error lint lint commit do lint save fix lint wrap it back as func lint save remove dead comment fix style fix lint Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> address review feedback pe now handle freevar. as a result preserving function is now trivial. test add basic test, implement pretty printing for generic function test lint fix segfault save save do test fix another error address comment commit save address review feedback add test for invalidate, fix error in lookup rename cont to boduy fix error and add regression test fix error, add test case Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> fix lint remove extra line save save
@tqchen I address your comment. can you merge? |
lint lint save save add more case save error lint lint commit do lint save fix lint wrap it back as func lint save remove dead comment fix style fix lint Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> address review feedback pe now handle freevar. as a result preserving function is now trivial. test add basic test, implement pretty printing for generic function test lint fix segfault save save do test fix another error address comment commit save address review feedback add test for invalidate, fix error in lookup rename cont to boduy fix error and add regression test fix error, add test case Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> fix lint remove extra line save save
lint lint save save add more case save error lint lint commit do lint save fix lint wrap it back as func lint save remove dead comment fix style fix lint Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> address review feedback pe now handle freevar. as a result preserving function is now trivial. test add basic test, implement pretty printing for generic function test lint fix segfault save save do test fix another error address comment commit save address review feedback add test for invalidate, fix error in lookup rename cont to boduy fix error and add regression test fix error, add test case Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame <lolisa@marisa.moe> fix lint remove extra line save save
During the pass two weeks, i had prototyped a partial evaluator in metaocaml for a toy DSL with similar characteristic to relay.
It works by taking a evaluator, lifting the value to partially-static domain, reifying the store, and anfing the code generated using let list to avoid code duplication and get capture avoidance substitution for free.
It is written such that it could be translated line by line to tvm.
I had wrote some test, and it can indeed remove the closure and reference #2496 generate.
However, it does generate lots of dead code (closure, reference) that need to be remove by the dead code elimination pass. We probably need an ad hoc naive escape analysis as only write to dead variable is ok to remove.
I will port this but the high level idea is ready for review: the prototype is only 336 loc (including all the DSL definition), and the core function is less then 100 line.
Thanks to @weberlo for reviewing the doc