Skip to content

Commit

Permalink
adding default inductor config settings (pytorch#423)
Browse files Browse the repository at this point in the history
* adding default inductor config settings

Summary:

making autoquant and quantize apis call a new
recommended_inductor_config_setter util to set recommended apis

also update groupsize -> groupsize in generate.py

Test Plan:

sh benchmarks.sh

comparison of different config combinations for matmul precision,
mixed_mm and coordinate_descent

tok/s=  9.14, mem/s=  60.55 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf,
tok/s=147.02, mem/s= 973.53 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf,
tok/s=  9.23, mem/s=  61.11 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf,
tok/s=139.59, mem/s= 924.33 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf,
tok/s=  9.10, mem/s=  60.26 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf,
tok/s=146.98, mem/s= 973.23 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf,
tok/s=  9.28, mem/s=  61.48 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf,
tok/s=146.90, mem/s= 972.73 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf,
tok/s=  9.08, mem/s=  60.09 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf,
tok/s=137.58, mem/s= 911.00 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf,
tok/s=  9.19, mem/s=  60.87 GB/s, peak_mem= 8.61 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf,
tok/s=166.02, mem/s=1099.30 GB/s, peak_mem= 8.97 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf,

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing tests

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fix weight only failures

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing new broken test

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing autoquant test

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* testing if inductor config is the issue

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* are inductor configs somehow being set?

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* when is coordinate descent tuning beinng enabled?

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* reset inductor config for tests

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* more test fixes

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* adding warning

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* handling of errors

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* option to supress autoquant errors

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles authored Jun 25, 2024
1 parent ce337bb commit 7b03ef3
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 36 deletions.
37 changes: 27 additions & 10 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,21 @@

def _int8wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8_weight_only())
quantize(mod, int8_weight_only(), set_inductor_config=False)
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_woqtensors(mod)

def _int8da_int8w_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8_dynamic_activation_int8_weight())
quantize(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_dqtensors(mod)

def _int4wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int4_weight_only())
quantize(mod, int4_weight_only(), set_inductor_config=False)
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod)
Expand All @@ -124,6 +124,13 @@ def _int4wo_api(mod):
_int4wo_api,
]

def undo_recommended_configs():
torch._inductor.config.coordinate_descent_tuning = False
torch._inductor.config.coordinate_descent_check_all_directions = False
torch._inductor.config.force_fuse_int_mm_with_mul = False
torch._inductor.config.fx_graph_cache = False
torch._inductor.config.triton.unique_kernel_names = False
torch.set_float32_matmul_precision("highest")

def combine_parameters(a, b):
new_tuples = []
Expand Down Expand Up @@ -689,6 +696,7 @@ def test_int8_dynamic_quant_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_weight_only_quant_subclass(self, device, dtype):
undo_recommended_configs()
self._test_lin_weight_subclass_impl(
Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype
)
Expand Down Expand Up @@ -794,6 +802,7 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_int8_weight_only_quant_subclass_api(self, device, dtype):
undo_recommended_configs()
self._test_lin_weight_subclass_api_impl(
_int8wo_api, device, 40, test_dtype=dtype
)
Expand Down Expand Up @@ -879,6 +888,7 @@ def test_weight_only_quant(self):
@torch.no_grad()
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_weight_only_quant_force_mixed_mm(self, device, dtype):
undo_recommended_configs()
if device != "cuda":
self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}")
if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
Expand Down Expand Up @@ -907,6 +917,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_weight_only_quant_use_mixed_mm(self, device, dtype):
undo_recommended_configs()
if device != "cuda":
self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}")
if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
Expand Down Expand Up @@ -1004,6 +1015,7 @@ def test_save_load_dqtensors(self, device, dtype):
@torch.no_grad()
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_save_load_int8woqtensors(self, device, dtype):
undo_recommended_configs()
self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down Expand Up @@ -1153,6 +1165,7 @@ class TestAutoQuant(unittest.TestCase):
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_one_input(self, device, dtype, m, k, n):
undo_recommended_configs()
print("(m, k, n): ", (m, k, n))
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
Expand All @@ -1173,7 +1186,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
torch.nn.ReLU(),
).to(device).to(dtype)
out = model(example_input)
torchao.autoquant(model)
torchao.autoquant(model, set_inductor_config=False)
out2 = model(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)
Expand All @@ -1186,6 +1199,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
undo_recommended_configs()
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
Expand All @@ -1202,7 +1216,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
example_input2 = torch.randn(m2, k, device=device, dtype=dtype)
out = model(example_input)

