Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve FSDP support for low-bit optimizers #538

Merged
merged 7 commits into from
Jul 26, 2024

Conversation

gau-nernst
Copy link
Collaborator

  • Use DTensor.from_local(run_check=False) to wrap quantized optim state (instead of swapping _local_tensor)
  • Make block_size a fixed attribute, calculated inside __init__ (instead of dynamically calculate every time)
  • Implement all_gather_into_tensor and wait_tensor to support DTensor.full_tensor()

Copy link

pytorch-bot bot commented Jul 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/538

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6cec214 with merge base e8662e0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 25, 2024
@gau-nernst gau-nernst marked this pull request as ready for review July 25, 2024 03:27
@gau-nernst gau-nernst requested review from msaroufim, awgu and drisspg July 25, 2024 03:27
@property
def block_size(self):
return self.codes.numel() * 2 // self.scale.numel()
self.block_size = codes.numel() * 2 // scale.numel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious q: Is there some description of the codes/ scales tensor and their relation to each other?

I can see the pattern that codes has .5x (4bit) and 1x (8bit) the bsize * scale numels
But does this assert square blocks?
I think some description here would be helpful

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add some description. Basically for 8-bit and FP8, codes has the same shape as the "outer shape", while for 4-bit, since there is bit-packing, I find that it's easier to let codes be a flattened 1D buffer and keep track of the shape manually.
To get the scale, the float tensor is actually flattened first and reshape to (-1, block_size). This is done to relax the requirement that the last dimension must be divisible by block_size -> now we only need numel (total size) to be divisible by block_size. This is especially needed when block size is large (8-bit optim uses block_size=2048 as done in bnb). Since optim update is element-wise, we don't really need to care if the original tensor is 1D, 2D, or n-D (well, maybe there is some structure in n-D tensor that flattening it might not be so wise). I believe the original implementation in bnb does this as well.
-> scale is always a 1D tensor, with size=original_tensor.numel() // block_size

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg Added some docs. Lmk if it is still unclear.

@awgu
Copy link
Contributor

awgu commented Jul 25, 2024

The DTensor-related changes look good to me. I wonder, did you ever have to run .full_tensor() in the training path, or was it only used outside e.g. for debugging?

@gau-nernst
Copy link
Collaborator Author

I don't use FSDP at work or personal projects because I don't have access to multi-GPU machines, so can't really answer your question 😅. Only added FSDP support for low bit optimizers due to request from people.

At least in torchtune, I saw that .full_tensor() is used to retrieve full optim state dict before saving checkpoint https://github.com/pytorch/torchtune/blob/0057fe7cf83e14f0b62538a8d4d20719e0a88639/torchtune/utils/_distributed.py#L437 (though it might be unnecessary, or even less efficient, compared to saving each shard separately, provided we resume training with the same setup).

@@ -31,17 +32,29 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape
)

def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape):
"""Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw the link is not valid, can you remove . in the end?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -28,15 +27,25 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool):
)

def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool):
"""Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this one

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! We should look to get the low bit optimizers out of prototype soon!

@msaroufim msaroufim merged commit 4280843 into pytorch:main Jul 26, 2024
13 checks passed
@gau-nernst gau-nernst deleted the improve_low_bit_optim branch July 26, 2024 02:49
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* fix generate for llama3

* switch more things to C

* remove C++ header
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* make --device fast the default

