Skip to content

Commit

Permalink
Remove is_gpt_fast flag (pytorch#172)
Browse files Browse the repository at this point in the history
Summary:
It was added before to merge the code for 8da4w and int4 weight only quant, but later we just duplicated the quantizer code, so we can safely remove this now.

in the future we'll refactor everything to use tensor subclass.

Test Plan:
tested locally to make sure `test_8da4w_quantizer_eval` still works

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Apr 24, 2024
1 parent 7fba313 commit 9e5d9cb
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 124 deletions.
51 changes: 0 additions & 51 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,57 +268,6 @@ def test_8da4w_quantizer_eval(self):
f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_gpt_fast(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
)
blocksize = 128
percdamp = 0.01
groupsize = 128
calibration_tasks = ["wikitext"]
calibration_limit = 1
calibration_seq_length = 100
input_prep_func = prepare_inputs_for_model
pad_calibration_inputs = False

inputs = InputRecorder(
tokenizer,
calibration_seq_length,
input_prep_func,
pad_calibration_inputs,
model.config.vocab_size,
).record_inputs(
calibration_tasks,
calibration_limit,
).get_inputs()

quantizer = Int8DynActInt4WeightGPTQQuantizer(
blocksize,
percdamp,
groupsize,
_is_gpt_fast=True,
_use_cuda=True,
)

model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)

model = quantizer.quantize(model, inputs)
compiled = torch.compile(model, mode="max-autotune")
with torch.no_grad():
compiled(inputs[0].values[0], inputs[1].values[0])

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper
Expand Down
92 changes: 19 additions & 73 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,17 +1176,8 @@ def __init__(
padding_allowed: bool = False,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
inner_k_tiles: Optional[int] = None,
_is_gpt_fast: bool = False,
) -> None:
super().__init__()
if _is_gpt_fast:
assert inner_k_tiles in [2, 4, 8]
assert groupsize in [32, 64, 128, 256]
else:
assert inner_k_tiles is None
self._is_gpt_fast = _is_gpt_fast
self.inner_k_tiles = inner_k_tiles
self.groupsize: int = groupsize
self.padding_allowed: bool = padding_allowed
self.precision: torch.dtype = precision
Expand All @@ -1210,9 +1201,7 @@ def _create_quantized_state_dict(
), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0"

weight = mod.weight.data
if not _check_linear_int4_k(
in_features, self.groupsize, self.inner_k_tiles
):
if not _check_linear_int4_k(in_features, self.groupsize):
if self.padding_allowed:
from .utils import find_multiple
import torch.nn.functional as F
Expand All @@ -1233,36 +1222,21 @@ def _create_quantized_state_dict(
self.groupsize,
self.scales_precision,
)
if self._is_gpt_fast:
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int8.to(torch.int32), self.inner_k_tiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
else:
cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu")
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu")
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
# TODO: support bias?

return cur_state_dict

def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
if self._is_gpt_fast:
# TODO: temporary path for gpt-fast, will remove later
replace_linear_int4(
model,
self.groupsize,
self.inner_k_tiles,
self.padding_allowed,
)
else:
replace_linear_8da4w(
model,
self.groupsize,
self.padding_allowed,
self.precision,
self.precision,
)
replace_linear_8da4w(
model,
self.groupsize,
self.padding_allowed,
self.precision,
self.precision,
)
return model

def quantize(
Expand All @@ -1284,9 +1258,7 @@ def __init__(
inner_k_tiles=8,
padding_allowed=True,
precision=torch.float32,
_is_gpt_fast=False,
):
self._is_gpt_fast = _is_gpt_fast
self.blocksize = blocksize
self.percdamp = percdamp
self.groupsize = groupsize
Expand Down Expand Up @@ -1327,23 +1299,6 @@ def __init__(
)

# we need to do the padding here, both for q and the qparams if necessary

# TODO: this is the gpt-fast version, merge with the main version later
def make_names_and_values_dict_func_gpt_fast(q, qparams):
k = q.shape[1]
new_k = find_multiple(k, 1024)
# how much we need to pad the weight
delta_k = new_k - q.shape[1]
q = q.to(torch.int32)
final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
scales = qparams[0].to(torch.bfloat16)
zeros = qparams[1].to(torch.bfloat16)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
# how many new groups we need for padded weight
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1)
return {"weight": final_q, "scales_and_zeros": final_s_and_z}

def make_names_and_values_dict_func(q, qparams):
k = q.shape[1]
new_k = find_multiple(k, 1 if groupsize is None else groupsize)
Expand All @@ -1354,26 +1309,17 @@ def make_names_and_values_dict_func(q, qparams):
zeros = qparams[1].to(self.precision)
return {"weight": final_q, "scales": scales, "zeros": zeros}

self.make_names_and_values_dict_func = make_names_and_values_dict_func_gpt_fast if self._is_gpt_fast else make_names_and_values_dict_func
self.make_names_and_values_dict_func = make_names_and_values_dict_func
super().__init__()

def _convert_for_runtime(self, model):
if self._is_gpt_fast:
# TODO: temporary path for gpt-fast, will remove later
replace_linear_int4(
model,
self.groupsize,
self.inner_k_tiles,
self.padding_allowed,
)
else:
replace_linear_8da4w(
model,
self.groupsize,
self.padding_allowed,
self.precision,
self.precision,
)
replace_linear_8da4w(
model,
self.groupsize,
self.padding_allowed,
self.precision,
self.precision,
)
return model

def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module:
Expand Down

0 comments on commit 9e5d9cb

Please sign in to comment.