Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add blockwise quantized dot support #7605

Merged
merged 5 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/quantized_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ orig_model.linear = q_linear
| per-channel | N/A | W4A16 | Yes |
| per-channel | per-token | W8A8 | No |
| per-channel | per-token | W4A8 | No |
| blockwise | N/A | W8A16 | No |
| blockwise | N/A | W4A16 | No |
| blockwise | N/A | W8A16 | Yes |
| blockwise | N/A | W4A16 | Yes |
| blockwise | per-token | W8A8 | No |
| blockwise | per-token | W4A8 | No |

Expand Down
117 changes: 98 additions & 19 deletions test/quantized_ops/test_quantized_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,66 @@ def __init__(self, input_dim, output_dim):
def weight_quantization_rtn(self,
linear,
n_bits=8,
block_size=-1,
quant_method=torch.per_channel_symmetric):
'''
Quantize linear weight using Round-To-Nearest(RTN) algorithm.
'''
assert isinstance(self.linear, torch.nn.Linear)
w_fp = linear.weight.data
min_val, max_val = torch.aminmax(w_fp, dim=1) # min_val, max_val [out_dim]
int_min = -2**(n_bits - 1)
int_max = 2**(n_bits - 1) - 1
scaler, zero_point = determine_qparams(
min_val,
max_val,
int_min,
int_max,
dtype=torch.int8,
eps=torch.Tensor([1e-5]),
has_customized_qrange=False,
qscheme=quant_method)
w_int = torch.ops.quantized_decomposed.quantize_per_channel(
w_fp, scaler, zero_point, 0, int_min, int_max, torch.int8)
return w_int, scaler.to(w_fp.dtype), zero_point

