diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 148228a030..3aea9d3c68 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -139,21 +139,14 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): - # Note: we only added cpu path here for 8da4w, this is for executorch, in the future - # 1. we'll add cpu/cuda version (int4mm etc.) - # 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like - # cpu device + et laytout --> gives current 8da4w executorch representation - # cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc. - # cuda device + some layout --> gives cuda kernel - # two scenarios where we currently fall back to vanilla mm: - # 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized - # kernels in CPU as well, see the note above + # 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation + # for consistency and to allow people to test # 2 - we're given non-floats - quantizing long to int8 is crazy if ( func in [aten.mm.default, aten.addmm.default] and args[0].is_floating_point() - and args[0].device == torch.device("cpu") + and args[0].is_cuda, ): if func == aten.addmm.default: assert args[1].shape[-1] == args[2].shape[0], ( @@ -803,14 +796,21 @@ def _apply_fn_to_data(self, fn): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): + # Note: we only added cpu path here for 8da4w, this is for executorch, in the future + # 1. we'll add cpu/cuda version (int4mm etc.) + # 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like + # cpu device + et laytout --> gives current 8da4w executorch representation + # cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc. + # cuda device + some layout --> gives cuda kernel + # two scenarios where we currently fall back to vanilla mm: - # 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation - # for consistency and to allow people to test + # 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized + # kernels in CPU as well, see the note above # 2 - we're given non-floats - quantizing long to int8 is crazy if ( func in [aten.mm.default, aten.addmm.default] and args[0].is_floating_point() - and args[0].is_cuda + and args[0].device == torch.device("cpu") ): if func == aten.addmm.default: assert args[1].shape[-1] == args[2].shape[0], ( @@ -833,6 +833,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): None if len(args) == 2 else args[2], ) if weight_qtensor.input_quant_func is not None: + # dynamic quantization input_tensor = weight_qtensor.input_quant_func(input_tensor) input_tensor = input_tensor.dequantize() weight_tensor = weight_qtensor.dequantize()