Skip to content

Commit

Permalink
Some improvement to make autoquant v2 work with Mixtral-8x7B-v0.1 (#1328
Browse files Browse the repository at this point in the history
)

* Some improvement to make autoquant v2 work with Mixtral-8x7B-v0.1

Summary:
Tested locally running autoquant v2 with llama2-7b and Mixtral-8x7B-v0.1
in https://github.com/pytorch/pytorch/blob/main/benchmarks/gpt_fast/benchmark.py

Llama-2-7b-chat-hf:
Compilation time: 81.71 seconds
Average tokens/sec: 131.12 tokens/sec
Average bandwidth achieved: 1732.77 GB/s
Memory used: 27.71 GB

Mixtral-8x7B-v0.1:
Compilation time: 108.89 seconds
Average tokens/sec: 79.59 tokens/sec
Average bandwidth achieved: 1025.14 GB/s
Memory used: 61.62 GB

more result can be found in pytorch/pytorch#140627

Test Plan:
local test with pytorch/benchmarks/gpt_fast/benchmark.py

Reviewers:

Subscribers:

Tasks:

Tags:

* remove print
  • Loading branch information
jerryzh168 authored Nov 22, 2024
1 parent f3c1a00 commit 7c3c51f
Showing 1 changed file with 35 additions and 28 deletions.
63 changes: 35 additions & 28 deletions torchao/prototype/quantization/autoquant_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@

target_folder = "/home/jerryzh/local/tmp/20241104_dynamo_test"

prepare_target_folder(target_folder)


__all__ = [
"AutoQuantizableLinearWeight",
"autoquant_v2",
Expand Down Expand Up @@ -128,29 +125,36 @@ def update_cache(gm, cls, shapes_and_dtype, res):

# adjust each input's bsz to target_bsz
# enable grad
# a hacky solution but should work in the use cases we are testing now
# we went through the list of sizes and swap the dimension that matches extracted_bsz to target_bsz
def resize_input(t, extracted_bsz, target_bsz):
if len(t.shape) > 1:
old_first_dim, old_second_dim, old_rest = t.size()[0], t.size()[1], t.size()[2:]
assert old_first_dim == 1
assert (
old_second_dim % extracted_bsz == 0
), f"unexpected old_first_dim {old_first_dim} target_bsz {target_bsz}"
new_second_dim = old_second_dim // extracted_bsz * target_bsz
new_shape = (old_first_dim, new_second_dim, *old_rest)
new_shape = []
for i in range(len(t.size())):
if t.size(i) == extracted_bsz:
new_shape.append(target_bsz)
else:
new_shape.append(t.size(i))
t = torch.randn(*new_shape, dtype=t.dtype, device=t.device)
return t


# a hacky solution but should work in the use cases we are testing now
# we went through the list of sizes and swap the dimension that matches extracted_bsz to target_bsz
def maybe_adjust_model_bsz(m, extracted_bsz, target_bsz):
"""
Makes guesses on how to adjust the model graph to account for the
fact that we changed the batch size. Note: this is very brittle
"""
for n in m.graph.nodes:
if n.op == "call_method" and n.target == "view":
if n.args[2] == extracted_bsz:
new_args = (*n.args[:2], target_bsz, *n.args[3:])
n.args = new_args
new_args = []
for arg in n.args:
if arg == extracted_bsz:
new_args.append(target_bsz)
else:
new_args.append(arg)
n.args = tuple(new_args)

m.recompile()

Expand Down Expand Up @@ -181,6 +185,7 @@ def __new__(
fqn=None,
example_inputs=None,
fqn_to_submodule=None,
batch_size=None,
**kwargs,
):
kwargs["device"] = weight.device
Expand All @@ -204,6 +209,7 @@ def __init__(
fqn=None,
example_inputs=None,
fqn_to_submodule=None,
batch_size=None,
**kwargs,
):
self.weight = weight
Expand All @@ -214,6 +220,7 @@ def __init__(
self.fqn = fqn
self.example_inputs = example_inputs
self.fqn_to_submodule = fqn_to_submodule
self.batch_size = batch_size

def __repr__(self):
return (
Expand All @@ -236,7 +243,7 @@ def log_shape(act_mat, w_autoquant, bias):
)

def tune_autoquant2(
self, fqn, m, inputs, q_cls, shapes_and_dtype, time_for_best_shape
self, fqn, m, batch_size, inputs, q_cls, shapes_and_dtype, time_for_best_shape
):
act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype

Expand All @@ -248,8 +255,8 @@ def tune_autoquant2(
linear_module = module
weight = q_cls.from_float(linear_module.weight)
linear_module.weight = torch.nn.Parameter(weight, requires_grad=False)
if LLAMA:
extracted_bsz = 256
if batch_size is not None:
extracted_bsz = batch_size
target_bsz = act_shape[0]
inputs = tree_map(
lambda t: resize_input(t, extracted_bsz, target_bsz), inputs
Expand Down Expand Up @@ -329,7 +336,7 @@ def count_shapes(self, do_print=True):
else time_for_best_shape
)
self.tune_autoquant2(
fqn, m, inputs, q_cls, shapes_and_dtype, time_for_best_shape
fqn, m, self.batch_size, inputs, q_cls, shapes_and_dtype, time_for_best_shape
)
ran_new_benchmarks = True
torch._dynamo.reset()
Expand Down Expand Up @@ -368,6 +375,7 @@ def _apply_fn_to_data(self, fn):
fqn=self.fqn,
example_inputs=self.example_inputs,
fqn_to_submodule=self.fqn_to_submodule,
batch_size=self.batch_size,
)

def __tensor_flatten__(self):
Expand All @@ -378,6 +386,7 @@ def __tensor_flatten__(self):
self.fqn,
self.example_inputs,
self.fqn_to_submodule,
self.batch_size,
self.dtype,
self.shape,
]
Expand All @@ -394,6 +403,7 @@ def __tensor_unflatten__(
fqn,
example_inputs,
fqn_to_submodule,
batch_size,
dtype,
shape,
) = tensor_attributes
Expand All @@ -405,6 +415,7 @@ def __tensor_unflatten__(
fqn=fqn,
example_inputs=example_inputs,
fqn_to_submodule=fqn_to_submodule,
batch_size=batch_size,
shape=shape if outer_size is None else outer_size,
dtype=dtype,
strides=outer_stride,
Expand Down Expand Up @@ -480,16 +491,6 @@ def do_autoquant_bench(op, *args, **kwargs):
return res


@torch.no_grad()
def do_autoquant_bench2(model, *args, **kwargs):
rep = kwargs.pop("rep", 200)
warmup = kwargs.pop("warmup", 30)

torch._dynamo.reset()
benchmark_model(model, warmup, args, kwargs)
return benchmark_model(model, rep, args, kwargs)


def _is_interpolate_mode(mode):
if (
isinstance(mode, list)
Expand Down Expand Up @@ -997,7 +998,7 @@ def dict_union(*args):


def _change_linears_to_autoquantizable(
model, example_input, fqn_to_submodule, **kwargs
model, example_input, fqn_to_submodule, batch_size, **kwargs
):
"""
Converts all linear weight tensors to the
Expand All @@ -1017,6 +1018,7 @@ def _change_linears_to_autoquantizable(
kwargs["model"] = model
kwargs["example_inputs"] = example_input
kwargs["fqn_to_submodule"] = fqn_to_submodule
kwargs["batch_size"] = batch_size
from torchao.quantization.quant_api import _get_subclass_inserter

_replace_with_custom_fn_if_matches_filter(
Expand Down Expand Up @@ -1090,6 +1092,7 @@ def autoquant_v2(
manual=False,
set_inductor_config=True,
supress_autoquant_errors=True,
batch_size=None,
**aq_kwargs,
):
"""
Expand Down Expand Up @@ -1151,6 +1154,7 @@ def autoquant_v2(

assert example_input is not None

prepare_target_folder(target_folder)
torch._dynamo.reset()
# TODO: explore using node.meta to retrieve the subgraph and fqn information
# disable nn module inlining, our subgraph extraction logic depends on this
Expand All @@ -1168,6 +1172,8 @@ def autoquant_v2(
else:
raise Exception("Unexpected example_input:", example_input)

torch._inductor.config.pre_grad_custom_pass = None

# verify debug logs and summary got saved
assert os.path.isfile(
os.path.join(target_folder, "debug_logs_0.txt")
Expand Down Expand Up @@ -1221,6 +1227,7 @@ def autoquant_v2(
model,
example_input,
fqn_to_submodule,
batch_size,
filter_fn=filter_fn,
qtensor_class_list=qtensor_class_list,
mode=mode,
Expand Down

0 comments on commit 7c3c51f

Please sign in to comment.