Skip to content

Commit

Permalink
Enable 8a4wdq (pytorch#264)
Browse files Browse the repository at this point in the history
Summary:
- Removed Int8DynActInt4Weight code
- Use torchao to achieve the same

Test Plan:
python export.py --quant '{"linear:a8w4dq" : {"groupsize": 128}}'
--checkpoint-path stories110M.pt
--params-path params.json
--output-pte-path /tmp/stories110m_a8w4dq.pte
Run
./build/cmake-out/runner_et /tmp/stories110m_a8w4dq.pte -z
/tmp/tokenizer.bin  -n 200 -t 0

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
kimishpatel authored and malfet committed Jul 17, 2024
1 parent f71a8b9 commit 68ee0b0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 339 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ jobs:
cat ./output_et
echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******** ET: a8w4dq INT4 group-wise quantized *******"
echo "******************************************"
# python export.py --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --output-pte-path ${MODEL_DIR}/${MODEL_NAME}.pte
# python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --pte-path ${MODEL_DIR}/${MODEL_NAME}.pte > ./output_et
python export.py --quant '{"linear:a8w4dq" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --output-pte-path ${MODEL_DIR}/${MODEL_NAME}.pte
python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --pte-path ${MODEL_DIR}/${MODEL_NAME}.pte > ./output_et
# cat ./output_et
echo "tests complete"
Expand Down
351 changes: 15 additions & 336 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,20 @@ def quantize_model(model: nn.Module, device, quantize_options):
).quantized_model()
elif quantizer == "linear:a8w4dq":
linears_quantized = True
model = Int8DynActInt4WeightQuantHandler(
model, device, **q_kwargs
).quantized_model()
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
# Note that Int8DynActInt4WeightQuantizer takes precision as
# arg, which is used to determine the precision/dtype of the output
# That is, if dtype=fp32 than this dynamically quantized linear will
# return output tensor with fp32 dtype.
# Ideally we make this dynamic such that the output dtype is determined
# based on the input dtype, instead of having to instantiate quantizer
# that picks the output dtype.
# Since this require change in torchao, we leave the current state as is
# and use the default precision for Int8DynActInt4WeightQuantizer
# which is fp32.
assert 'groupsize' in list(q_kwargs.keys()), f"a8w4dq quantization option must specify groupsize. Specified options {q_kwargs}"
model = Int8DynActInt4WeightQuantizer(groupsize=q_kwargs['groupsize']
).quantize(model)
elif quantizer == "linear:gptq":
linears_quantized = True
model = WeightOnlyInt4GPTQQuantHandler(
Expand Down Expand Up @@ -968,273 +979,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:


#########################################################################
##### Int8 Dynamic Activations 4 Bit Weights #####


def prepare_int4_weight_and_scales_and_zeros(weight, groupsize, precision):
weight_int8, scales, zeros = group_quantize_tensor_symmetric(
weight,
n_bit=4,
groupsize=groupsize,
precision=precision,
)
# TODO: better API
# weight_int4packed = torch.ops.quantized_decomposed.pack_int4_from_int8(weight_int8)
return weight_int8, scales, zeros


def linear_forward_8da4w(
x, weight_int8, scales, zeros, out_features, groupsize, precision
):
x = per_token_dynamic_quant(x)
# TODO: verify and remove following reshape code
# origin_x_size = x.size()
# x = x.reshape(-1, origin_x_size[-1])

# TODO: better API
# weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed)
n_bit = 4
quant_min = -(2 ** (n_bit - 1))
quant_max = 2 ** (n_bit - 1) - 1
w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group(
weight_int8,
scales,
zeros,
quant_min,
quant_max,
torch.int8,
groupsize,
precision,
)

# x = x.to(torch.float16)
# w_dq = w_dq.to(torch.float16)
c = torch.nn.functional.linear(x, w_dq)

# new_shape = origin_x_size[:-1] + (out_features,)
# c = c.reshape(new_shape)

return c


def find_multiple(n: int, *args: Tuple[int]) -> int:
k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9]
if n % k == 0:
return n
return n + k - (n % k)

##### GPTQ #####

def _check_linear_int4_k(k, groupsize=1):
return k % groupsize == 0


def _calc_padded_size_linear_int4(k, groupsize=1):
return find_multiple(k, groupsize)


def replace_linear_8da4w(
module,
groupsize,
padding_allowed,
precision,
scales_precision,
):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed:
setattr(
module,
name,
Int8DynActInt4WeightLinear(
child.in_features,
child.out_features,
bias=False,
groupsize=groupsize,
precision=precision,
scales_precision=scales_precision,
),
)
else:
replace_linear_8da4w(
child,
groupsize,
padding_allowed,
precision,
scales_precision,
)


class Int8DynActInt4WeightQuantHandler(QuantHandler):
def __init__(
self,
mod,
device,
*,
groupsize=256,
padding_allowed=False,
precision=torch.float32,
scales_precision=torch.float32,
):
self.mod = mod
self.device = device
self.groupsize = groupsize
self.padding_allowed = padding_allowed
self.precision = precision
self.scales_precision = scales_precision
# assert groupsize in [32, 64, 128, 256]

@torch.no_grad()
def create_quantized_state_dict(self):
cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
in_features = mod.in_features
# print("in features:", in_features, " out features:", out_features)
# assert out_features % 8 == 0, "require out_features % 8 == 0"
# print(f"linear: {fqn}, in={in_features}, out={out_features}")

assert (
in_features % self.groupsize == 0
), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0"

weight = mod.weight.data
"""
if not _check_linear_int4_k(
in_features, self.groupsize
):
if self.padding_allowed:
print(
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
)
padded_in_features = _calc_padded_size_linear_int4(
in_features, self.groupsize
)
weight = F.pad(
weight, pad=(0, padded_in_features - in_features)
)
else:
raise RuntimeError(
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+ "and that groupsize"
)
"""
(
weight_int4pack,
scales,
zeros,
) = prepare_int4_weight_and_scales_and_zeros(
weight.to(self.precision),
self.groupsize,
self.scales_precision,
)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")

return cur_state_dict

def convert_for_runtime(self):
replace_linear_8da4w(
self.mod,
self.groupsize,
self.padding_allowed,
self.precision,
self.scales_precision,
)
return self.mod

def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict)
return self.mod


class Int8DynActInt4WeightLinear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]

in_features: int
out_features: int
weight: torch.Tensor

"""
This module implements a dynamic quantized linear layer with int4 weight.
Weights are per channel groupwise quantized. Parameters of importance
groupsize: the number of elements in each quantized group
precision: precision of input and output. e.g. torch.float32 means input
activation is float32 and output is float32.
scales_precision: precision of per group scale.
"""

def __init__(
self,
in_features: int,
out_features: int,
bias=True,
device=None,
dtype=None,
groupsize: int = 256,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
) -> None:
super().__init__()
# always pad if needed since it becomes a noop at runtime if not needed
# self.origin_in_features = in_features
assert (
in_features % groupsize == 0
), f"require in_features:{in_features} % groupsize:{groupsize} == 0"
# in_features = _calc_padded_size_linear_int4(
# in_features, groupsize
# )
self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.groupsize = groupsize
# Precision of the activation which also indicates
# output precision of the dynamically quantized linear layer
# that his module represents.
self.precision = precision

# currently storing unpacked int8 weights
self.register_buffer(
"weight",
torch.empty((out_features, in_features), dtype=torch.int8),
)
self.register_buffer(
"scales",
torch.empty(
(out_features, in_features // groupsize),
dtype=scales_precision,
),
)
self.register_buffer(
"zeros",
torch.empty(
(out_features, in_features // groupsize),
dtype=scales_precision,
),
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(self.precision)
# padding is removed for perf
# input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_forward_8da4w(
input,
self.weight,
self.scales,
self.zeros,
self.out_features,
self.groupsize,
self.precision,
)


#########################################################################
##### GPTQ #####


class GPTQQuantHandler(QuantHandler):
"""
This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
Expand Down Expand Up @@ -1445,77 +1195,6 @@ def quantized_model(self) -> nn.Module:
return self.mod


# class Int8DynActInt4WeightGPTQQuantHandler(GPTQQuantHandler):
# def __init__(
# self,
# groupsize=128,
# inner_k_tiles=8,
# padding_allowed=True,
# precision=torch.float32,
# ):

# self.groupsize = groupsize
# self.inner_k_tiles = inner_k_tiles
# self.padding_allowed = padding_allowed
# self.precision = precision
# self.dyn_quant_func = lambda x: per_token_dynamic_quant(x)
# n_bit = 4
# self.get_qparams_func = lambda w: get_group_qparams_symmetric(
# w, n_bit, groupsize, self.precision
# )
# quant_min = -(2 ** (n_bit - 1))
# quant_max = 2 ** (n_bit - 1) - 1
# self.quantize_func = lambda w, qparams: torch.ops.quantized_decomposed.quantize_per_channel_group(
# w, qparams[0], qparams[1], quant_min, quant_max, torch.int8, groupsize
# )
# self.dequantize_func = lambda q, qparams: torch.ops.quantized_decomposed.dequantize_per_channel_group(
# q,
# qparams[0],
# qparams[1],
# quant_min,
# quant_max,
# torch.int8,
# groupsize,
# self.precision,
# )
# self.combine_qparams_list_func = lambda qparams_list: [
# torch.cat(x, dim=1) for x in zip(*qparams_list)
# ]
# # skip unless padding_allowed=True or its correctly sized
# self.skip_layer_func = lambda linear_weight: not (
# _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles)
# or padding_allowed
# )

# # we need to do the padding here, both for q and the qparams if necessary
# def make_names_and_values_dict_func(q, qparams):
# k = q.shape[1]
# new_k = _calc_padded_size_linear_int4(k, groupsize, inner_k_tiles)
# # how much we need to pad the weight
# delta_k = new_k - q.shape[1]
# final_q = F.pad(q, pad=(0, delta_k))
# scales_and_zeros = pack_scales_and_zeros(*qparams, precision=self.precision)
# # how many new groups we need for padded weight
# delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
# # TODO: split scales and zero_points
# final_s_and_z = F.pad(
# scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1
# )
# return {"weight": final_q, "scales_and_zeros": final_s_and_z}

# self.make_names_and_values_dict_func = make_names_and_values_dict_func
# super().__init__()

# def convert_for_runtime(self, model):
# replace_linear_8da4w(
# model,
# self.groupsize,
# self.padding_allowed,
# torch.int8,
# self.precision,
# )
# return model

##################################################################
### WIP: HQQ ###

Expand Down

0 comments on commit 68ee0b0

Please sign in to comment.