Skip to content

Commit

Permalink
Enable 6-bit kernel
Browse files Browse the repository at this point in the history
Summary: Enables 6-bit kernel that bootcamper wrote, and improves linear callsite when input has rank 3 or higher.

Differential Revision: D63991820
  • Loading branch information
metascroy authored and facebook-github-bot committed Oct 7, 2024
1 parent 52d27a1 commit 78b4b19
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ TORCH_LIBRARY(torchao, m) {
DEFINE_OP(3);
DEFINE_OP(4);
DEFINE_OP(5);
DEFINE_OP(6);
}

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
Expand All @@ -74,6 +75,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) {
DEFINE_CPU_IMPL(3);
DEFINE_CPU_IMPL(4);
DEFINE_CPU_IMPL(5);
DEFINE_CPU_IMPL(6);
}

TORCH_LIBRARY_IMPL(torchao, Meta, m) {
Expand All @@ -82,4 +84,5 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) {
DEFINE_META_IMPL(3);
DEFINE_META_IMPL(4);
DEFINE_META_IMPL(5);
DEFINE_META_IMPL(6);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

// Unlike ATen, ExecuTorch op registration appears to only allow on
// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new
// file is needed for each variant

#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h>

namespace {
Tensor _op_out(
RuntimeContext& ctx,
const Tensor& activations,
const Tensor& packed_weights,
const Tensor& group_size_tensor,
const Tensor& n_tensor,
const Tensor& k_tensor,
Tensor& out) {
(void)ctx;
linear_out_cpu</*weight_nbit*/ 6, /*has_weight_zeros*/ false>(
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
return out;
}
} // namespace

EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_6bit0zp_weight.out", _op_out);
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

// Unlike ATen, ExecuTorch op registration appears to only allow on
// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new
// file is needed for each variant

#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h>

namespace {
Tensor _op_out(
RuntimeContext& ctx,
const Tensor& activations,
const Tensor& packed_weights,
const Tensor& group_size_tensor,
const Tensor& n_tensor,
const Tensor& k_tensor,
Tensor& out) {
(void)ctx;
linear_out_cpu</*weight_nbit*/ 6, /*has_weight_zeros*/ true>(
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
return out;
}
} // namespace

EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_6bit_weight.out", _op_out);
12 changes: 2 additions & 10 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,7 @@ def forward(self, x):
lead_shape = x.shape[0:-2]
m, k = x.shape[-2], x.shape[-1]
n = self._n.shape[1]
x = x.reshape(-1, m, k)

res = [
self._linear_op(
x[i, :, :], self.packed_weights, self._group_size, self._n, self._k
)
for i in range(x.shape[0])
]
res = torch.stack(res)
res = self._linear_op(x.reshape(-1, k), self.packed_weights, self._group_size, self._n, self._k)
res = res.reshape(*lead_shape, m, n)
return res

Expand Down Expand Up @@ -206,7 +198,7 @@ def forward(self, x):

def _maybe_get_quantized_linear_native(nbit, has_weight_zeros):
try:
if nbit in [1, 2, 3, 4, 5]:
if nbit in [1, 2, 3, 4, 5, 6]:
wzp_suffix = "" if has_weight_zeros else "0zp"
return _Int8DynActIntxWeightQuantizedLinearNative(
pack_weight_op=getattr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_accuracy(self):
m = 1
n = 1071
k = 4096
activations = torch.randn(m, k, dtype=torch.float32)
activations = torch.randn(2, 3, m, k, dtype=torch.float32)
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])

for nbit in [1, 2, 3, 4, 5, 6, 7]:
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_export_compile_aoti(self):
layers = [torch.nn.Linear(k0, k1, bias=False), torch.nn.Linear(k1, k2, bias=False), torch.nn.Linear(k2, k3, bias=False)]
model = torch.nn.Sequential(*layers)

activations = torch.randn(2, 1, m, k0, dtype=torch.float32)
activations = torch.randn(m, k0, dtype=torch.float32)

print("Quantizing model")
quantizer = Int8DynActIntxWeightQuantizer(
Expand Down

0 comments on commit 78b4b19

Please sign in to comment.