-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPMD] Add apply_backward_optimization_barrier #6097
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice one! LGTM
test/spmd/test_xla_sharding.py
Outdated
# The first layer won't have gradients in the hook. Not sure why. | ||
xs.xla_sharding.apply_backward_optimization_barrier(model.fc2) | ||
|
||
# optimizer.zero_grad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should remove this. oops...
Thanks Jon for approving the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks!
Summary: This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients. Test Plan: python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier
23d1272
to
cb3bad6
Compare
Summary: This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients. It's also used in pytorch-tpu/transformers#50. Test Plan: python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier
Summary: This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients. It's also used in pytorch-tpu/transformers#50. Test Plan: python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier
Summary: This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients. It's also used in pytorch-tpu/transformers#50. Test Plan: python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier
Summary:
This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients.
It's also used in pytorch-tpu/transformers#50.
Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier