Skip to content

Commit

Permalink
Move Linear int4 to qops (pytorch#537)
Browse files Browse the repository at this point in the history
* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* move int4 linear to qops
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent d69915a commit 863f2cf
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 116 deletions.
97 changes: 97 additions & 0 deletions qops.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,100 @@ def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
return r.view(indices.size() + (-1,))

# r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, ))


def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_input_size = input.size()
input = input.reshape(-1, origin_input_size[-1])

if "cuda" in str(input.device):
c = torch.ops.aten._weight_int4pack_mm(
input.to(torch.bfloat16),
weight_int4pack,
groupsize,
scales_and_zeros.to(torch.bfloat16),
).to(
input.dtype
) # cast back to input.dtype
else:
c = torch.ops.aten._weight_int4pack_mm(
input,
weight_int4pack,
groupsize,
scales_and_zeros,
)
new_shape = origin_input_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c


class LinearInt4(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
scales_and_zeros: torch.Tensor

def __init__(
self,
device: str,
in_features: int,
out_features: int,
bias=True,
dtype=None,
groupsize: int = 128,
inner_k_tiles: int = 8,
) -> None:
super().__init__()
self.padding = not self._check_k(
k=in_features,
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
)
if self.padding:
self.origin_in_features = in_features
in_features = find_multiple(in_features, 1024)

self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles

assert out_features % 8 == 0, "require out_features % 8 == 0"
assert (
in_features % (inner_k_tiles * 16) == 0
), "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
device=device,
),
)
self.register_buffer(
"scales_and_zeros",
torch.empty(
(in_features // groupsize, out_features, 2),
dtype=get_precision(),
device=device,
),
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.padding:
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_int4(
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)

@classmethod
def _check_k(cls, *, k, groupsize=1, inner_k_tiles=1):
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0

169 changes: 53 additions & 116 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
)
from qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding

from qops import (
LinearInt4 as WeightOnlyInt4Linear,
LinearInt8 as WeightOnlyInt8Linear,
QuantizedEmbedding,
)


#########################################################################
### torchchat quantization API ###
Expand Down Expand Up @@ -606,31 +612,6 @@ def _int4_calc_padded_size(k, groupsize=1, innner_k_tiles=1):
return find_multiple(k, 1024)


def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])

if "cuda" in str(x.device):
c = torch.ops.aten._weight_int4pack_mm(
x.to(torch.bfloat16),
weight_int4pack,
groupsize,
scales_and_zeros.to(torch.bfloat16),
).to(
x.dtype
) # cast back to x.dtype
else:
c = torch.ops.aten._weight_int4pack_mm(
x,
weight_int4pack,
groupsize,
scales_and_zeros,
)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c


def replace_linear_int4(
module,
device,
Expand All @@ -640,9 +621,10 @@ def replace_linear_int4(
):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if (
_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles)
or padding_allowed
if padding_allowed or WeightOnlyInt4Linear._check_k(
k=child.in_features,
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
):
setattr(
module,
Expand Down Expand Up @@ -704,8 +686,10 @@ def create_quantized_state_dict(self):
# print(f"linear: {fqn}, in={in_features}, out={out_features}")

weight = mod.weight.data
if not _check_linear_int4_k(
in_features, self.groupsize, self.inner_k_tiles
if not WeightOnlyInt4Linear._check_k(
k=in_features,
groupsize=self.groupsize,
inner_k_tiles=self.inner_k_tiles,
):
if self.padding_allowed:
print(
Expand Down Expand Up @@ -751,85 +735,23 @@ def quantized_model(self) -> nn.Module:
return self.model_


class WeightOnlyInt4Linear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
scales_and_zeros: torch.Tensor

def __init__(
self,
device: str,
in_features: int,
out_features: int,
bias=True,
dtype=None,
groupsize: int = 128,
inner_k_tiles: int = 8,
) -> None:
super().__init__()
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
if self.padding:
self.origin_in_features = in_features
in_features = find_multiple(in_features, 1024)

self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles

assert out_features % 8 == 0, "require out_features % 8 == 0"
assert (
in_features % (inner_k_tiles * 16) == 0
), "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
device=device,
),
)
self.register_buffer(
"scales_and_zeros",
torch.empty(
(in_features // groupsize, out_features, 2),
dtype=get_precision(),
device=device,
),
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.padding:
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_forward_int4(
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)


#########################################################################
##### GPTQ #####


def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0


class GPTQQuantHandler(QuantHandler):
"""
This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
__init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.
The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
create_quantized_state_dict. Here is a description of each function.
"""This class implements a GPTQ QuantHandler that can be used to
apply GPTQ to a model in concert with the GenericGPTQRunner class.
Unlike the base QuantHandler class, the user does not need to
implement the create_quantized_state_dict, instead they have to
reimplement __init__ such that it defines the functions for the
quantization mode. User is expected to reimplement
convert_for_runtime.
The following functions (which must be defined in __init__) are
used to define the quantization mode for both GPTQ and
create_quantized_state_dict. Here is a description of each
function.
get_qparams_func:
A function that calculates the quantization qparams for an input tensor.
Expand All @@ -839,9 +761,11 @@ class GPTQQuantHandler(QuantHandler):
qparams: it can have any format but will need to be handled by the other defined functions below.
quantize_func:
A function that applies quantization to an input tensor. It should be noted
that this function needs to be able to handle quantizing the entire weight tensor, a single group,
or a single column.
A function that applies quantization to an input tensor. It
should be noted that this function needs to be able to handle
quantizing the entire weight tensor, a single group, or a
single column.
Args:
weight: A 2d weight tensor with non-integer dtype.
qparams: the output from get_qparams_func
Expand All @@ -850,9 +774,11 @@ class GPTQQuantHandler(QuantHandler):
dequantize_func:
A function that dequantizes an input quantized weight tensor. It should be noted
that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
or a single column.
A function that dequantizes an input quantized weight
tensor. It should be noted that this function needs to be able
to handle dequantizing the entire weight tensor, a single
group, or a single column.
Args:
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
qparams: the output from get_qparams_func
Expand All @@ -861,6 +787,7 @@ class GPTQQuantHandler(QuantHandler):
combine_qparams_list_func:
A function that combines several qparams into one qparam.
Args:
qparams_list: a list of qparams objects, each obtained by calling get_qparams_func
on a single group from a weight tensor
Expand All @@ -875,13 +802,17 @@ class GPTQQuantHandler(QuantHandler):
skip: boolean indicating whether layer should be skipped
make_names_and_values_dict_func:
A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they
should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here.
A function that prepares the qparams and quantized_weight and
creates a dictionary indicating how they should be inserted
into the state_dict. Generally any packing of the weight and
qparams should be done here.
Args:
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
qparams: the output from get_qparams_func
Returns:
names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the
names_and_values_dict: a dictionary mapping the name of
the parameters of the quantized module to the
corresponding quantized weights and qparams.
"""

Expand Down Expand Up @@ -1026,14 +957,20 @@ def __init__(
]
# skip unless padding_allowed=True or its correctly sized
self.skip_layer_func = lambda linear_weight: not (
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles)
or padding_allowed
padding_allowed
or WeightOnlyInt4Linear._check_k(
k=linear_weight.shape[-1],
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
)
)

# we need to do the padding here, both for q and the qparams if necessary
def make_names_and_values_dict_func(q, qparams):
k = q.shape[1]
if not _check_linear_int4_k(k, groupsize, inner_k_tiles):
if not WeightOnlyInt4Linear._check_k(
k=k, groupsize=groupsize, inner_k_tiles=inner_k_tiles
):
new_k = find_multiple(k, 1024)
else:
new_k = k
Expand Down

0 comments on commit 863f2cf

Please sign in to comment.