Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Add apply_backward_optimization_barrier #6097

Merged
merged 3 commits into from
Dec 12, 2023

Conversation

alanwaketan
Copy link
Collaborator

@alanwaketan alanwaketan commented Dec 11, 2023

Summary:
This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients.

It's also used in pytorch-tpu/transformers#50.

Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier

Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice one! LGTM

# The first layer won't have gradients in the hook. Not sure why.
xs.xla_sharding.apply_backward_optimization_barrier(model.fc2)

# optimizer.zero_grad()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should remove this. oops...

@alanwaketan
Copy link
Collaborator Author

Thanks Jon for approving the pull request.

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

Summary:
This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier
where registers a full backward hook that apply an optimization barrier to the given module.
This API will prevent the XLA compiler from fusing the module's backward pass with others.
And It's useful to prevent gigantic buffers being allocated to synchronize the gradients.

Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier
@alanwaketan alanwaketan force-pushed the alanwaketan/opt-barrier branch from 23d1272 to cb3bad6 Compare December 12, 2023 04:09
@alanwaketan alanwaketan merged commit 07540f2 into master Dec 12, 2023
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
Summary:
This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients.

It's also used in pytorch-tpu/transformers#50.

Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
Summary:
This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients.

It's also used in pytorch-tpu/transformers#50.

Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
Summary:
This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients.

It's also used in pytorch-tpu/transformers#50.

Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants