-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity][Transform] Handle symbolic variables in LambdaLift #16411
[Unity][Transform] Handle symbolic variables in LambdaLift #16411
Conversation
b9209cb
to
ec4a2ed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes definitely seem to be an improvement in terms of code organization, but I have a request for clarification (see comments) as to how this change achieves the goal of handling symbolic vars.
src/relax/transform/lambda_lift.cc
Outdated
Function(params, body, func_node->ret_struct_info, func_node->is_pure, func_node->attrs); | ||
// recursive call | ||
if (is_recursive && is_closure) { | ||
// it is required by block_blocker, will be updated later |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this typo was present in the original, but I presume this comment should refer to the BlockBuilder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the catch. The typo was present in the original, and the comment is also out of date after this change. I've removed it, and added an appropriate comment.
src/relax/transform/lambda_lift.cc
Outdated
auto cache = current_lambda_var_; | ||
current_lambda_var_ = binding->var; | ||
|
||
// ExprMutator::VisitBinding_(binding, func_node); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably subjective, but this comment is a little more cryptic compared to just pointing out that we are visiting a function literal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the catch, and the comment was some test code, and is now removed.
func->attrs); | ||
builder_->UpdateFunction(pair.first, func); | ||
for (auto [gvar, base_func] : glob_funcs) { | ||
if (auto opt = base_func.as<Function>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, I didn't know about using as
for ObjectRef
s directly 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! This was was added back in #14522, and I really like the way it avoids the repetition in if(obj->IsInstance<TNode>())
checks and Downcast<T>(obj)
casts inside the conditional.
src/relax/transform/lambda_lift.cc
Outdated
struct Context { | ||
explicit Context(int* ptr) : ptr(ptr) { (*ptr)++; } | ||
~Context() { (*ptr)--; } | ||
int* ptr; | ||
} context(&depth); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't appear to be used anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, debug code for indenting prints during nested calls. Removed.
src/relax/transform/lambda_lift.cc
Outdated
if (bool is_closure = IsClosure(var); | ||
is_closure && builder_->LookupBinding(var).as<CallNode>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason to write it this way instead of IsClosure(var) && ...
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. A previous implementation of IsClosure
returned an Optional<_>
, and this stuck around. Updated.
// it is required by block_blocker, will be updated later | ||
nested_closure_map_.emplace( | ||
current_lambda_var_.value(), | ||
Call(gvar_lifted_func, captured_vars.Map([](Var var) -> Expr { return var; }))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume the map is so that an Array<Var>
is treated as Array<Expr>
, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, that's correct. I'm wondering if there should be an implicit conversion from Array<DerivedClass>
to Array<BaseClass>
, to avoid needing this type of conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might have to be careful to ensure type safety can't get broken that way, that has tended to lead to soundness issues in languages. (I believe that type of thing is a source of unsoundness in TypeScript, IIRC.)
// Must visit the function itself, and not just the function | ||
// body, to ensure that EraseToWellDefined recognized symbolic | ||
// variables that are exposed by the function signature. | ||
auto func = Downcast<Function>(VisitExpr(opt.value())); | ||
builder_->UpdateFunction(gvar, func); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the only change needed to handle symbolic vars, fundamentally? I ask because I didn't see any logic specific to symbolic vars above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Primarily, yes. Most of the other changes are to allow this specific change to work. Previously, the mutator could assume that every FunctionNode
encountered would be a lambda to be lifted out, but now it could be a top-level function as well.
Prior to this commit, symbolic variables used by a lambda function would be duplicated between the caller and the lifted-out function. In addition, shape inference within the lifted-out function was performed without access to the symbolic variables, resulting in unnecessary fallback from `R.Tensor([m, n])` to `R.Tensor(ndim=2)`. This commit updates the `LambdaLift` transform to handle symbolic variables. All symbolic variables have unique definitions across the resulting `IRModule`, and shape inference in the lifted-out function is aware of symbolic variables that have been exposed to it.
ec4a2ed
to
43b7f8c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My comments were addressed.
Prior to this commit, symbolic variables used by a lambda function would be duplicated between the caller and the lifted-out function. In addition, shape inference within the lifted-out function was performed without access to the symbolic variables, resulting in unnecessary fallback from
R.Tensor([m, n])
toR.Tensor(ndim=2)
.This commit updates the
LambdaLift
transform to handle symbolic variables. All symbolic variables have unique definitions across the resultingIRModule
, and shape inference in the lifted-out function is aware of symbolic variables that have been exposed to it.