def replace_with_xla_quantized_matmul(self, n_bit=8):
if block_size == -1:
min_val, max_val = torch.aminmax(
w_fp, dim=1) # min_val, max_val [out_dim]
int_min = -2**(n_bits - 1)
int_max = 2**(n_bits - 1) - 1
scaler, zero_point = determine_qparams(
min_val,
max_val,
int_min,
int_max,
dtype=torch.int8,
eps=torch.Tensor([1e-5]),
has_customized_qrange=False,
qscheme=quant_method)
w_int = torch.ops.quantized_decomposed.quantize_per_channel(
w_fp, scaler, zero_point, 0, int_min, int_max, torch.int8)
return w_int, scaler.to(w_fp.dtype), zero_point
else:
assert w_fp.shape[1] % block_size == 0
output_dim = w_fp.shape[0]
input_dim = w_fp.shape[1]
w_fp = w_fp.reshape(output_dim * input_dim // block_size, block_size)
min_val, max_val = torch.aminmax(
w_fp, dim=1) # min_val, max_val [out_dim]
int_min = -2**(n_bits - 1)
int_max = 2**(n_bits - 1) - 1
scaler, zero_point = determine_qparams(
min_val,
max_val,
int_min,
int_max,
dtype=torch.int8,
eps=torch.Tensor([1e-5]),
has_customized_qrange=False,
qscheme=quant_method)
w_int = torch.ops.quantized_decomposed.quantize_per_channel(
w_fp, scaler, zero_point, 0, int_min, int_max, torch.int8)
w_int = w_int.reshape(output_dim, input_dim // block_size,
block_size).permute(1, 2, 0)
scaler = scaler.to(w_fp.dtype).reshape(output_dim,
input_dim // block_size).permute(
1, 0)
return w_int, scaler, zero_point

def replace_with_xla_quantized_matmul(self, n_bit=8, block_size=-1):
assert isinstance(self.linear, torch.nn.Linear)
w_int, scaler, _ = self.weight_quantization_rtn(self.linear, n_bit)
w_int, scaler, _ = self.weight_quantization_rtn(
self.linear, n_bits=n_bit, block_size=block_size)
use_int4_weight = n_bit == 4
q_linear = XlaQuantizedLinear(
self.linear.in_features,
self.linear.out_features,
block_size=block_size,
int4_weight=use_int4_weight)
q_linear.load_quantized_weight(w_int, scaler)
self.linear = q_linear
Expand Down Expand Up @@ -95,7 +126,7 @@ def test_q_linear_module_dynamo(self):
m = m.to(device)
m_dynamo = torch.compile(m, backend="openxla")
out_quant_dynamo = m_dynamo(x.to(device))
self.assertTrue(torch.allclose(out_fp, out_quant, atol=0.01))
self.assertTrue(torch.allclose(out_fp, out_quant, atol=0.02))
self.assertTrue(torch.allclose(out_quant_dynamo.cpu(), out_quant))

def test_q_linear_hlo(self):
Expand Down Expand Up @@ -146,6 +177,54 @@ def test_int4_per_channel_linear_module(self):
self.assertGreater(
self._calc_cosine_dist(out_quant_xla.cpu(), out_quant), 0.999999)

def test_blockwise_matmul_op(self):
input_features = 6
out_features = 8
block_size = 2
batch_size = 3
for n_bit in [4]:
with self.subTest(n_bit=n_bit):
weight = torch.randint(-8, 7, (input_features // block_size, block_size,
out_features)).to(torch.int8)
weight_scaler = torch.ones(input_features // block_size, out_features)
x = torch.rand(batch_size, input_features)

# Fake quantize output.
w_dq = (weight * weight_scaler.unsqueeze(1)).reshape(
input_features, out_features)
fake_quant_out = torch.matmul(x, w_dq)
# Eager output.
torch_out = torch.ops.xla.quantized_matmul(
x, weight, weight_scaler, block_size=block_size)
self.assertGreater(
self._calc_cosine_dist(fake_quant_out, torch_out), 0.99999)
# XLA Output.
if not (n_bit == 4 and xr.device_type() != 'TPU'):
x = x.to(device)
weight = weight.to(device)
weight_scaler = weight_scaler.to(device)
xla_out = torch.ops.xla.quantized_matmul(
x, weight, weight_scaler, block_size=block_size)
self.assertTrue(torch.allclose(torch_out, xla_out.cpu(), atol=0.03))

def test_blockwise_linear_module(self):
for n_bit in [4, 8]:
with self.subTest(n_bit=n_bit):
m = M(6, 8)
x = torch.randn(3, 6)
out_fp = m(x)
m.replace_with_xla_quantized_matmul(n_bit=8, block_size=2)
out_quant = m(x)
self.assertGreater(self._calc_cosine_dist(out_fp, out_quant), 0.99)

# Dot with int4 weight is only supported on TPU
if not (n_bit == 4 and xr.device_type() != 'TPU'):
m = m.to(device)
Copy link
Collaborator

@miladm miladm Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the behavior on CUDA and CPU device?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because int4 only runs on XLA:TPU

Copy link
Collaborator

@miladm miladm Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offline discussion summary: int4 works only on XLA:TPU today; XLA:CPU does not support INT4. XLA:GPU level of support is unclear as it is not tested currently.

x = x.to(device)
out_quant_xla = m(x)
self.assertGreater(
self._calc_cosine_dist(out_quant_xla.cpu(), out_quant), 0.999999)


if __name__ == '__main__':
unittest.main()
107 changes: 78 additions & 29 deletions torch_xla/experimental/xla_quantized_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch_xla.core.xla_model import XLA_LIB

XLA_LIB.define(
"quantized_matmul(Tensor x, Tensor w, Tensor scale, int? blocksize=-1, bool? int4_weight=False, bool? quantize_activation=False) -> Tensor"
"quantized_matmul(Tensor x, Tensor w, Tensor scale, int? block_size=-1, bool? int4_weight=False, bool? quantize_activation=False) -> Tensor"
)


Expand All @@ -21,83 +21,132 @@ def _check_per_channel_quant_weight_dtype_shapes(input_dim, output_dim, w,
0], f"weight scaler shape is expect to be [out_channel,], got {w_scaler.shape}, weight shape {w_shape}."


def _check_blockwise_quant_weight_dtype_shapes(input_dim, output_dim,
block_size, w, w_scaler):
assert w.dtype == torch.int8, (
f"Weight dtype is expected to be torch.int8, got {w.dtype}.")
assert w.dim() == 3, (
f"Weight tensor is expected to be 3D, got {w.dim()}D Tensor.")
w_shape = list(w.shape)
assert input_dim % block_size == 0, (
f"input_dim should be divisible by block_size, "
f"got input_dim: {input_dim}, block_size: {block_size}.")
assert w_shape[0] == input_dim / block_size and w_shape[1] == block_size, (
f"Weight shape is expected to be [input_dim / block_size, block_size, output_dim], "
f"input_dim: {input_dim}, block_size: {block_size}, output_dim: {output_dim}, "
f"but got {w_shape}.")
assert w_scaler.dim() == 2, (
f"weight scaler is expected to be 2D, got {w_scaler.dim()}D Tensor.")
assert w_scaler.shape[0] == w_shape[0] and w_scaler.shape[1] == w_shape[-1], (
f"weight scaler shape is expect to be [in_channel / block_size, out_channel], "
f"got {w_scaler.shape}, weight shape {w_shape}.")


@impl(XLA_LIB, "quantized_matmul", "XLA")
def quantized_matmul_xla(x: torch.Tensor,
w: torch.Tensor,
scaler: torch.Tensor,
blocksize: int = -1,
block_size: int = -1,
int4_weight: bool = False):
"""Quantized Matrix Multiply op on XLA devices.

Args:
x: torch.Tensor - Activation of Matmul [..., in_channel].
w: torch.Tensor - Weight Tensor.
per-channel quant: torch.int8 x [out_channel, in_channel].
block_wise quant: torch.int8 x [in_channel / block_size, block_size, out_channel].
scaler: torch.Tensor - Weight scaler.
per-channel quant: [out_channel,].
blocksize: blocksize for blockwise quantization, -1 for per-channel quantization.
blockwise quant: [in_channel / block_size, out_channel].
block_size: The blocksize for blockwise quantization, -1 for per-channel quantization.
int4_weight: if the weights are int4, the int4 weights need to be stored in a int8
container (unpacked).
"""
assert blocksize == -1, "blockwise quantization is not supported yet."
if int4_weight:
# Reinterpret cast the weight to s4 dtype in XLA.
w = torch_xla._XLAC._xla_cast_int4(w, w.cpu().flatten().numpy().tolist())
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0], w,
scaler)
return F.linear(x, w) * scaler
if block_size == -1:
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0],
w, scaler)
return F.linear(x, w) * scaler
else:
# Blockwise quant.
_check_blockwise_quant_weight_dtype_shapes(x.shape[-1], w.shape[-1],
block_size, w, scaler)
x = x.reshape(*x.shape[:-1], x.shape[-1] // block_size, block_size)
out = torch.einsum('scn,...sc->...sn', w, x)
out = torch.einsum('sn,...sn->...n', scaler, out)
return out


@impl(XLA_LIB, "quantized_matmul", "CompositeExplicitAutograd")
def quantized_matmul(x: torch.Tensor,
w: torch.Tensor,
scaler: torch.Tensor,
blocksize: int = -1,
block_size: int = -1,
int4_weight: bool = False):
assert blocksize == -1, "blockwise quantization is not supported yet."
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0], w,
scaler)
w = w.to(x.dtype)
return torch.mul(F.linear(x, w), scaler)
if block_size == -1:
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0],
w, scaler)
w = w.to(x.dtype)
return torch.mul(F.linear(x, w), scaler)
else:
# Blockwise quant.
_check_blockwise_quant_weight_dtype_shapes(x.shape[-1], w.shape[-1],
block_size, w, scaler)
x = x.reshape(*x.shape[:-1], x.shape[-1] // block_size, block_size)
w = w.to(x.dtype)
out = torch.einsum('scn,...sc->...sn', w, x)
out = torch.einsum('sn,...sn->...n', scaler, out)
return out


class XlaQuantizedLinear(torch.nn.Module):

def __init__(self,
input_dim,
output_dim,
blocksize=-1,
block_size=-1,
int4_weight: bool = False):
super().__init__()
assert blocksize == -1, "Only per-channel quantization is supported."
self.input_dim = input_dim
self.output_dim = output_dim
self.blocksize = blocksize
self.block_size = block_size
self.int4_weight = int4_weight
self.register_buffer('weight',
torch.zeros(output_dim, input_dim).to(torch.int8))
self.register_buffer('weight_scaler', torch.zeros(output_dim))

def load_quantized_weight(self, weight, weight_scaler):
'''
Weight shape: [output_channel, input_channel]
Weight scaler shape: [output_channel]
weight (Tensor):
per-channel quant: [out_channel, in_channel].
block_wise quant: [in_channel / block_size, block_size, out_channel].

weight_scaler (Tensor):
per-channel quant: [out_channel,].
blockwise quant: [in_channel / block_size, out_channel].
'''
if self.blocksize == -1:
if self.block_size == -1:
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(self.input_dim,
self.output_dim, weight,
weight_scaler)
self.weight = weight
self.weight_scaler = weight_scaler
else:
assert False, "Only per-channel quantization is supported."
# Blockwise quant.
_check_blockwise_quant_weight_dtype_shapes(self.input_dim,
self.output_dim,
self.block_size, weight,
weight_scaler)
self.weight = weight
self.weight_scaler = weight_scaler

def forward(self, x):
if self.blocksize == -1:
return torch.ops.xla.quantized_matmul(
x, self.weight, self.weight_scaler, int4_weight=self.int4_weight)
else:
assert False, "Only per-channel quantization is supported."
return torch.ops.xla.quantized_matmul(
x,
self.weight,
self.weight_scaler,
block_size=self.block_size,
int4_weight=self.int4_weight)
Loading