Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][Transform] Handle symbolic variables in LambdaLift #16411

Merged
merged 2 commits into from
Jan 23, 2024

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, symbolic variables used by a lambda function would be duplicated between the caller and the lifted-out function. In addition, shape inference within the lifted-out function was performed without access to the symbolic variables, resulting in unnecessary fallback from R.Tensor([m, n]) to R.Tensor(ndim=2).

This commit updates the LambdaLift transform to handle symbolic variables. All symbolic variables have unique definitions across the resulting IRModule, and shape inference in the lifted-out function is aware of symbolic variables that have been exposed to it.

@Lunderberg Lunderberg force-pushed the unity_lambda_lift_symbolic_vars branch from b9209cb to ec4a2ed Compare January 16, 2024 20:51
Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes definitely seem to be an improvement in terms of code organization, but I have a request for clarification (see comments) as to how this change achieves the goal of handling symbolic vars.

Function(params, body, func_node->ret_struct_info, func_node->is_pure, func_node->attrs);
// recursive call
if (is_recursive && is_closure) {
// it is required by block_blocker, will be updated later
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this typo was present in the original, but I presume this comment should refer to the BlockBuilder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the catch. The typo was present in the original, and the comment is also out of date after this change. I've removed it, and added an appropriate comment.

auto cache = current_lambda_var_;
current_lambda_var_ = binding->var;

// ExprMutator::VisitBinding_(binding, func_node);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably subjective, but this comment is a little more cryptic compared to just pointing out that we are visiting a function literal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the catch, and the comment was some test code, and is now removed.

func->attrs);
builder_->UpdateFunction(pair.first, func);
for (auto [gvar, base_func] : glob_funcs) {
if (auto opt = base_func.as<Function>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, I didn't know about using as for ObjectRefs directly 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! This was was added back in #14522, and I really like the way it avoids the repetition in if(obj->IsInstance<TNode>()) checks and Downcast<T>(obj) casts inside the conditional.

Comment on lines 428 to 432
struct Context {
explicit Context(int* ptr) : ptr(ptr) { (*ptr)++; }
~Context() { (*ptr)--; }
int* ptr;
} context(&depth);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't appear to be used anywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, debug code for indenting prints during nested calls. Removed.

Comment on lines 374 to 375
if (bool is_closure = IsClosure(var);
is_closure && builder_->LookupBinding(var).as<CallNode>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to write it this way instead of IsClosure(var) && ...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. A previous implementation of IsClosure returned an Optional<_>, and this stuck around. Updated.

// it is required by block_blocker, will be updated later
nested_closure_map_.emplace(
current_lambda_var_.value(),
Call(gvar_lifted_func, captured_vars.Map([](Var var) -> Expr { return var; })));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume the map is so that an Array<Var> is treated as Array<Expr>, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that's correct. I'm wondering if there should be an implicit conversion from Array<DerivedClass> to Array<BaseClass>, to avoid needing this type of conversion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might have to be careful to ensure type safety can't get broken that way, that has tended to lead to soundness issues in languages. (I believe that type of thing is a source of unsoundness in TypeScript, IIRC.)

Comment on lines +483 to +478
// Must visit the function itself, and not just the function
// body, to ensure that EraseToWellDefined recognized symbolic
// variables that are exposed by the function signature.
auto func = Downcast<Function>(VisitExpr(opt.value()));
builder_->UpdateFunction(gvar, func);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the only change needed to handle symbolic vars, fundamentally? I ask because I didn't see any logic specific to symbolic vars above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Primarily, yes. Most of the other changes are to allow this specific change to work. Previously, the mutator could assume that every FunctionNode encountered would be a lambda to be lifted out, but now it could be a top-level function as well.

Prior to this commit, symbolic variables used by a lambda function
would be duplicated between the caller and the lifted-out function.
In addition, shape inference within the lifted-out function was
performed without access to the symbolic variables, resulting in
unnecessary fallback from `R.Tensor([m, n])` to `R.Tensor(ndim=2)`.

This commit updates the `LambdaLift` transform to handle symbolic
variables.  All symbolic variables have unique definitions across the
resulting `IRModule`, and shape inference in the lifted-out function
is aware of symbolic variables that have been exposed to it.
@Lunderberg Lunderberg force-pushed the unity_lambda_lift_symbolic_vars branch from ec4a2ed to 43b7f8c Compare January 22, 2024 20:02
@Lunderberg Lunderberg changed the base branch from unity to main January 22, 2024 20:02
Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comments were addressed.

@Lunderberg Lunderberg merged commit 7789b24 into apache:main Jan 23, 2024
16 of 17 checks passed
@Lunderberg Lunderberg deleted the unity_lambda_lift_symbolic_vars branch January 23, 2024 16:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants