diff --git a/docs/Makefile b/docs/Makefile index c5769c3928..7279ac5bee 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -6,7 +6,12 @@ ifneq ($(EXAMPLES_PATTERN),) endif # You can set these variables from the command line. -SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS) + +# TODO: Revert this when have docs on pytorch.org/ao +# SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS) +# SPHINXOPTS = -WT -j auto --keep-going # enable later when the files are included in the doc build + + SPHINXBUILD = sphinx-build SPHINXPROJ = torchao SOURCEDIR = source diff --git a/docs/source/conf.py b/docs/source/conf.py index f1e700d556..988864f997 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -58,6 +58,12 @@ "remove_config_comments": True, } +### TODO: Delete this when we have content +suppress_warnings = [ + 'toc.unlisted', +] +### + napoleon_use_ivar = True napoleon_numpy_docstring = False napoleon_google_docstring = True diff --git a/docs/source/index.rst b/docs/source/index.rst index 3aee2a5075..fb1649fa48 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -2,53 +2,55 @@ Welcome to the torchao Documentation ======================================= **torchao** is an open-source library that provides the functionality -to quantize and prune your models using native PyTorch. +to quantize and prune your models using native PyTorch. Our documentation is under development +with more content coming soon. -.. grid:: 3 +.. + .. grid:: 3 - .. grid-item-card:: :octicon:`file-code;1em` - Getting Started - :img-top: _static/img/card-background.svg - :link: getting-started.html - :link-type: url + .. grid-item-card:: :octicon:`file-code;1em` + Getting Started + :img-top: _static/img/card-background.svg + :link: getting-started.html + :link-type: url - Learn about how to get started with torchao - and ts application in your projects. + Learn about how to get started with torchao + and ts application in your projects. - .. grid-item-card:: :octicon:`file-code;1em` - Concepts - :img-top: _static/img/card-background.svg - :link: dtypes.html - :link-type: url + .. grid-item-card:: :octicon:`file-code;1em` + Concepts + :img-top: _static/img/card-background.svg + :link: dtypes.html + :link-type: url - Learn about the key torchao concepts such - as dtypes, quantization, sparsity, among others. + Learn about the key torchao concepts such + as dtypes, quantization, sparsity, among others. - .. grid-item-card:: :octicon:`file-code;1em` - API Reference - :img-top: _static/img/card-background.svg - :link: api_ref_intro.html - :link-type: url + .. grid-item-card:: :octicon:`file-code;1em` + API Reference + :img-top: _static/img/card-background.svg + :link: api_ref_intro.html + :link-type: url - A comprehensive reference for the torchao - API and its functionalities. + A comprehensive reference for the torchao + API and its functionalities. -Tutorials -~~~~~~~~~ + Tutorials + ~~~~~~~~~ -Ready to experiment? Check out some of the -torchao tutorials. + Ready to experiment? Check out some of the + torchao tutorials. -.. customcardstart:: + .. customcardstart:: -.. customcarditem:: - :header: Template Tutorial - :card_description: A placeholder template for demo purposes - :image: _static/img/generic-pytorch-logo.png - :link: tutorials/template_tutorial.html - :tags: template + .. customcarditem:: + :header: Template Tutorial + :card_description: A placeholder template for demo purposes + :image: _static/img/generic-pytorch-logo.png + :link: tutorials/template_tutorial.html + :tags: template -.. customcardend:: + .. customcardend:: .. ---------------------------------------------------------------------- @@ -56,42 +58,43 @@ torchao tutorials. .. Each of the entry below corresponds to a file.rst in docs/source/. .. ---------------------------------------------------------------------- -.. toctree:: - :glob: - :maxdepth: 1 - :caption: Getting Started - :hidden: +.. + .. toctree:: + :glob: + :maxdepth: 1 + :caption: Getting Started + :hidden: - overview - getting-started + overview + getting-started -.. toctree:: - :glob: - :maxdepth: 1 - :caption: Concepts - :hidden: + .. toctree:: + :glob: + :maxdepth: 1 + :caption: Concepts + :hidden: - dtypes - quantization - sparsity - performant_kernels + dtypes + quantization + sparsity + performant_kernels -.. toctree:: - :glob: - :maxdepth: 1 - :caption: Tutorials - :hidden: + .. toctree:: + :glob: + :maxdepth: 1 + :caption: Tutorials + :hidden: - tutorials/template_tutorial + tutorials/template_tutorial .. toctree:: :glob: :maxdepth: 1 :caption: API Reference - :hidden: api_ref_intro api_ref_sparsity api_ref_quantization api_ref_dtypes - api_ref_kernel + .. + api_ref_kernel diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 55bbe0bcb9..3e8b89f9f0 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -1,6 +1,7 @@ import logging import unittest from packaging import version +import math import torch from torch import nn @@ -10,11 +11,17 @@ parametrize, run_tests, ) -from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor, to_nf4 +from torchao.dtypes.nf4tensor import ( + linear_nf4, + NF4Tensor, + to_nf4, + _INNER_TENSOR_NAMES_FOR_SHARDING, +) import torch.nn.functional as F import io from collections import OrderedDict import torchao +from typing import Tuple, Union bnb_available = False @@ -234,9 +241,199 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype): a_nf4 = torchao.dtypes.to_nf4(a, 16, 2) inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device) out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("shape", [(16, 16), (32, 16)]) + @parametrize("chunk_size", [8, 16, 32]) + def test_chunk_size_equivalence(self, dtype: torch.dtype, shape, chunk_size): + a = torch.randn(shape, device='cuda', dtype=dtype) + with unittest.mock.patch("torchao.dtypes.nf4tensor.CHUNK_SIZE", chunk_size): + nf4_patched = to_nf4(a, 16, 2) + # This will be essentially no chunking since the numel is alot smaller than default chunk_size + nf4_base = to_nf4(a, 16, 2) + + torch.testing.assert_close(nf4_patched.quantized_data, nf4_base.quantized_data) + + + +class TestFSDPOps(TestCase): + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_torch_chunk_valid(self, input_size: Union[Tuple[int], int]): + num_chunks = 2 + nf4_tensor = to_nf4(torch.randn(input_size)) + chunks = list(torch.chunk(nf4_tensor, num_chunks)) + self.assertEqual(len(chunks), num_chunks) + if isinstance(input_size, int): + expected_size0 = input_size // num_chunks + else: + expected_size0 = input_size[0] // num_chunks + for chunk in chunks: + self.assertEqual(chunk.size(0), expected_size0) + + @parametrize("input_size", [511 * 512, (511 * 512,), (511, 512)]) + def test_torch_chunk_invalid_divide(self, input_size: Union[Tuple[int], int]): + num_chunks = 2 + with self.assertRaisesRegex(AssertionError, "Number of scalers must be divisible by scaler block size"): + nf4_tensor = to_nf4(torch.randn(input_size)) + torch.chunk(nf4_tensor, num_chunks) + + @parametrize("input_size", [(512, 512, 512)]) + def test_torch_chunk_invalid_3d(self, input_size: Union[Tuple[int], int]): + num_chunks = 2 + with self.assertRaisesRegex(AssertionError, "expect input tensor dim <= 2"): + nf4_tensor = to_nf4(torch.randn(input_size)) + torch.chunk(nf4_tensor, num_chunks) + + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_tensor_new_zeros_valid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + nf4_tensor_zeros = nf4_tensor.new_zeros(input_size) + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4_tensor_zeros, attr) + self.assertEqual(torch.count_nonzero(inner_tensor), 0) + expected_size = input_size if not isinstance(input_size, int) else (input_size, ) + self.assertEqual(nf4_tensor_zeros.size(), torch.Size(expected_size)) + + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_tensor_new_zeros_invalid(self, input_size: Union[Tuple[int], int]): + if isinstance(input_size, int): + new_size = input_size + 1 + elif len(input_size) == 1: + new_size = (input_size[0] + 1, ) + else: + new_size = (input_size[0] + 1, input_size[1]) + nf4_tensor = to_nf4(torch.randn(input_size)) + with self.assertRaisesRegex(NotImplementedError, "aten.new_zeros\\(NF4Tensor\\) with new size"): + nf4_tensor_zeros = nf4_tensor.new_zeros(new_size) + + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_tensor_slice_valid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + orig_attrs, _ = nf4_tensor.__tensor_flatten__() + orig_sizes = dict([(attr, getattr(nf4_tensor, attr).size()) for attr in orig_attrs]) + end_idx = input_size if isinstance(input_size, int) else input_size[0] + sliced_tensor = nf4_tensor[:end_idx] + self.assertEqual(nf4_tensor.size(), sliced_tensor.size()) + attrs, _ = sliced_tensor.__tensor_flatten__() + for attr in attrs: + orig_storage = getattr(nf4_tensor, attr).untyped_storage().data_ptr() + sliced_tensor_inner = getattr(sliced_tensor, attr) + self.assertEqual(sliced_tensor_inner.untyped_storage().data_ptr(), orig_storage) + self.assertEqual(sliced_tensor_inner.size(), orig_sizes[attr]) + + def test_tensor_slice_1d_invalid(self): + nf4_tensor = to_nf4(torch.randn(512 * 512)) + with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with customized step"): + nf4_tensor[..., ::2] + with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"): + nf4_tensor[1:] + with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end"): + nf4_tensor[:2] + + def test_tensor_slice_2d_invalid(self): + nf4_tensor = to_nf4(torch.randn((512, 512))) + with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with dim"): + nf4_tensor[:, :511] + with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"): + nf4_tensor[1:] + with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end"): + nf4_tensor[:2] + + @parametrize("input_size", [(512 * 512,), (512, 512)]) + def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + viewed_tensor = nf4_tensor.view(-1) + self.assertEqual(viewed_tensor.dim(), 1) + self.assertEqual(viewed_tensor.numel(), math.prod(input_size)) + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(viewed_tensor, attr) + self.assertEqual(inner_tensor.size(0), inner_tensor.numel()) + + @parametrize("input_size", [(512 * 512,), (512, 512)]) + def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + if len(input_size) == 1: + with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"): + nf4_tensor.view(input_size) + if len(input_size) == 2: + with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with len\\(size\\)"): + nf4_tensor.view(input_size) + + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_tensor_as_strided_valid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + nf4_tensor_strided = torch.as_strided(nf4_tensor, nf4_tensor.size(), nf4_tensor.stride(), nf4_tensor.storage_offset()) + self.assertEqual(nf4_tensor_strided.size(), nf4_tensor.size()) + self.assertEqual(nf4_tensor_strided.stride(), nf4_tensor.stride()) + self.assertEqual(nf4_tensor_strided.storage_offset(), nf4_tensor.storage_offset()) + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor_orig = getattr(nf4_tensor, attr) + inner_tensor_strided = getattr(nf4_tensor_strided, attr) + self.assertEqual(inner_tensor_strided.size(), inner_tensor_orig.size()) + self.assertEqual(inner_tensor_strided.stride(), inner_tensor_orig.stride()) + self.assertEqual(inner_tensor_strided.storage_offset(), inner_tensor_orig.storage_offset()) + + + @parametrize("input_size", [(512 * 512,), (512, 512)]) + def test_tensor_as_strided_invalid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + if len(input_size) == 1: + size = (input_size[0] - 1, ) + else: + size = (input_size[0] - 1, input_size[1]) + with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) different numel"): + torch.as_strided(nf4_tensor, size, nf4_tensor.stride(), nf4_tensor.storage_offset()) + with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support original storage offset"): + torch.as_strided(nf4_tensor, nf4_tensor.size(), nf4_tensor.stride(), 1) + + if len(input_size) == 2: + with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support continuous stride"): + stride = (nf4_tensor.stride()[1], nf4_tensor.stride()[0]) + torch.as_strided(nf4_tensor, nf4_tensor.size(), stride, nf4_tensor.storage_offset()) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_pin_memory(self): + nf4_tensor = to_nf4(torch.randn(512 * 512)) + self.assertFalse(nf4_tensor.is_pinned()) + + nf4_tensor = nf4_tensor.pin_memory() + self.assertTrue(nf4_tensor.is_pinned()) + nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda')) + self.assertFalse(nf4_tensor.is_pinned()) + + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_to_cuda(self): + nf4_tensor = to_nf4(torch.randn(512 * 512)) + self.assertEqual(nf4_tensor.device.type, "cpu") + nf4_tensor = nf4_tensor.to("cuda", non_blocking=True) + self.assertEqual(nf4_tensor.device.type, "cuda") + + nf4_tensor = to_nf4(torch.randn(512 * 512)) + self.assertEqual(nf4_tensor.device.type, "cpu") + nf4_tensor = nf4_tensor.to("cuda") + self.assertEqual(nf4_tensor.device.type, "cuda") + + nf4_tensor = to_nf4(torch.randn(512 * 512)) + self.assertEqual(nf4_tensor.device.type, "cpu") + nf4_tensor = nf4_tensor.to("cuda", torch.bfloat16) + self.assertEqual(nf4_tensor.device.type, "cuda") + self.assertEqual(nf4_tensor.dtype, torch.bfloat16) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_to_cpu(self): + nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda')) + nf4_tensor = nf4_tensor.cpu() + self.assertEqual(nf4_tensor.device.type, "cpu") + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4_tensor, attr) + self.assertEqual(inner_tensor.device.type, "cpu") + instantiate_parametrized_tests(TestNF4Linear) +instantiate_parametrized_tests(TestFSDPOps) if __name__ == "__main__": run_tests() diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 6cf4b2c9f8..93ac6fe739 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -159,7 +159,7 @@ def test_int8_wo_quant_save_load(self): torch.testing.assert_close(ref, res.cpu()) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") def test_8da4w_quantizer(self): from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index f09d53821d..48249434b7 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -1,21 +1,38 @@ import functools from dataclasses import dataclass +import math from typing import Dict, Tuple +import math +import sys +from enum import Enum, auto import torch import torch.nn.functional as F from torch import Tensor +from torch.distributed.device_mesh import DeviceMesh +from torch._prims_common import make_contiguous_strides_for aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional -from typing import Any, Tuple +from typing import Any, Optional, Tuple, Union, List NF4_OPS_TABLE: Dict[Any, Any] = {} +_INNER_TENSOR_NAMES_FOR_SHARDING = ["quantized_scalers", "quantization_factor", "quantized_data"] + +# Note: Quantize in Chunks +# During quantization to NF4, one of the steps to convert from the original float number +# to the index of the nearest value in the NF4 format. This can cause a large memory spike +# Due to intermediates of the quantization process. Instead we process the original +# tensor in chunks. This is a tradeoff between memory and speed. This number seems to +# strike a good balance between memory and speed +CHUNK_SIZE = 1024**2 + + def same_metadata(a: "NF4Tensor", b: "NF4Tensor"): both_nf4 = isinstance(a, NF4Tensor) and isinstance(b, NF4Tensor) return ( @@ -37,11 +54,219 @@ def decorator(func): return decorator -@implements([torch.ops.aten.detach.default, torch.ops.aten.detach]) +def construct_nf4_args(nf4tensor: "NF4Tensor", kwargs: Optional[Dict[str, Any]] = None): + if kwargs is None: + kwargs = {} + tensor_meta = SubclassTensorArgs( + kwargs.get("size", nf4tensor.size()), + kwargs.get("stride", nf4tensor.stride()), + kwargs.get("storage_offset", nf4tensor.storage_offset()), + kwargs.get("dtype", nf4tensor.dtype), + kwargs.get("device", nf4tensor.device), + kwargs.get("requires_grad", nf4tensor.requires_grad), + ) + return ( + tensor_meta, + kwargs.get("block_size", nf4tensor.block_size), + kwargs.get("n_blocks", nf4tensor.n_blocks), + kwargs.get("scaler_block_size", nf4tensor.scaler_block_size), + kwargs.get("quantized_scalers", nf4tensor.quantized_scalers), + kwargs.get("quantization_factor", nf4tensor.quantization_factor), + kwargs.get("scaler_mean", nf4tensor.scaler_mean), + kwargs.get("quantized_data", nf4tensor.quantized_data), + kwargs.get("nf4", nf4tensor.nf4), + ) + + +# __torch_dispatch__ utils: apply aten op to inner tensors +def apply_to_inner_tensors(nf4tensor: "NF4Tensor", aten_op, args, kwargs): + attr_to_tensor = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + attr_to_tensor[attr] = aten_op(getattr(nf4tensor, attr), *args, **kwargs) + return attr_to_tensor + +# __torch_function__ utils: call tensor ops from inner tensors +def call_from_inner_tensors(nf4tensor: "NF4Tensor", method_name: str, args, kwargs): + attr_to_tensor = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + func = getattr(inner_tensor, method_name) + attr_to_tensor[attr] = func(*args, **kwargs) + return attr_to_tensor + +class CompareOp(Enum): + EQ = auto() + LT = auto() + +def expect_num_of_args(op: CompareOp, num: int, msg: str): + def decorator(func): + @functools.wraps(func) + def wrapper(aten_op, args, kwargs=None): + if op == CompareOp.LT and not (len(args) < num): + raise NotImplementedError(msg) + return func(aten_op, args, kwargs) + return wrapper + return decorator + +def expect_arg_value_at_k(k: int, op: CompareOp, value: Any, msg: str): + def decorator(func): + @functools.wraps(func) + def wrapper(aten_op, args, kwargs=None): + if op == CompareOp.EQ and not (args[k] == value): + raise NotImplementedError(msg + str(args[k])) + return func(aten_op, args, kwargs) + return wrapper + return decorator + +def expect_args_len_at_k(k: int, op: CompareOp, value: Any, msg: str): + def decorator(func): + @functools.wraps(func) + def wrapper(aten_op, args, kwargs=None): + if op == CompareOp.LT and not (len(args[k]) < value): + raise NotImplementedError(msg + str(len(args[k]))) + elif op == CompareOp.EQ and not (len(args[k]) == value): + raise NotImplementedError(msg + str(len(args[k]))) + return func(aten_op, args, kwargs) + return wrapper + return decorator + + +@implements([torch.ops.aten.detach]) def noop_detach(func, *args, **kwargs): return args[0][0] +@implements( + [ + aten.detach.default, + ] +) +def nf4_detach(aten_op, args, kwargs=None): + nf4tensor = args[0] + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + + +@implements( + [ + aten.split.Tensor, + ] +) +def nf4_split(aten_op, args, kwargs=None): + if len(args) == 3 and args[2] != 0: + raise NotImplementedError(f"aten.split(NF4Tensor, dim={args[2]})") + nf4tensor = args[0] + num_chunks = nf4tensor.size(0) // args[1] + + attr_to_chunks = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + assert inner_tensor.numel() % num_chunks == 0, f"{attr}.numel() not divisible by {num_chunks}" + chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs) + attr_to_chunks[attr] = chunks + + orig_dim = nf4tensor.dim() + if orig_dim == 1: + chunked_size = (nf4tensor.size(0) // num_chunks, ) + elif orig_dim == 2: + chunked_size = (nf4tensor.size(0) // num_chunks, nf4tensor.size(1)) + else: + chunked_size = () + raise NotImplementedError(f"aten.split(NF4Tensor) wherer NF4Tensor.dim() = {orig_dim}") + + nf4_chunks = [] + for idx in range(num_chunks): + updated_attrs = { + "size": chunked_size + } + for attr, chunks in attr_to_chunks.items(): + updated_attrs[attr] = chunks[idx] + nf4_chunks.append(NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))) + return nf4_chunks + +@implements( + [ + aten.new_zeros.default, + ] +) +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=") +def nf4_new_zeros(aten_op, args, kwargs=None): + nf4tensor = args[0] + new_size = tuple(args[1]) + new_size_dim = len(new_size) + if nf4tensor.numel() % math.prod(new_size) != 0: + raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}") + ratio = nf4tensor.numel() // math.prod(new_size) + + updated_attrs = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + assert inner_tensor.size(0) % ratio == 0, f"{attr}.numel() must be divisible by {ratio}" + inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs) + updated_attrs[attr] = inner_tensor + updated_attrs["size"] = new_size + + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + +@implements( + [ + aten.slice.Tensor, + ] +) +@expect_num_of_args(CompareOp.LT, 5, "aten.slice(NF4Tensor) with customized step") +@expect_arg_value_at_k(1, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with dim=") +@expect_arg_value_at_k(2, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with start=") +def nf4_slice(aten_op, args, kwargs=None): + nf4tensor = args[0] + # for tensor 512 x 512, tensor[:, :512] dispatch to + # aten.slice(dim = 0, end=sys.maxsize) + if not args[3] in [nf4tensor.size(0), sys.maxsize]: + raise NotImplementedError(f"aten.slice(NF4Tensor) with end={args[3]}") + return NF4Tensor(*construct_nf4_args(nf4tensor)) + +@implements( + [ + aten.view.default, + ] +) +@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=") +def nf4_view(aten_op, args, kwargs=None): + nf4tensor = args[0] + size = args[1] + if size[0] != -1: + raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + updated_attrs.update({ + "size": [nf4tensor.numel()], + "stride": (1, ), + }) + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + +@implements( + [ + aten.as_strided.default, + ] +) +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.as_strided(NF4Tensor) only support dim <= 2 but got dim=") +def nf4_as_strided(aten_op, args, kwargs=None): + nf4tensor = args[0] + size = args[1] + stride = tuple(args[2]) + storage_offset = args[3] + if math.prod(size) != nf4tensor.numel(): + raise NotImplementedError(f"aten.as_strided(NF4Tensor) different numel={nf4tensor.numel()} and size={size}") + if stride != make_contiguous_strides_for(size): + raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support continuous stride={make_contiguous_strides_for(size)} but got stride={stride}") + if nf4tensor.storage_offset() != storage_offset: + raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support original storage offset {nf4tensor.storage_offset()} but got {storage_offset}") + kwargs = { + "size": torch.Size(size), + "stride": stride, + "storage_offset": storage_offset, + } + return NF4Tensor(*construct_nf4_args(nf4tensor, kwargs)) + + @implements([torch.ops.aten._to_copy.default]) def _to_copy(func, *args, **kwargs): if not args[0][0].is_contiguous(): @@ -120,6 +345,31 @@ def copy_(func, *args, **kwargs): return original.copy_(same_meta_nf4) +@implements( + [ + aten.is_pinned.default, + ] +) +def nf4_is_pinned(aten_op, args, kwargs=None): + nf4tensor = args[0] + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + if not aten_op(inner_tensor, *(args[1:]), **kwargs): + return False + return True + + +@implements( + [ + aten._pin_memory.default, + ] +) +def nf4_pin_memory(aten_op, args, kwargs=None): + nf4tensor = args[0] + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + + @dataclass class SubclassTensorArgs: original_shape: torch.Size @@ -224,7 +474,7 @@ def from_tensor( block_size: int, scaler_block_size: int, ): - assert inpt_tensor.dim() <= 2 + assert inpt_tensor.dim() <= 2, f"expect input tensor dim <= 2 but got dim = {inpt_tensor.dim()}" assert ( inpt_tensor.numel() % block_size == 0 ), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}" @@ -375,7 +625,7 @@ def dequantize_scalers( @staticmethod def convert_to_norm_float_weight( - inpt_tensor: torch.Tensor, n_blocks: int, block_size: int, nf4: torch.tensor + inpt_tensor: torch.Tensor, n_blocks: int, block_size: int, nf4: torch.Tensor ) -> torch.Tensor: """Convert a tensor to the normalized float weight format""" flattened_tensor = inpt_tensor.flatten() @@ -393,9 +643,13 @@ def convert_to_norm_float_weight( scaled_blocks = blocks / scales # Returns a flattened tensor with each element quantized to nf4 index - quantized_blocks = NF4Tensor.quantize_tensor_nearest( - scaled_blocks.flatten(), nf4 - ) + # See Note: Quantize in Chunks + quantized_blocks = torch.empty(numel, dtype=torch.uint8, device=inpt_tensor.device) + flattened = scaled_blocks.flatten() + for chunk_num in range(math.ceil(numel / CHUNK_SIZE)): + start = chunk_num * CHUNK_SIZE + end = min(start + CHUNK_SIZE, numel) + quantized_blocks[start:end] = NF4Tensor.quantize_tensor_nearest(flattened[start:end], nf4).to(torch.uint8) # Combine the quantized elements into uint8 values # This lays out two consecutive elements in the same byte @@ -435,7 +689,7 @@ def get_original_weight(self) -> torch.Tensor: @staticmethod def quantize_tensor_nearest( - value: torch.float16, nf4: torch.Tensor + value: torch.Tensor, nf4: torch.Tensor ) -> torch.Tensor: """Quantize a float16 tensor to nf4 format to nearest and not rounded up""" value = value.unsqueeze(-1) # (numel, 1) @@ -445,36 +699,15 @@ def quantize_tensor_nearest( return closest_nf4 @staticmethod - - # inconsistently. - - # defined in `torch._C.TensorBase`. def dequantize(value: torch.Tensor, nf4: torch.Tensor) -> torch.Tensor: """Dequantize a nf4 value to bfloat16 format""" # return nf4.index_select(0, value) return nf4[value] - def unpack( - self, - ) -> Tuple[ - int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Size - ]: - - # Size]` but got `Tuple[int, int, int, Tensor, Tensor, Tensor, Tensor]`. - return ( - self.block_size, - self.n_blocks, - self.scaler_block_size, - self.quantized_scalers, - self.quantization_factor, - self.scaler_mean, - self.quantized_data, - ) - - def __repr__(self): + def __repr__(self) -> str: return f"Quantized Data: {self.quantized_data}\nScalers: {self.quantized_scalers}\n" - def __str__(self): + def __str__(self) -> str: return f"NF4Tensor({self.shape}, {self.block_size})" def __tensor_flatten__(self): @@ -501,9 +734,6 @@ def __tensor_flatten__(self): ], ctx @staticmethod - - # `typing.Dict[, ]` to avoid runtime subscripting errors. - def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride): assert len(inner_tensors) == 5, "Expected 5 inner tensors" return NF4Tensor( @@ -565,20 +795,75 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return func(*args, **kwargs) -class LinearNF4(torch.autograd.Function): - @staticmethod + def fsdp_pre_all_gather(self, mesh: DeviceMesh) -> Tuple[Tuple[torch.Tensor, ...], Any]: + return ( + self.quantized_scalers, + self.quantization_factor, + self.quantized_data, + ), ( + SubclassTensorArgs( + self.size(), + self.stride(), + self.storage_offset(), + self.dtype, + self.device, + self.requires_grad, + ), + self.block_size, + self.n_blocks, + self.scaler_block_size, + self.scaler_mean, + self.nf4, + mesh.get_group().size(), + ) - # inconsistently. + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]: + (quantized_scalers, quantization_factor, quantized_data) = all_gather_outputs + (tensor_meta, block_size, n_blocks, scaler_block_size, scaler_mean, nf4, pg_size) = metadata + if len(tensor_meta.original_shape) != 2: + raise NotImplementedError(f"only support 2D shape but got dim={len(tensor_meta.original_shape)}") + tensor_meta.original_shape = torch.Size((tensor_meta.original_shape[0] * pg_size, tensor_meta.original_shape[1])) + if out is not None: + # TODO: add param dtype for mixed precision + assert isinstance(out, NF4Tensor), f"{type(out)}" + assert ( + quantized_scalers.untyped_storage().data_ptr() + == out.quantized_scalers.untyped_storage().data_ptr() and + quantization_factor.untyped_storage().data_ptr() + == out.quantization_factor.untyped_storage().data_ptr() and + quantized_data.untyped_storage().data_ptr() + == out.quantized_data.untyped_storage().data_ptr() + ), f"Expects out's data to be the all-gather output" + return + return NF4Tensor( + tensor_meta, + block_size, + n_blocks, + scaler_block_size, + quantized_scalers, + quantization_factor, + scaler_mean, + quantized_data, + nf4, + ), (quantized_scalers, quantization_factor, quantized_data) + + +class LinearNF4(torch.autograd.Function): + @staticmethod def forward(ctx, input: torch.Tensor, weight: NF4Tensor): """Save the quantized nf4 weight for backward pass""" ctx.nf4_weight = weight return F.linear(input, weight.to(input.dtype)) @staticmethod - - # inconsistently. - def backward(ctx, grad_output): """The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)""" weight: NF4Tensor = ctx.nf4_weight @@ -613,12 +898,33 @@ def decorator(func): @implements_torch_function(torch.Tensor.to) def function_to_dtype(*args, **kwargs): - if isinstance(args[0], NF4Tensor) and isinstance(args[1], torch.dtype): + tensor = args[0] + if isinstance(args[1], torch.dtype): # Tensor.to(dtype, non_blocking, copy, memory_format) - return args[0].get_original_weight().to(*args[1:], **kwargs) + return tensor.get_original_weight().to(*args[1:], **kwargs) + elif ( + isinstance(args[1], torch.device) or ( + isinstance(args[1], str) and ( + args[1] == "cpu" or args[1].startswith("cuda") + ) + ) + ) and len(args) == 2: + # Tensor.to(device, non_blocking) + device = args[1] + updated_attrs = call_from_inner_tensors(tensor, "to", args[1:], kwargs) + updated_attrs["device"] = device + return NF4Tensor(*construct_nf4_args(tensor, updated_attrs)) else: # Tensor.to(device, dtype, non_blocking, copy, memory_format) # Tensor.to(other, non_blocking, copy) raise NotImplementedError( f"NF4Tensor.to({args[1:]}, {kwargs}) is not supported, passing to dispatch" ) + + +@implements_torch_function(torch.Tensor.cpu) +def function_cpu(*args, **kwargs): + nf4tensor = args[0] + updated_attrs = call_from_inner_tensors(nf4tensor, "cpu", args[1:], kwargs) + updated_attrs["device"] = "cpu" + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) diff --git a/torchao/sparsity/README.md b/torchao/sparsity/README.md index f7efe5b6a5..b18e996b58 100644 --- a/torchao/sparsity/README.md +++ b/torchao/sparsity/README.md @@ -44,7 +44,7 @@ The handoff point between these two pieces are sparse weights stored in a dense This also allows users with existing sparse weights in a dense format to take advantage of our fast sparse kernels. We anticipate many users to come up with their own custom frontend masking solution or to use another third party solution, as this is an active area of research. -![pruning_flow](https://private-user-images.githubusercontent.com/8041643/324612475-3873655f-3eab-40c7-8070-722b3eef4444.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTM5MjYwODAsIm5iZiI6MTcxMzkyNTc4MCwicGF0aCI6Ii84MDQxNjQzLzMyNDYxMjQ3NS0zODczNjU1Zi0zZWFiLTQwYzctODA3MC03MjJiM2VlZjQ0NDQucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI0MDQyNCUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNDA0MjRUMDIyOTQwWiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9N2ZjZTAwNzgyMjc4MGE3ZDZlYTQ3MDZkOTA3YTkwM2I3ODJiYjg4NzE2N2E3ZGJjZGVkZDhjYjJhMTgwOThhOSZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QmYWN0b3JfaWQ9MCZrZXlfaWQ9MCZyZXBvX2lkPTAifQ.SXj5_j7CC61CB6hanWrubY7k4Fq9Oko985qD7qaOAy4) +![pruning_flow](https://private-user-images.githubusercontent.com/8041643/324607153-ba91eaca-14ce-4608-9db8-6cbb9ea1f9ec.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTQ1OTgzOTYsIm5iZiI6MTcxNDU5ODA5NiwicGF0aCI6Ii84MDQxNjQzLzMyNDYwNzE1My1iYTkxZWFjYS0xNGNlLTQ2MDgtOWRiOC02Y2JiOWVhMWY5ZWMucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI0MDUwMSUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNDA1MDFUMjExNDU2WiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9YWVjOWQ5ZjFjMWZmNjg4ZTgyZGFkYWU3ZDQ3MDBjMTZkNzczZWQxYzczN2ZiM2ZjZGY0NjUwMGUwY2UwZDA1YyZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QmYWN0b3JfaWQ9MCZrZXlfaWQ9MCZyZXBvX2lkPTAifQ.ni5F_wDhNkeupMJ84bFNxhaSO3xPH-9zecz_933Uu68) Below, we provide an example of accelerating a model with 2:4 sparsity + bf16 using our PyTorch APIs. @@ -97,7 +97,7 @@ Note that this section focuses on **pruning**, instead of **sparse training**. T Roughly, the flow for achieving a more performant pruned model looks like this: -![flow](https://private-user-images.githubusercontent.com/8041643/324612485-c7008b1d-6c1a-4424-b3d1-34c55a25460d.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTM5MjYwODAsIm5iZiI6MTcxMzkyNTc4MCwicGF0aCI6Ii84MDQxNjQzLzMyNDYxMjQ4NS1jNzAwOGIxZC02YzFhLTQ0MjQtYjNkMS0zNGM1NWEyNTQ2MGQucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI0MDQyNCUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNDA0MjRUMDIyOTQwWiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9NWVlY2I3OTBlM2ViZTZiZmMwYmQzYjA3NjM1ZDY3NmZkZjNiMzk3M2JhMzkwOTYyZmM4Mjc5MWJkYTI2M2MxMiZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QmYWN0b3JfaWQ9MCZrZXlfaWQ9MCZyZXBvX2lkPTAifQ.neMkWGtDbGGw0Vn7MA1RJ_Q2iAvGIkcjRD-pLAtNd5k) +![flow](https://private-user-images.githubusercontent.com/8041643/324607146-53542488-65ce-4d99-a3ae-21e724f89467.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTQ1OTgzOTYsIm5iZiI6MTcxNDU5ODA5NiwicGF0aCI6Ii84MDQxNjQzLzMyNDYwNzE0Ni01MzU0MjQ4OC02NWNlLTRkOTktYTNhZS0yMWU3MjRmODk0NjcucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI0MDUwMSUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNDA1MDFUMjExNDU2WiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9ZWJlYWMzZDFmNzc2NDM1MGI2ODNlMjUxZjQxYTAwYzhhNzBkNGU2ZGIwYTg4NzA5Yjk3N2JkNzI4MmUyNzg3NiZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QmYWN0b3JfaWQ9MCZrZXlfaWQ9MCZyZXBvX2lkPTAifQ.Hxk5XMuJXhNsORVNNgcKNRCk7W1nT4CndLTAC3Oz0qE) The general idea behind pruning is that we can mask out some of the weights of a trained neural network and recover any accuracy loss. The resultant pruned model can be run on optimized kernels that take advantage of this sparsity for accelerated inference.