-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Helper function to wrap entry HLO #3920
Conversation
cc2b4b4
to
ceb994e
Compare
It is working, need to refactor the code a bit and make |
TODO: we need to add |
resnet without tupling
resnet with tupling (by forcing a very low
doesn't seems like parameter tupling will affect speed too much(if at all) |
ca390fd
to
bf9b61e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
torch_xla/csrc/helpers.cpp
Outdated
|
||
// Call the original computation. | ||
xla::XlaOp orig_result; | ||
orig_result = xla::Call(&builder, computation, inner_params); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Combine these two lines into one?
torch_xla/csrc/helpers.cpp
Outdated
}(); | ||
|
||
// Handle the results of the original computation. | ||
const std::vector<xla::XlaOp> inner_params = [&input_tuple, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my education, what's the advantage of defining these parameters in such way, i.e, using lambdas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh.. This part of code was modified from a legacy internal code. From what I can tell this just keeps all of the intermediate variable within the lambda scope. I will remove it as we don't use this kind of style in other part of pt/xla.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First time seeing such technique actually. Maybe under some circumstances, it's good to free memory gradually in the scope of function instead of all at once during return... If just for readability, block should do the trick I guess?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, I think so. The only thing I can think of(why this is preferred than a block) is that one must specified what variable from outer block can be accessed. I don't think it fits code style for this repo too good so I will remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fair. (TBH, I don't use blocks often myself...)
@ronghanghu Latest wheel should have this fix, let me know if this fixed your issue.(if you set |
Thanks @JackCaoG! I'll try this out. |
@JackCaoG Thanks for the fix here. I confirm that this resolves our previous issue of |
This feature is for PJRT only.
This pr fix the
Computation requires more parameters (3348) than supported (limit 3302)
error on TPU.TPU has a number of parameter limit that is irrelevant to the actual parameter size. On TPUv3 this limit is 3302. XLA's definition of "parameter" is
data that needs to be moved from cpu to the device
. To workaround this issue we can wrap the parameter into a single tuple which reduce the number of parameter to 1. For example a HLO liketo
We added another function as entry to forward the parameter to real function. I verified this unblock the issue we saw when scaling up.
PJRT handles tupling input buffers, caller only need to handle wrapping the input for hlo text.
TODO:
- [ ] addXLA_PARAMETER_WRAPPING_THREADSHOLD
value to the graph hash otherwise we might have hash collision(tuple input version and non-tuple version)on a second thought, given tupling doesn't affect speed too much(if at all), this should not be an issue. Even if someone manually modify
XLA_PARAMETER_WRAPPING_THREADSHOLD
and incorrectly hit the cache (expect non-tupling but see tupling) the program will still run correctly because from caller perspective we still pass all parameter as vector.On top of this, we currently do not support long lasting cache for PJRT, so this is actually impossible to happen.