Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compile tests to test suite #906

Merged
merged 3 commits into from
Sep 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def new_test(self, value=value):


class TorchAOBasicTestCase(common_utils.TestCase):
"""Basic test case for tensor subclasses
"""
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

Expand Down Expand Up @@ -142,6 +140,66 @@ def test_linear(self, device, dtype):
lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor)
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)


class TorchAOCompileTestCase(common_utils.TestCase):
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

TENSOR_SUBCLASS = AffineQuantizedTensor
FACTORY_FN = to_affine_quantized_intx
kwargs = {
"mapping_type": MappingType.ASYMMETRIC,
"block_size": (1, 32),
"target_dtype": torch.uint8,
}
# minimum sqnr for linear operation when the weight is quantized to low precision
# with the above setting
LINEAR_MIN_SQNR = 40
COMPILE_MIN_SQNR = 50

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_input_output_tensor_subclass(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
def f(tensor):
return tensor

ref = f(lp_tensor)
f = torch.compile(f)
compiled = f(lp_tensor)
self.assertTrue(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for most of these tests, you probably want to fully run the test in both eager and compile and assert that outputs are the same, rather than just testing if the output is / is not a subclass?

self.assertEqual(ref.dequantize(), compiled.dequantize())

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_input_tensor_subclass(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
def f(tensor):
return tensor.dequantize()

ref = f(lp_tensor)
f = torch.compile(f)
compiled = f(lp_tensor)
self.assertFalse(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS))
self.assertEqual(ref, compiled)

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_output_tensor_subclass(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
def f(hp_tensor):
return self.FACTORY_FN(hp_tensor, **self.kwargs)

ref = f(hp_tensor)
f = torch.compile(f)
compiled = f(hp_tensor)
self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS))
# bfloat16 seems to result in much larger numerical differences
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh it seems that bfloat16 gives large errors right now for compile v.s. no compile

if dtype != torch.bfloat16:
self.assertGreater(torchao.quantization.utils.compute_error(ref.dequantize(), compiled.dequantize()), self.COMPILE_MIN_SQNR)

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_linear_compile(self, device, dtype):
Expand All @@ -155,7 +213,10 @@ def test_linear_compile(self, device, dtype):
lp_res = torch.compile(l)(hp_act_tensor)
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)



common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)

if __name__ == "__main__":
unittest.main()
Loading