* Update iOS.md (pytorch#517)

* Update iOS.md

* Update iOS.md

* Pip to pip3 (pytorch#504)

* remove macos-12 test

* pip to pip3

* break aoti CI jobs separately (pytorch#500)

* init

* fixes

* more fixes

* fixes

* fix

* fix

* bug fix

* add objcopy update

* suppress int8

* undefined variable

---------

Co-authored-by: Michael Gschwind <mikekg@meta.com>

* Support llama3 in chat in run.cpp  (pytorch#486)

* refactor chat runner in preparation for llama3

* add sketch for llama3 prompt template and move to returning tokens

* fix tiktoken

* fixes to chat

* add default llama_ver

* Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519)

* remove code for no KV Cache path (pytorch#527)

* Update ADVANCED-USERS.md (pytorch#529)

Update Advanced Users description to reflect changes in the repo since the description was initially created.

* runner-aoti on cuda (pytorch#531)

* runner-aoti on cuda

* transfer results back to CPU

* transfer results back to CPU

* runner-aoti on cuda

* Update runner_build.md (pytorch#530)

Update description of runner and build process in runner_build.md

* clean up runner code a little (pytorch#532)

* clean up runner code a little

* update

* update

* pull out generate loop in chat

* updates

* edit docs

* typo

* move int8 linear class and function into qops.py (pytorch#534)

* add dtype tests for runner-aoti + runner-et (pytorch#539)

* add dtype tests for runner-aoti + runner-et

* typo

* Quantized embedding (pytorch#536)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* Move Linear int4 to qops (pytorch#537)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* move int4 linear to qops

* Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548)

This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1.

* fix generate for llama3 (pytorch#538)

* fix generate for llama3

* switch more things to C

* remove C++ header

* add delegation visualization instructions (pytorch#551)

* Add dtype runner aoti (pytorch#552)

* add dtype tests for runner-aoti + runner-et

* typo

* add dtype test runner-aoti

* test sdpa with fp16 (pytorch#553)

* test sdpa with fp16

* kv cache fp32

* typo

* update (pytorch#560)

* Only support newest versions of lm-eval (pytorch#556)

Summary:
remove support for lm-eval 0.3 to reduce the options we have

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* split cpu eval CI by dtype (pytorch#554)

* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix

* Removing duplicate HF issue message from README (pytorch#559)

Co-authored-by: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com>

* doc updates (pytorch#567)

* Add VM-safe MPS check

---------

Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in>
Co-authored-by: metascroy <161522778+metascroy@users.noreply.github.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: lucylq <lfq@meta.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* code beautification

* code beautification, move functions together

* make --device fast the default (pytorch#515)

* make --device fast the default

* Update iOS.md (pytorch#517)

* Update iOS.md

* Update iOS.md

* Pip to pip3 (pytorch#504)

* remove macos-12 test

* pip to pip3

* break aoti CI jobs separately (pytorch#500)

* init

* fixes

* more fixes

* fixes

* fix

* fix

* bug fix

* add objcopy update

* suppress int8

* undefined variable

---------

Co-authored-by: Michael Gschwind <mikekg@meta.com>

* Support llama3 in chat in run.cpp  (pytorch#486)

* refactor chat runner in preparation for llama3

* add sketch for llama3 prompt template and move to returning tokens

* fix tiktoken

* fixes to chat

* add default llama_ver

* Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519)

* remove code for no KV Cache path (pytorch#527)

* Update ADVANCED-USERS.md (pytorch#529)

Update Advanced Users description to reflect changes in the repo since the description was initially created.

* runner-aoti on cuda (pytorch#531)

* runner-aoti on cuda

* transfer results back to CPU

* transfer results back to CPU

* runner-aoti on cuda

* Update runner_build.md (pytorch#530)

Update description of runner and build process in runner_build.md

* clean up runner code a little (pytorch#532)

* clean up runner code a little

* update

* update

* pull out generate loop in chat

* updates

* edit docs

* typo

* move int8 linear class and function into qops.py (pytorch#534)

* add dtype tests for runner-aoti + runner-et (pytorch#539)

* add dtype tests for runner-aoti + runner-et

* typo

* Quantized embedding (pytorch#536)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* Move Linear int4 to qops (pytorch#537)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* move int4 linear to qops

* Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548)

This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1.

* fix generate for llama3 (pytorch#538)

* fix generate for llama3

* switch more things to C

* remove C++ header

* add delegation visualization instructions (pytorch#551)

* Add dtype runner aoti (pytorch#552)

* add dtype tests for runner-aoti + runner-et

* typo

* add dtype test runner-aoti

* test sdpa with fp16 (pytorch#553)

* test sdpa with fp16

* kv cache fp32

* typo

* update (pytorch#560)

* Only support newest versions of lm-eval (pytorch#556)

Summary:
remove support for lm-eval 0.3 to reduce the options we have

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* split cpu eval CI by dtype (pytorch#554)

* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix

* Removing duplicate HF issue message from README (pytorch#559)

Co-authored-by: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com>

* doc updates (pytorch#567)

* Add VM-safe MPS check

---------

Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in>
Co-authored-by: metascroy <161522778+metascroy@users.noreply.github.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: lucylq <lfq@meta.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>

* add unpacking support (pytorch#525)

* add unpacking support

* fix typos and linter

* perform parallel prefill when possible (pytorch#568)

* perform parallel prefill when possible

* typo

* disable hack

* remove print

* remove debug messages which prevent export

* fixes

* stream results in generate.py (#571)

* remove logging interfering with export

---------

Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in>
Co-authored-by: metascroy <161522778+metascroy@users.noreply.github.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: lucylq <lfq@meta.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants