Skip to content

Commit

Permalink
Fix an error in subclass impl
Browse files Browse the repository at this point in the history
Summary:
Accidently changed the device check code for old subclass instead of the new one, forgot to fix before landing

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 7, 2024
1 parent b34d1ac commit 6d59ba2
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,14 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
# 1. we'll add cpu/cuda version (int4mm etc.)
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
# cpu device + et laytout --> gives current 8da4w executorch representation
# cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc.
# cuda device + some layout --> gives cuda kernel

# two scenarios where we currently fall back to vanilla mm:
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
# kernels in CPU as well, see the note above
# 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation
# for consistency and to allow people to test
# 2 - we're given non-floats - quantizing long to int8 is crazy
if (
func in [aten.mm.default, aten.addmm.default]
and args[0].is_floating_point()
and args[0].device == torch.device("cpu")
and args[0].is_cuda
):
if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
Expand Down Expand Up @@ -803,14 +796,21 @@ def _apply_fn_to_data(self, fn):

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
# 1. we'll add cpu/cuda version (int4mm etc.)
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
# cpu device + et laytout --> gives current 8da4w executorch representation
# cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc.
# cuda device + some layout --> gives cuda kernel

# two scenarios where we currently fall back to vanilla mm:
# 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation
# for consistency and to allow people to test
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
# kernels in CPU as well, see the note above
# 2 - we're given non-floats - quantizing long to int8 is crazy
if (
func in [aten.mm.default, aten.addmm.default]
and args[0].is_floating_point()
and args[0].is_cuda
and args[0].device == torch.device("cpu")
):
if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
Expand All @@ -833,6 +833,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
None if len(args) == 2 else args[2],
)
if weight_qtensor.input_quant_func is not None:
# dynamic quantization
input_tensor = weight_qtensor.input_quant_func(input_tensor)
input_tensor = input_tensor.dequantize()
weight_tensor = weight_qtensor.dequantize()
Expand Down

0 comments on commit 6d59ba2

Please sign in to comment.