From f48cfdfc1ae011b0fae5057bb14b5712dcffebf9 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 18 Oct 2024 16:14:06 -0700 Subject: [PATCH 1/2] Add tensor parallelism support for int4_weight_only quantization Summary: Following https://github.com/pytorch/ao/issues/988 we added TP support for int4_weight_only quantization in torchao that's using TensorCoreTiledLayout Addresses one work item in https://github.com/pytorch/ao/issues/988 Also clarified docs based on https://github.com/pytorch/ao/issues/386 Also restructructured the tests in test/dtypes/test_affine_quantized_tensor_parallel.py to not depend on torchao/utils.py to reduce the jumps people have to do to understand what is tested Test Plan: python test/dtypes/test_affine_quantized_tensor_parallel.py Reviewers: Subscribers: Tasks: Tags: --- .../test_affine_quantized_tensor_parallel.py | 247 ++++++++++-------- torchao/dtypes/affine_quantized_tensor.py | 92 +++++-- torchao/dtypes/utils.py | 21 +- 3 files changed, 229 insertions(+), 131 deletions(-) diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 8e6855a5df..0251c6d7b6 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -1,9 +1,13 @@ import torch import unittest -from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase from torch.testing._internal.common_utils import run_tests from torch.testing._internal import common_utils -from torchao.quantization import int8_weight_only, float8_weight_only, float8_dynamic_activation_float8_weight +from torchao.quantization import ( + int4_weight_only, + int8_weight_only, + float8_weight_only, + float8_dynamic_activation_float8_weight, +) from torchao.quantization.observer import PerRow, PerTensor import torch.distributed as dist from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh @@ -16,119 +20,142 @@ from torchao.dtypes import AffineQuantizedTensor from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -class TestInt8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): +class TestAffineQuantizedTensorParallel(DTensorTestBase): + """Basic test case for tensor subclasses + """ QUANT_METHOD_FN = staticmethod(int8_weight_only) -copy_tests(TorchAOTensorParallelTestCase, TestInt8woAffineQuantizedTensorParallel, "int8wo_tp") + QUANT_METHOD_KWARGS = {} -# Run only on H100 -if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): - class TestFloat8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): - QUANT_METHOD_FN = staticmethod(float8_weight_only) - copy_tests(TorchAOTensorParallelTestCase, TestFloat8woAffineQuantizedTensorParallel, "fp8wo_tp") + @staticmethod + def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in column-wise fashion + """ + # Column-wise is wrt to A^T, so for A it is row-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_rows = orig_weight.size(0) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + + @staticmethod + def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in row-wise fashion + """ + # Row-wise is wrt to A^T, so for A it is column-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_cols = orig_weight.size(1) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + + def quantize(self, m: torch.nn.Module) -> torch.nn.Module: + """ + Quantize the model + """ + quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS)) + return m + + def _test_tp(self, dtype): + device = "cuda" + # To make sure different ranks create the same module + torch.manual_seed(5) + + class M(torch.nn.Module): + def __init__(self, in_features, out_features, **kwargs) -> None: + super().__init__(**kwargs) + self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + # Get rank and device + device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}") + + # Original model + proj_up = M(1024, 2048).to(device).to(dtype) + proj_dn = M(2048, 1024).to(device).to(dtype) + example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype) + y = proj_dn(proj_up(example_input)) + # Quantize the model + up_quant = self.quantize(proj_up) + dn_quant = self.quantize(proj_dn) + y_q = dn_quant(up_quant(example_input)) + + mesh = self.build_device_mesh() + mesh.device_type = "cuda" + + # Shard the models + up_dist = self.colwise_shard(up_quant, mesh) + dn_dist = self.rowwise_shard(dn_quant, mesh) + + # We need to turn inputs into DTensor form as well -- just a format change + input_dtensor = DTensor.from_local( + example_input, mesh, [Replicate()] + ) + + y_d = dn_dist(up_dist(input_dtensor)) + + if not TORCH_VERSION_AT_LEAST_2_5: + # Need torch 2.5 to support compiled tensor parallelism + return + + up_compiled = torch.compile(up_dist) + y_up = up_compiled(input_dtensor) + dn_compiled = torch.compile(dn_dist) + y_dn = dn_compiled(y_up) + + +class TestInt8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): + QUANT_METHOD_FN = staticmethod(int8_weight_only) + COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] + + @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tp(self, dtype): + return self._test_tp(dtype) + + +class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): + QUANT_METHOD_FN = staticmethod(int4_weight_only) + COMMON_DTYPES = [torch.bfloat16] + + @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tp(self, dtype): + return self._test_tp(dtype) + +common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel) +common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel) # Run only on H100 if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): - class TestFloat8dqAffineQuantizedTensorParallel(DTensorTestBase): - """Basic test case for tensor subclasses - """ + class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): + QUANT_METHOD_FN = staticmethod(float8_weight_only) COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] - TENSOR_SUBCLASS = AffineQuantizedTensor - QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) - QUANT_METHOD_KWARGS = {} - - @staticmethod - def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: - """ - Shard linear layer of the model in column-wise fashion - """ - # Column-wise is wrt to A^T, so for A it is row-wise. - # Number of rows per rank - orig_weight = m.linear.weight - n_local_rows = orig_weight.size(0) // mesh.size() - rank = mesh.get_local_rank() - local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :] - # Construct DTensor from local shard - dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) - # Replace parameter in module - m.linear.weight = torch.nn.Parameter( - dtensor, requires_grad=False - ) - return m - - @staticmethod - def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: - """ - Shard linear layer of the model in row-wise fashion - """ - # Row-wise is wrt to A^T, so for A it is column-wise. - # Number of rows per rank - orig_weight = m.linear.weight - n_local_cols = orig_weight.size(1) // mesh.size() - rank = mesh.get_local_rank() - local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols] - # Construct DTensor from local shard - dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True) - # Replace parameter in module - m.linear.weight = torch.nn.Parameter( - dtensor, requires_grad=False - ) - return m - - def quantize(self, m: torch.nn.Module) -> torch.nn.Module: - """ - Quantize the model - """ - quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS)) - return m - - def _test_tp(self, dtype): - device = "cuda" - # To make sure different ranks create the same module - torch.manual_seed(5) - - class M(torch.nn.Module): - def __init__(self, in_features, out_features, **kwargs) -> None: - super().__init__(**kwargs) - self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda") - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) - - # Get rank and device - device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}") - - # Original model - proj_up = M(1024, 2048).to(device).to(dtype) - proj_dn = M(2048, 1024).to(device).to(dtype) - example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype) - y = proj_dn(proj_up(example_input)) - # Quantize the model - up_quant = self.quantize(proj_up) - dn_quant = self.quantize(proj_dn) - y_q = dn_quant(up_quant(example_input)) - - mesh = self.build_device_mesh() - mesh.device_type = "cuda" - - # Shard the models - up_dist = self.colwise_shard(up_quant, mesh) - dn_dist = self.rowwise_shard(dn_quant, mesh) - - # We need to turn inputs into DTensor form as well -- just a format change - input_dtensor = DTensor.from_local( - example_input, mesh, [Replicate()] - ) - - y_d = dn_dist(up_dist(input_dtensor)) - - if not TORCH_VERSION_AT_LEAST_2_5: - # Need torch 2.5 to support compiled tensor parallelism - return - - up_compiled = torch.compile(up_dist) - y_up = up_compiled(input_dtensor) - dn_compiled = torch.compile(dn_dist) - y_dn = dn_compiled(y_up) + + @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tp(self, dtype): + return self._test_tp(dtype) class TestFloat8dqTensorAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel): QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) @@ -151,7 +178,7 @@ class TestFloat8dqRowAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTe @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_tp(self, dtype): return self._test_tp(dtype) - + common_utils.instantiate_parametrized_tests(TestFloat8dqTensorAffineQuantizedTensorParallel) common_utils.instantiate_parametrized_tests(TestFloat8dqRowAffineQuantizedTensorParallel) if __name__ == "__main__": diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index d14a5dd17c..75d178fb50 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -56,6 +56,9 @@ class AQTTensorImpl(TorchAOBaseTensor): """ Base class for the tensor impl for `AffineQuantizedTensor` + + Note: This is not a user facing API, it's used by AffineQuantizedTensor to construct + the underlying implementation of a AQT based on layout """ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Get the plain (unpacked) Tensor for the tensor impl @@ -487,6 +490,10 @@ class BlockSparseLayout(Layout): @dataclass(frozen=True) class TensorCoreTiledLayout(Layout): + """ + inner_k_tiles is an internal argument for packing function of tensor core tiled layout + that can affect the performance of the matmul kernel + """ inner_k_tiles: int = 8 def pre_process(self, input: torch.Tensor) -> torch.Tensor: @@ -512,11 +519,16 @@ def pre_process_static(self, input: torch.Tensor, scale: torch.Tensor, zero_poin scale = torch.nn.functional.pad(scale, padding_changes) zero_point = torch.nn.functional.pad(zero_point, padding_changes) return input, scale, zero_point - - - - + def post_process(self, input: torch.Tensor) -> torch.Tensor: + orig_out_features, orig_in_features = input.shape + in_features = find_multiple(orig_in_features, 1024) + out_features = find_multiple(orig_out_features, 8) + input = torch.nn.functional.pad( + input, + (0, in_features - orig_in_features, 0, out_features - orig_out_features), + ) + return input def extra_repr(self): return f"inner_k_tiles={self.inner_k_tiles}" @@ -551,7 +563,7 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: @register_layout(PlainLayout) class PlainAQTTensorImpl(AQTTensorImpl): """ - TensorImpl storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point + TensorImpl for plain layout for affine quantized tensor, it stores int_data, scale, zero_point tensors directly as plain tensors. fields: @@ -675,7 +687,7 @@ def from_plain( @register_layout(SemiSparseLayout) class SemiSparseAQTTensorImpl(PlainAQTTensorImpl): """ - TensorImpl storage class for semi_sparse_cusparselt layout for affine quantized tensor + TensorImpl for semi_sparse_cusparselt layout for affine quantized tensor """ @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): @@ -855,7 +867,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): @register_layout(MarlinSparseLayout) class MarlinSparseAQTTensorImpl(AQTTensorImpl): """ - TensorImpl storage class for sparse_marlin_24 layout for affine quantized tensor. + TensorImpl for sparse_marlin_24 layout for affine quantized tensor. Can be used with 4 bits and 8 bits quantization. @@ -1025,7 +1037,10 @@ def _apply_fn_to_data(self, fn): @register_layout(Float8Layout) class Float8AQTTensorImpl(AQTTensorImpl): """ - TensorImpl storage class for float8 tensor impl for affine quantized tensor + TensorImpl for float8 layout affine quantized tensor + + Note: technically we should not create a new layout for float8 we should merge this into + plain layout """ float8_data: torch.Tensor scale: torch.Tensor @@ -1166,9 +1181,21 @@ def __repr__(self): @register_layout(TensorCoreTiledLayout) class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): """ - TensorImpl storage class for tensor_core_tiled tensor impl for affine quantized tensor, this is for int4 only, - it stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of + TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, + used by tinygemm kernels `_weight_int4pack_mm` + + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of dimension: [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2] + (unpacked Tensor shape is n * k) + where inner_k_tiles is an internal argument for packing function of tensor core tiled layout + that can affect the performance of the matmul kernel (defaults to 8) + + Note: we also pack scale and zero point together here for tinygemm kernel + + Note: technically tensor core tiled layout should be the layout for the underlying packed weight + (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used + in plain layout, we just created a layout for AQT right now, this could be improved if we split out + int4 aqt into a separate tensor subclass fields: packed_weight (torch.Tensor): the 4-d packed tensor in a tensor_core_tiled layout @@ -1251,9 +1278,15 @@ def to(self, *args, **kwargs): ) def _apply_fn_to_data(self, fn): - self.packed_weight = fn(self.packed_weight) - self.scale_and_zero = fn(self.scale_and_zero) - return self + # self.packed_weight = fn(self.packed_weight) + # self.scale_and_zero = fn(self.scale_and_zero) + # return self + return self.__class__( + fn(self.packed_weight), + fn(self.scale_and_zero), + self.transposed, + self._layout, + ) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): @@ -1273,8 +1306,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs): """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose """ - args[0].transposed = not args[0].transposed - return return_and_correct_aliasing(func, args, kwargs, args[0]) + transposed = TensorCoreTiledAQTTensorImpl(args[0].packed_weight, args[0].scale_and_zero, not args[0].transposed, args[0]._layout) + return return_and_correct_aliasing(func, args, kwargs, transposed) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + int_data, scale, zero_point = self.get_plain() + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return return_and_correct_aliasing(func, args, kwargs, sliced) + elif dim == 1: + int_data, scale, zero_point = self.get_plain() + assert step == 1, "Only step == 1 is supported in slicing right now" + data_len = int_data.shape[dim] + scale_len = scale.shape[dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor(zero_point, dim, start_scale, end_scale, step) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return sliced + else: + raise NotImplementedError(f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") raise NotImplementedError( f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" @@ -1510,6 +1571,7 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): y = y[:, :orig_out_features] y = y.reshape(*orig_act_size[:-1], orig_out_features) + if bias is not None: y += bias return y.to(orig_dtype) diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index d17231c1ee..7c0dfd9dc8 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -3,13 +3,22 @@ from dataclasses import dataclass """ -Base class for different Layout, should not be instantiated directly -used to allow users to pass around configurations for the tensor impl, e.g. inner_k_tiles -for int4 tensor core tiled tensor impl +Base class for different layout, following the same design of PyTorch layout +https://pytorch.org/docs/stable/tensor_attributes.html#torch-layout, used to represent different +data layout of a Tensor, it's used in conjunction with TensorImpl to represent custom data layout. -Note: TensorImpl is an abstraction not only for custom data representation, it is also used for how the -tensorImpl interacts with different operators, e.g. the same data representation can have different -behaviors when running the same operator, e.g. transpose, quantized_linear. +As a native PyTorch example, Sparse Coordinate format Tensor (https://pytorch.org/docs/stable/generated/torch.sparse_coo_tensor.html#torch-sparse-coo-tensor) has `torch.sparse_coo` layout, which is backed up by +`SparseImpl`: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/SparseTensorImpl.h which stores two Tensors (indices_ and values_) + +We extended the layout in torchao with Layout class (instead of torch.layout objects), also we use tensor subclass to implement TensorImpl classes. + +Layout also allows users to pass around configurations for the TensorImpl, +e.g. inner_k_tiles for int4 tensor core tiled TensorImpl + +Note: Layout is an abstraction not only for custom data representation, it is also used for how the +Tensor interacts with different operators, e.g. the same data representation can have different +behaviors when running the same operator, e.g. transpose, quantized_linear. This is the same as layout +in PyTorch native Tensor """ @dataclass(frozen=True) class Layout: From ca241ed47c0e9441249da60de00fbedd82031a77 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 18 Oct 2024 16:37:00 -0700 Subject: [PATCH 2/2] typo --- test/dtypes/test_affine_quantized_tensor_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 0251c6d7b6..42511e7db8 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -157,7 +157,7 @@ class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParalle def test_tp(self, dtype): return self._test_tp(dtype) - class TestFloat8dqTensorAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel): + class TestFloat8dqTensorAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) QUANT_METHOD_KWARGS = {"granularity": PerTensor()} COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] @@ -168,7 +168,7 @@ class TestFloat8dqTensorAffineQuantizedTensorParallel(TestFloat8dqAffineQuantize def test_tp(self, dtype): return self._test_tp(dtype) - class TestFloat8dqRowAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel): + class TestFloat8dqRowAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) QUANT_METHOD_KWARGS = {"granularity": PerRow()} COMMON_DTYPES = [torch.bfloat16]