mod = torchao.autoquant(torch.compile(model), manual=True)
mod = torchao.autoquant(torch.compile(model), manual=True, set_inductor_config=False)
mod(example_input)
mod(example_input2)
mod.finalize_autoquant()
Expand All @@ -1214,6 +1228,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_manual(self, device, dtype):
undo_recommended_configs()
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
Expand All @@ -1229,15 +1244,15 @@ def test_autoquant_manual(self, device, dtype):
example_input2 = torch.randn(m2, k, device=device, dtype=dtype)
out = model(example_input)

mod = torchao.autoquant(torch.compile(model), manual=True)
mod = torchao.autoquant(torch.compile(model), manual=True, set_inductor_config=False)
mod(example_input)
mod(example_input2)
mod.finalize_autoquant()
out2 = mod(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

mod2 = torchao.autoquant(model, manual=True)
mod2 = torchao.autoquant(model, manual=True, set_inductor_config=False)
mod2(example_input)
mod2(example_input2)
mod2.finalize_autoquant()
Expand All @@ -1254,6 +1269,7 @@ def test_autoquant_manual(self, device, dtype):
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n):
undo_recommended_configs()
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
Expand All @@ -1280,7 +1296,7 @@ def forward(self, x, y):
}
out = model(**example_input)

mod = torchao.autoquant(torch.compile(model))
mod = torchao.autoquant(torch.compile(model), set_inductor_config=False)
mod(**example_input)

out2 = mod(**example_input)
Expand All @@ -1293,6 +1309,7 @@ def forward(self, x, y):
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_double_access(self, device, dtype, m, k, n):
undo_recommended_configs()
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
Expand All @@ -1316,7 +1333,7 @@ def forward(self, x):
x_in = torch.randn(m, k, device=device, dtype=dtype)
model = DoubleAccess().to(device).to(dtype)
model(x_in)
torchao.autoquant(model)
torchao.autoquant(model, set_inductor_config=False)
assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight)
model(x_in)

Expand Down Expand Up @@ -1443,7 +1460,7 @@ def test_get_model_size_autoquant(self, device, dtype):
qtensor_class_list = (
AQWeightOnlyQuantizedLinearWeight2,
)
mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list)
mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list, set_inductor_config=False)
mod(example_input)
size2 = torchao.utils.get_model_size_in_bytes(mod)
self.assertTrue(size2 < size)
Expand Down
6 changes: 3 additions & 3 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from torchao._models.llama.model import prepare_inputs_for_model

torch._inductor.config.fx_graph_cache = True
torch._inductor.config.force_fuse_int_mm_with_mul = True

def run_evaluation(
checkpoint_path: Path,
tasks: List[str],
Expand All @@ -41,6 +38,9 @@ def run_evaluation(
pad_calibration_inputs: Optional[bool] = False,
):
"""Runs the evaluation of a model using LM Eval."""

torchao.quantization.utils.recommended_inductor_config_setter()

assert checkpoint_path.is_file(), checkpoint_path
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), str(tokenizer_path)
Expand Down
16 changes: 6 additions & 10 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@ def device_sync(device):
else:
print(f"device={device} is not yet suppported")


torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
torch._inductor.config.force_fuse_int_mm_with_mul = True
# torch._inductor.config.use_mixed_mm = True

default_device = 'cuda' if torch.cuda.is_available() else 'cpu'

# support running without installing as a package
Expand Down Expand Up @@ -163,6 +156,9 @@ def main(
) -> None:
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
"""

torchao.quantization.utils.recommended_inductor_config_setter()

assert checkpoint_path.is_file(), checkpoint_path
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), str(tokenizer_path)
Expand Down Expand Up @@ -203,7 +199,7 @@ def main(
if "int4wo" in quantization:
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize(model, int4_weight_only(groupsize=groupsize))
quantize(model, int4_weight_only(group_size=groupsize))
if "autoquant" == quantization:
model = autoquant(model, manual=True)

Expand Down Expand Up @@ -339,8 +335,8 @@ def callback(x):
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument("--quantization", type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
Expand Down
10 changes: 3 additions & 7 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ of the activations that the different linear layers see, it then benchmarks thes
import torch
import torchao

# inductor settings which improve torch.compile performance for quantized modules
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

# Plug in your model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
Expand Down Expand Up @@ -107,9 +103,6 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune')
group_size = 32
m = quantize(m, int4_weight_only(group_size=group_size))

torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

# temporary workaround for tensor subclass + torch.compile
from torchao.quantization.utils import unwrap_tensor_subclass
m = unwrap_tensor_subclass(m)
Expand Down Expand Up @@ -163,6 +156,9 @@ m = torch.export.export(m_unwrapped, example_inputs).module()
torch._export.aot_compile(m_unwrapped, example_inputs)
```

### Automatic Inductor Configuration
The `quantize` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues.

### Other Available Quantization Techniques
#### A8W8 Dynamic Quantization

Expand Down
35 changes: 30 additions & 5 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torchao
from .subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
Expand Down Expand Up @@ -90,7 +91,11 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time):
with torch.no_grad():
act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device)
bias = None if bias_shape is None else torch.randn(bias_shape, dtype=act_dtype, device=self.device)
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode)
try:
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode)
except Exception as e:
print(f"warning: failed to autoquant {q_cls.__name__} for shape: {shapes_and_dtype} due to {e}")
res = torch.inf
update_cache(q_cls, shapes_and_dtype, res)

@torch.no_grad()
Expand Down Expand Up @@ -407,16 +412,21 @@ def _change_linears_to_autoquantizable(model, **kwargs):
filter_fn if filter_fn is not None else _is_linear,
)

def _change_autoquantizable_to_quantized(model, **kwargs):
def _change_autoquantizable_to_quantized(model, supress_autoquant_errors=True, **kwargs):
"""
Converts AutoQuantizableLinearWeight tensor subclasses
to various quantized/non-quantized tensor subclasses depending
on benchmark results. Expectation is that these modules are
torch.compiled afterwards.
"""
hold = torch._dynamo.config.automatic_dynamic_shapes
hold_automatic_dynamic_shapes = torch._dynamo.config.automatic_dynamic_shapes
torch._dynamo.config.automatic_dynamic_shapes = False

if supress_autoquant_errors:
hold_supress_errors = torch._dynamo.config.suppress_errors
torch._dynamo.config.suppress_errors = True
import logging
torch._logging.set_logs(inductor=logging.CRITICAL, dynamo=logging.CRITICAL)
filter_fn = kwargs.pop(
"filter_fn",
lambda mod, *args:
Expand All @@ -432,7 +442,13 @@ def _change_autoquantizable_to_quantized(model, **kwargs):
),
filter_fn,
)
torch._dynamo.config.automatic_dynamic_shapes = hold
# undo dynamic shape change
torch._dynamo.config.automatic_dynamic_shapes = hold_automatic_dynamic_shapes

# undo error supression
if supress_autoquant_errors:
torch._dynamo.config.suppress_errors = hold_supress_errors
torch._logging.set_logs()
torch._dynamo.reset()

# TODO: example_input seems weird to include in the API
Expand All @@ -443,8 +459,11 @@ def autoquant(
model,
example_input=None,
qtensor_class_list=DEFAULT_CLASS_LIST,
filter_fn=None, mode=["interpolate", .85],
filter_fn=None,
mode=["interpolate", .85],
manual=False,
set_inductor_config=True,
supress_autoquant_errors=True,
**aq_kwargs
):
"""
Expand Down Expand Up @@ -477,6 +496,8 @@ def autoquant(
and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85].
manual (bool, optional): Whether to stop shape calibration and do autoquant after a single run (default, False) or to wait for
the user to call model.finalize_autoquant (True) so inputs with several shapes/dtypes can be logged.
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
supress_autoquant_errors (bool, optional): Whether to suppress errors during autoquantization. (defaults to True)
**aq_kwargs: Additional keyword arguments for the autoquantization process.
Returns:
Expand All @@ -493,6 +514,9 @@ def autoquant(
model(*example_input2)
model.finalize_autoquant()
"""
if set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()


# perform initial swap from linear weights
# to AutoQuantizableLinearWeight
Expand Down Expand Up @@ -539,6 +563,7 @@ def autoquant_prehook(module, args, kwargs):
def finalize_autoquant():
_change_autoquantizable_to_quantized(
real_model,
supress_autoquant_errors,
**aq_kwargs,
)
if hasattr(real_model, "old_forward"):
Expand Down
Loading

0 comments on commit 7b03ef3

Please sign in to comment.