-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Auto-Round support #581
Changes from 22 commits
be78a08
49f8075
62834a2
6433e75
65f46e5
1e22c11
b8d37b9
8d388fb
07a95a0
6baa62f
78a5067
37e9f5f
8bfe76a
e25d6eb
16a901d
5f16e8d
f3442c5
432da79
7ee9f9b
e5ffcca
cec375b
a8f5681
ab08cb3
5ee2e06
a5a3544
e406ee8
7b6908e
6a4d67c
c1fa230
5eef0a6
e4cfa7d
41d9afd
e1cec58
6f20e25
ee1510c
ca5bb30
e01e028
8532af0
5106fe0
b82b638
bb08957
b5f08c5
c8fc3f6
7ee493f
1f75897
48d0903
34e6b49
eeca10b
2b94608
0d38b20
f04b594
0e0b06d
d0a4920
1e8a081
5b3374f
4ef0cdc
5f78c73
5baae13
6feb975
f6ed1e0
03cd9fc
e60b815
b20e6d9
fabe8d2
9ae5392
157c189
2df3f5f
d719460
d7ba39e
a2c6b28
896d87f
6a8e073
9e48d1a
5ca125e
b6d95ce
21686f1
96f745d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
### Usage | ||
> [!NOTE] | ||
> Currently implementation requires installaton of `Auto-round`. | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### Quantize `facebook/opt-125m` with Auto-round | ||
```bash | ||
python autoround_demo.py | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import argparse | ||
|
||
import torchao | ||
|
||
from torchao.prototype.autoround.core import ( | ||
auto_round_config, | ||
MultiTensor, | ||
post_process_model_after_applying_auto_round_, | ||
prepare_model_for_applying_auto_round_, | ||
) | ||
|
||
|
||
def main(args): | ||
# 0. Get the model, tokenizer, and decoder_cls | ||
import torchao.prototype.autoround.utils as ar_utils | ||
|
||
model_name_or_path = args.model_name_or_path | ||
model, tokenizer, decoder_cls = ar_utils.get_float_model_info(model_name_or_path) | ||
# Workaround for disabling the `kv_cache`, which cause the OOM. | ||
model.config.use_cache = False | ||
ar_utils.gen_text(model, tokenizer, "Float model", device="cuda", max_length=50) | ||
|
||
auto_round_config.iters = args.iters | ||
auto_round_config.nsamples = args.nsamples | ||
auto_round_config.seqlen = args.seqlen | ||
|
||
# 1. Prepare the model for applying auto-round | ||
# User should provide the `is_decoder` function for identifying the decoder block | ||
# It can be extended to other modules, such as `lm_head`, the function like: | ||
# is_target_module = lambda mod, fqn: isinstance(mod, decoder_cls) or "lm_head" in fqn | ||
if args.quant_lm_head: | ||
# is_decoder = lambda mod, fqn: "lm_head" in fqn | ||
is_decoder = lambda mod, fqn: isinstance(mod, decoder_cls) or "lm_head" in fqn | ||
else: | ||
is_decoder = lambda mod, fqn: isinstance(mod, decoder_cls) | ||
prepare_model_for_applying_auto_round_(model, is_decoder) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jerryzh168 here is the new flow using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah using MultiTensor sounds fine, and the flow makes sense, see my suggestion on the implementation below |
||
|
||
# 2. Caliration and optimization | ||
dataloader = ar_utils.get_dataloader( | ||
tokenizer, | ||
auto_round_config.seqlen, | ||
seed=auto_round_config.seed, | ||
bs=auto_round_config.train_bs, | ||
nsamples=auto_round_config.nsamples, | ||
) | ||
|
||
input_ids_lst = [] | ||
attn_mask_lst = [] | ||
for i, data in enumerate(dataloader): | ||
input_ids_lst.append(data["input_ids"]) | ||
attn_mask_lst.append(data["attention_mask"]) | ||
|
||
mul_t_input_ids = MultiTensor(input_ids_lst) | ||
mul_t_attn_mask = MultiTensor(attn_mask_lst) | ||
|
||
# The optimization is applied during the forward pass | ||
out = model(mul_t_input_ids, mul_t_attn_mask) | ||
|
||
# 3. Post-process the model after applying auto-round | ||
post_process_model_after_applying_auto_round_(model) | ||
assert ar_utils.has_tensor_of_type(model, torchao.dtypes.AffineQuantizedTensor) | ||
|
||
# 4(Optional). Generate text using the optimized model | ||
ar_utils.gen_text(model, tokenizer, "Quantized model", device="cuda", max_length=50) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
) | ||
parser.add_argument( | ||
"-m", | ||
"--model_name_or_path", | ||
type=str, | ||
default="facebook/opt-125m", | ||
help="Model name or path", | ||
) | ||
parser.add_argument("--seed", default=0, type=int, help="Random seed for torch") | ||
parser.add_argument( | ||
"--iters", default=20, type=int, help="Number of iterations for optimization" | ||
) | ||
parser.add_argument( | ||
"--nsamples", default=128, type=int, help="Number of samples for optimization" | ||
) | ||
parser.add_argument( | ||
"--seqlen", default=2048, type=int, help="Sequence length for optimization" | ||
) | ||
parser.add_argument( | ||
"--quant_lm_head", | ||
default=False, | ||
action="store_true", | ||
help="Quantize the `lm_head` or not", | ||
) | ||
args = parser.parse_args() | ||
main(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think more explanations on the API, summary of perf and accuracy of llama2/llama3 will be helpful, similar to https://github.com/pytorch/ao/tree/main/torchao/prototype/quant_llm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jerryzh168 , I summarized the usage and posted partial results at README.md. More tests are WIP.