Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Helper function to wrap entry HLO #3920

Merged
merged 12 commits into from
Aug 27, 2022
Merged

Helper function to wrap entry HLO #3920

merged 12 commits into from
Aug 27, 2022

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Aug 24, 2022

This feature is for PJRT only.

This pr fix the Computation requires more parameters (3348) than supported (limit 3302) error on TPU.

TPU has a number of parameter limit that is irrelevant to the actual parameter size. On TPUv3 this limit is 3302. XLA's definition of "parameter" is data that needs to be moved from cpu to the device. To workaround this issue we can wrap the parameter into a single tuple which reduce the number of parameter to 1. For example a HLO like

ENTRY %SyncTensorsGraph.23 (p0.2: f32[], p1.5: f32[], p2.8: f32[], p3.16: f32[]) -> (f32[]) {
  %p3.16 = f32[] parameter(3), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_para.py" source_line=9}
  %constant.14 = f32[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %constant.13 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %multiply.15 = f32[] multiply(f32[] %constant.14, f32[] %constant.13), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_para.py" source_line=11}
  %add.17 = f32[] add(f32[] %p3.16, f32[] %multiply.15), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_para.py" source_line=11}
  %constant.11 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %constant.10 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %multiply.12 = f32[] multiply(f32[] %constant.11, f32[] %constant.10), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_para.py" source_line=11}
  %add.18 = f32[] add(f32[] %add.17, f32[] %multiply.12), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_para.py" source_line=11}
  %p2.8 = f32[] parameter(2), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_para.py" source_line=11}
  %constant.7 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %multiply.9 = f32[] multiply(f32[] %p2.8, f32[] %constant.7), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_para.py" source_line=11}
  %add.19 = f32[] add(f32[] %add.18, f32[] %multiply.9), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_para.py" source_line=11}
  %p1.5 = f32[] parameter(1), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_para.py" source_line=11}
  %constant.4 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %multiply.6 = f32[] multiply(f32[] %p1.5, f32[] %constant.4), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_para.py" source_line=11}
  %add.20 = f32[] add(f32[] %add.19, f32[] %multiply.6), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_para.py" source_line=11}
  %p0.2 = f32[] parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_para.py" source_line=11}
  %constant.1 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %multiply.3 = f32[] multiply(f32[] %p0.2, f32[] %constant.1), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_para.py" source_line=11}
  %add.21 = f32[] add(f32[] %add.20, f32[] %multiply.3), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_para.py" source_line=11}
  ROOT %tuple.22 = (f32[]) tuple(f32[] %add.21)
}

to

HloModule SyncTensorsGraph.23.31, entry_computation_layout={((f32[], f32[], f32[], f32[]))->(f32[])}

%SyncTensorsGraph.6 (p0.8: f32[], p1.11: f32[], p2.14: f32[], p3.22: f32[]) -> (f32[]) {
  %p3.22 = f32[] parameter(3), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_para.py" source_line=9}
  %constant.20 = f32[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %constant.19 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %multiply.21 = f32[] multiply(f32[] %constant.20, f32[] %constant.19), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_para.py" source_line=11}
  %add.23 = f32[] add(f32[] %p3.22, f32[] %multiply.21), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_para.py" source_line=11}
  %constant.17 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %constant.16 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %multiply.18 = f32[] multiply(f32[] %constant.17, f32[] %constant.16), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_para.py" source_line=11}
  %add.24 = f32[] add(f32[] %add.23, f32[] %multiply.18), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_para.py" source_line=11}
  %p2.14 = f32[] parameter(2), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_para.py" source_line=11}
  %constant.13 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %multiply.15 = f32[] multiply(f32[] %p2.14, f32[] %constant.13), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_para.py" source_line=11}
  %add.25 = f32[] add(f32[] %add.24, f32[] %multiply.15), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_para.py" source_line=11}
  %p1.11 = f32[] parameter(1), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_para.py" source_line=11}
  %constant.10 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %multiply.12 = f32[] multiply(f32[] %p1.11, f32[] %constant.10), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_para.py" source_line=11}
  %add.26 = f32[] add(f32[] %add.25, f32[] %multiply.12), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_para.py" source_line=11}
  %p0.8 = f32[] parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<module>@test_para.py" source_line=11}
  %constant.7 = f32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<module>@test_para.py" source_line=11}
  %multiply.9 = f32[] multiply(f32[] %p0.8, f32[] %constant.7), metadata={op_type="aten__mul" op_name="aten__mul" source_file="<module>@test_para.py" source_line=11}
  %add.27 = f32[] add(f32[] %add.26, f32[] %multiply.9), metadata={op_type="aten__add" op_name="aten__add" source_file="<module>@test_para.py" source_line=11}
  ROOT %tuple.28 = (f32[]) tuple(f32[] %add.27)
}

ENTRY %SyncTensorsGraph.23.31 (in.1: (f32[], f32[], f32[], f32[])) -> (f32[]) {
  %in.1 = (f32[], f32[], f32[], f32[]) parameter(0)
  %get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[], f32[], f32[]) %in.1), index=0
  %get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[], f32[], f32[]) %in.1), index=1
  %get-tuple-element.4 = f32[] get-tuple-element((f32[], f32[], f32[], f32[]) %in.1), index=2
  %get-tuple-element.5 = f32[] get-tuple-element((f32[], f32[], f32[], f32[]) %in.1), index=3
  ROOT %call.29 = (f32[]) call(f32[] %get-tuple-element.2, f32[] %get-tuple-element.3, f32[] %get-tuple-element.4, f32[] %get-tuple-element.5), to_apply=%SyncTensorsGraph.6
}

We added another function as entry to forward the parameter to real function. I verified this unblock the issue we saw when scaling up.

PJRT handles tupling input buffers, caller only need to handle wrapping the input for hlo text.

TODO:
- [ ] add XLA_PARAMETER_WRAPPING_THREADSHOLD value to the graph hash otherwise we might have hash collision(tuple input version and non-tuple version)

on a second thought, given tupling doesn't affect speed too much(if at all), this should not be an issue. Even if someone manually modify XLA_PARAMETER_WRAPPING_THREADSHOLD and incorrectly hit the cache (expect non-tupling but see tupling) the program will still run correctly because from caller perspective we still pass all parameter as vector.
On top of this, we currently do not support long lasting cache for PJRT, so this is actually impossible to happen.

@JackCaoG JackCaoG force-pushed the tuplify_parameters branch from cc2b4b4 to ceb994e Compare August 25, 2022 02:11
@JackCaoG
Copy link
Collaborator Author

It is working, need to refactor the code a bit and make tupliying_threadshold configable. I will also test the perfomrance implication for tupling.

@JackCaoG JackCaoG changed the title [DRAFT] Helper function to wrap entry HLO Helper function to wrap entry HLO Aug 26, 2022
@JackCaoG JackCaoG requested a review from will-cromar August 26, 2022 02:37
@JackCaoG
Copy link
Collaborator Author

TODO: we need to add XLA_PARAMETER_WRAPPING_THREADSHOLD value to the graph hash otherwise we might have hash collision(tuple input version and non-tuple version)

@JackCaoG
Copy link
Collaborator Author

resnet without tupling

| Training Device=xla:1/5 Epoch=1 Step=140 Loss=0.01358 Rate=646.51 GlobalRate=156.10 Time=02:42:19
| Training Device=xla:1/7 Epoch=1 Step=140 Loss=0.01358 Rate=646.59 GlobalRate=156.96 Time=02:42:19
| Training Device=xla:0/0 Epoch=1 Step=160 Loss=0.01136 Rate=646.96 GlobalRate=174.58 Time=02:42:23
| Training Device=xla:1/5 Epoch=1 Step=160 Loss=0.01136 Rate=646.85 GlobalRate=172.34 Time=02:42:23
| Training Device=xla:1/3 Epoch=1 Step=160 Loss=0.01136 Rate=646.88 GlobalRate=176.81 Time=02:42:23
| Training Device=xla:1/7 Epoch=1 Step=160 Loss=0.01136 Rate=647.09 GlobalRate=173.26 Time=02:42:23
| Training Device=xla:0/4 Epoch=1 Step=160 Loss=0.01136 Rate=646.57 GlobalRate=172.55 Time=02:42:23
| Training Device=xla:1/1 Epoch=1 Step=160 Loss=0.01136 Rate=646.58 GlobalRate=175.05 Time=02:42:23
| Training Device=xla:0/2 Epoch=1 Step=160 Loss=0.01136 Rate=646.70 GlobalRate=176.84 Time=02:42:23
| Training Device=xla:0/6 Epoch=1 Step=160 Loss=0.01136 Rate=646.73 GlobalRate=173.34 Time=02:42:23

resnet with tupling (by forcing a very low XLA_PARAMETER_WRAPPING_THREADSHOLD)

| Training Device=xla:1/3 Epoch=1 Step=880 Loss=0.00143 Rate=649.20 GlobalRate=428.17 Time=02:48:00
| Training Device=xla:0/0 Epoch=1 Step=880 Loss=0.00143 Rate=649.18 GlobalRate=432.39 Time=02:48:00
| Training Device=xla:1/5 Epoch=1 Step=880 Loss=0.00143 Rate=649.35 GlobalRate=426.60 Time=02:48:00
| Training Device=xla:0/6 Epoch=1 Step=880 Loss=0.00143 Rate=649.19 GlobalRate=426.77 Time=02:48:00
| Training Device=xla:0/2 Epoch=1 Step=880 Loss=0.00143 Rate=649.12 GlobalRate=427.96 Time=02:48:00
| Training Device=xla:1/1 Epoch=1 Step=880 Loss=0.00143 Rate=649.17 GlobalRate=433.22 Time=02:48:00
| Training Device=xla:0/4 Epoch=1 Step=880 Loss=0.00143 Rate=648.96 GlobalRate=426.32 Time=02:48:00

doesn't seems like parameter tupling will affect speed too much(if at all)

@JackCaoG JackCaoG force-pushed the tuplify_parameters branch from ca390fd to bf9b61e Compare August 26, 2022 17:09
@JackCaoG JackCaoG requested a review from alanwaketan August 26, 2022 22:22
Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.


// Call the original computation.
xla::XlaOp orig_result;
orig_result = xla::Call(&builder, computation, inner_params);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine these two lines into one?

}();

// Handle the results of the original computation.
const std::vector<xla::XlaOp> inner_params = [&input_tuple,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my education, what's the advantage of defining these parameters in such way, i.e, using lambdas?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh.. This part of code was modified from a legacy internal code. From what I can tell this just keeps all of the intermediate variable within the lambda scope. I will remove it as we don't use this kind of style in other part of pt/xla.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First time seeing such technique actually. Maybe under some circumstances, it's good to free memory gradually in the scope of function instead of all at once during return... If just for readability, block should do the trick I guess?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, I think so. The only thing I can think of(why this is preferred than a block) is that one must specified what variable from outer block can be accessed. I don't think it fits code style for this repo too good so I will remove it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair. (TBH, I don't use blocks often myself...)

@JackCaoG JackCaoG merged commit 41a4f85 into master Aug 27, 2022
@JackCaoG
Copy link
Collaborator Author

@ronghanghu Latest wheel should have this fix, let me know if this fixed your issue.(if you set PJRT_DEVICE+TPU you don't need any other configs).

@ronghanghu
Copy link
Collaborator

Thanks @JackCaoG! I'll try this out.

@ronghanghu
Copy link
Collaborator

@JackCaoG Thanks for the fix here. I confirm that this resolves our previous issue of Computation requires more parameters (4748) than supported (limit 3304) on very deep networks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants