diff --git a/torchao/prototype/quantization/autoquant_v2.py b/torchao/prototype/quantization/autoquant_v2.py index 31e20b6b2f..a11fe861e4 100644 --- a/torchao/prototype/quantization/autoquant_v2.py +++ b/torchao/prototype/quantization/autoquant_v2.py @@ -56,9 +56,6 @@ target_folder = "/home/jerryzh/local/tmp/20241104_dynamo_test" -prepare_target_folder(target_folder) - - __all__ = [ "AutoQuantizableLinearWeight", "autoquant_v2", @@ -128,19 +125,22 @@ def update_cache(gm, cls, shapes_and_dtype, res): # adjust each input's bsz to target_bsz # enable grad +# a hacky solution but should work in the use cases we are testing now +# we went through the list of sizes and swap the dimension that matches extracted_bsz to target_bsz def resize_input(t, extracted_bsz, target_bsz): if len(t.shape) > 1: - old_first_dim, old_second_dim, old_rest = t.size()[0], t.size()[1], t.size()[2:] - assert old_first_dim == 1 - assert ( - old_second_dim % extracted_bsz == 0 - ), f"unexpected old_first_dim {old_first_dim} target_bsz {target_bsz}" - new_second_dim = old_second_dim // extracted_bsz * target_bsz - new_shape = (old_first_dim, new_second_dim, *old_rest) + new_shape = [] + for i in range(len(t.size())): + if t.size(i) == extracted_bsz: + new_shape.append(target_bsz) + else: + new_shape.append(t.size(i)) t = torch.randn(*new_shape, dtype=t.dtype, device=t.device) return t +# a hacky solution but should work in the use cases we are testing now +# we went through the list of sizes and swap the dimension that matches extracted_bsz to target_bsz def maybe_adjust_model_bsz(m, extracted_bsz, target_bsz): """ Makes guesses on how to adjust the model graph to account for the @@ -148,9 +148,13 @@ def maybe_adjust_model_bsz(m, extracted_bsz, target_bsz): """ for n in m.graph.nodes: if n.op == "call_method" and n.target == "view": - if n.args[2] == extracted_bsz: - new_args = (*n.args[:2], target_bsz, *n.args[3:]) - n.args = new_args + new_args = [] + for arg in n.args: + if arg == extracted_bsz: + new_args.append(target_bsz) + else: + new_args.append(arg) + n.args = tuple(new_args) m.recompile() @@ -181,6 +185,7 @@ def __new__( fqn=None, example_inputs=None, fqn_to_submodule=None, + batch_size=None, **kwargs, ): kwargs["device"] = weight.device @@ -204,6 +209,7 @@ def __init__( fqn=None, example_inputs=None, fqn_to_submodule=None, + batch_size=None, **kwargs, ): self.weight = weight @@ -214,6 +220,7 @@ def __init__( self.fqn = fqn self.example_inputs = example_inputs self.fqn_to_submodule = fqn_to_submodule + self.batch_size = batch_size def __repr__(self): return ( @@ -236,7 +243,7 @@ def log_shape(act_mat, w_autoquant, bias): ) def tune_autoquant2( - self, fqn, m, inputs, q_cls, shapes_and_dtype, time_for_best_shape + self, fqn, m, batch_size, inputs, q_cls, shapes_and_dtype, time_for_best_shape ): act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype @@ -248,8 +255,8 @@ def tune_autoquant2( linear_module = module weight = q_cls.from_float(linear_module.weight) linear_module.weight = torch.nn.Parameter(weight, requires_grad=False) - if LLAMA: - extracted_bsz = 256 + if batch_size is not None: + extracted_bsz = batch_size target_bsz = act_shape[0] inputs = tree_map( lambda t: resize_input(t, extracted_bsz, target_bsz), inputs @@ -329,7 +336,7 @@ def count_shapes(self, do_print=True): else time_for_best_shape ) self.tune_autoquant2( - fqn, m, inputs, q_cls, shapes_and_dtype, time_for_best_shape + fqn, m, self.batch_size, inputs, q_cls, shapes_and_dtype, time_for_best_shape ) ran_new_benchmarks = True torch._dynamo.reset() @@ -368,6 +375,7 @@ def _apply_fn_to_data(self, fn): fqn=self.fqn, example_inputs=self.example_inputs, fqn_to_submodule=self.fqn_to_submodule, + batch_size=self.batch_size, ) def __tensor_flatten__(self): @@ -378,6 +386,7 @@ def __tensor_flatten__(self): self.fqn, self.example_inputs, self.fqn_to_submodule, + self.batch_size, self.dtype, self.shape, ] @@ -394,6 +403,7 @@ def __tensor_unflatten__( fqn, example_inputs, fqn_to_submodule, + batch_size, dtype, shape, ) = tensor_attributes @@ -405,6 +415,7 @@ def __tensor_unflatten__( fqn=fqn, example_inputs=example_inputs, fqn_to_submodule=fqn_to_submodule, + batch_size=batch_size, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride, @@ -480,16 +491,6 @@ def do_autoquant_bench(op, *args, **kwargs): return res -@torch.no_grad() -def do_autoquant_bench2(model, *args, **kwargs): - rep = kwargs.pop("rep", 200) - warmup = kwargs.pop("warmup", 30) - - torch._dynamo.reset() - benchmark_model(model, warmup, args, kwargs) - return benchmark_model(model, rep, args, kwargs) - - def _is_interpolate_mode(mode): if ( isinstance(mode, list) @@ -997,7 +998,7 @@ def dict_union(*args): def _change_linears_to_autoquantizable( - model, example_input, fqn_to_submodule, **kwargs + model, example_input, fqn_to_submodule, batch_size, **kwargs ): """ Converts all linear weight tensors to the @@ -1017,6 +1018,7 @@ def _change_linears_to_autoquantizable( kwargs["model"] = model kwargs["example_inputs"] = example_input kwargs["fqn_to_submodule"] = fqn_to_submodule + kwargs["batch_size"] = batch_size from torchao.quantization.quant_api import _get_subclass_inserter _replace_with_custom_fn_if_matches_filter( @@ -1090,6 +1092,7 @@ def autoquant_v2( manual=False, set_inductor_config=True, supress_autoquant_errors=True, + batch_size=None, **aq_kwargs, ): """ @@ -1151,6 +1154,7 @@ def autoquant_v2( assert example_input is not None + prepare_target_folder(target_folder) torch._dynamo.reset() # TODO: explore using node.meta to retrieve the subgraph and fqn information # disable nn module inlining, our subgraph extraction logic depends on this @@ -1168,6 +1172,8 @@ def autoquant_v2( else: raise Exception("Unexpected example_input:", example_input) + torch._inductor.config.pre_grad_custom_pass = None + # verify debug logs and summary got saved assert os.path.isfile( os.path.join(target_folder, "debug_logs_0.txt") @@ -1221,6 +1227,7 @@ def autoquant_v2( model, example_input, fqn_to_submodule, + batch_size, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode,