diff --git a/tests/test_verify_correctness.py b/tests/test_verify_correctness.py new file mode 100644 index 0000000000000..624a3bbc6e329 --- /dev/null +++ b/tests/test_verify_correctness.py @@ -0,0 +1,196 @@ +#!/usr/bin/env pytest +import importlib +import operator +import unittest + +import torch + +import torchdynamo +import torchdynamo.config as config +from torchdynamo.optimizations import backends +from torchdynamo.optimizations.inference import fixed_strategy1 +from torchdynamo.optimizations.inference import offline_autotuner +from torchdynamo.testing import same + + +def has_onnxruntime(): + try: + importlib.import_module("onnxruntime") + return True + except ImportError: + return False + + +def has_ipex(): + try: + importlib.import_module("intel_extension_for_pytorch") + return True + except ImportError: + return False + + +class Seq(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 10), + torch.nn.Sigmoid(), + ) + + def forward(self, x): + return self.layers(x) + + +class Conv_Bn_Relu(torch.nn.Module): + def __init__(self, in_channels, out_channels, **kwargs): + super(Conv_Bn_Relu, self).__init__() + self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) + + +def toy_example(a, b): + x = a / (torch.abs(a) + 1) + if b.sum() < 0: + b = b * -1 + return x * b + + +def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + # Checks if we're calling a function (i.e: + # operator.add) + if node.op == "call_function": + # The target attribute is the function + # that call_function calls. + if node.target == operator.mul: + node.target = operator.add + + gm.graph.lint() # Does some checks to make sure the + # Graph is well-formed. + + gm.recompile() + return gm + + +config.verify_correctness = True + + +class TestVerifyCorrectness(torchdynamo.testing.TestCase): + def test_example_inputs(self): + def fn(a, bc, d): + b, c = bc + return a / d - b / c + + def compiler_fn(graph, example_inputs): + nonlocal r1 + r1 = graph(*example_inputs)[0] + return graph.forward + + a = torch.empty(2).fill_(1) + b = torch.empty(2).fill_(2) + c = torch.empty(2).fill_(3) + d = 4 + r1 = None + r2 = fn(a, (b, c), d) + with torchdynamo.optimize_assert(compiler_fn): + r3 = fn(a, (b, c), d) + + self.assertIsNotNone(r1) + self.assertTrue(same(r1, r2)) + self.assertTrue(same(r1, r3)) + + def test_fixed_strategy1(self): + s = Seq() + i = torch.randn(10) + r1 = s(i) + with torchdynamo.optimize(fixed_strategy1): + r2 = s(i) + self.assertTrue(same(r1, r2)) + + def test_nnc(self): + s = Seq() + i = torch.randn(10) + r1 = s(i) + with torchdynamo.optimize("nnc"): + r2 = s(i) + self.assertTrue(same(r1, r2)) + + def test_incorrect_verify_true(self): + """ + Even the bad optimization return a graph that + is not functionally equal to the original graph; + When config.verify_correctness=True, it will + check the correctness of outputs and fallback using + the original graph + """ + i1 = torch.randn(10) + i2 = torch.randn(10) + + def incorrect_compile_fn(gm, example_inputs): + return transform(gm).forward + + r1 = toy_example(i1, i2) + with torchdynamo.optimize(incorrect_compile_fn): + r2 = toy_example(i1, i2) + self.assertTrue(same(r1, r2)) + + def test_incorrect_verify_false(self): + config.verify_correctness = False + """ + The bad optimization return a graph that + is not functionally equal to the original graph; + When config.verify_correctness=False, wrong outputs + will return + """ + i1 = torch.randn(10) + i2 = torch.randn(10) + + def incorrect_compile_fn(gm, example_inputs): + return transform(gm).forward + + r1 = toy_example(i1, i2) + with torchdynamo.optimize(incorrect_compile_fn): + r2 = toy_example(i1, i2) + self.assertTrue(not same(r1, r2)) + config.verify_correctness = True + + @unittest.skipIf(not has_onnxruntime(), "requires onnxruntime") + def test_export(self): + s = Seq() + i = torch.randn(10) + r1 = s(i) + with torchdynamo.optimize_assert(offline_autotuner): + r2 = s(i) + self.assertTrue(same(r1, r2)) + + @unittest.skipIf(not has_ipex(), "requires ipex") + def test_ipex_fp32(self): + model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1) + model = model.to(memory_format=torch.channels_last) + model = model.eval() + input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last) + r1 = model(input) + with torchdynamo.optimize(backends.ipex_fp32), torch.no_grad(): + r2 = model(input) + self.assertTrue(same(r1, r2)) + self.assertEqual(r2.dtype, torch.float32) + + @unittest.skipIf(not has_ipex(), "requires ipex") + def test_ipex_bf16(self): + model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1) + model = model.to(memory_format=torch.channels_last) + model = model.eval() + input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last) + r1 = model(input) + with torchdynamo.optimize( + backends.ipex_bf16 + ), torch.no_grad(), torch.cpu.amp.autocast(): + r2 = model(input) + self.assertTrue(same(r1, r2.float(), tol=0.1)) + self.assertEqual(r2.dtype, torch.bfloat16) diff --git a/torchdynamo/config.py b/torchdynamo/config.py index 0c37fe8df0158..0b6abc75a6520 100644 --- a/torchdynamo/config.py +++ b/torchdynamo/config.py @@ -7,6 +7,9 @@ # print out lots of stuff debug = False +# verify the correctness of optimized backend +verify_correctness = False + # an unreasonable amount of debug printouts trace = False diff --git a/torchdynamo/convert_frame.py b/torchdynamo/convert_frame.py index 9d7626b52e491..35e68f13567d0 100644 --- a/torchdynamo/convert_frame.py +++ b/torchdynamo/convert_frame.py @@ -17,6 +17,7 @@ from .bytecode_analysis import remove_pointless_jumps from .bytecode_transformation import is_generator from .bytecode_transformation import transform_code_object +from .eval_frame import WrapperBackend from .eval_frame import skip_code from .exc import InternalTorchDynamoError from .exc import TorchRuntimeError @@ -59,7 +60,7 @@ def fx_forward_from_src_skip_result(*args, **kwargs): return result -def wrap_compiler_fn(compiler_fn): +def _wrap_compiler_fn(compiler_fn): """Expand backend strings to functions""" if compiler_fn == "inductor": from torchinductor.compile_fx import compile_fx @@ -73,6 +74,19 @@ def wrap_compiler_fn(compiler_fn): return compiler_fn +def wrap_compiler_fn(compiler_fn): + """WrapperBackend if config.verify_correctness is True""" + wrapped_compiler_fn = _wrap_compiler_fn(compiler_fn) + + if config.verify_correctness: + # wrap backend if verify_correctness is True + wrapper_backend_compiler_fn = WrapperBackend(wrapped_compiler_fn) + + return wrapper_backend_compiler_fn + + return wrapped_compiler_fn + + def wrap_convert_context(fn): """ Context manager to: diff --git a/torchdynamo/eval_frame.py b/torchdynamo/eval_frame.py index 809002cb07a48..b64e257fd6391 100644 --- a/torchdynamo/eval_frame.py +++ b/torchdynamo/eval_frame.py @@ -1,12 +1,21 @@ import contextlib +import copy import functools import logging import threading +import torch + +from torchdynamo.utils import checkpoint_params +from torchdynamo.utils import clone_inputs + from . import config from . import convert_frame from . import skipfiles from .mutation_guard import install_generation_tagging_new +from .utils import same + +log = logging.getLogger(__name__) try: from . import _eval_frame @@ -124,6 +133,47 @@ def _optimize_catch_errors(compile_fn, backend_ctx_ctor=null_context): ) +class WrapperBackend: + def __init__(self, backend=None): + self.backend = backend + + @property + def example_inputs(self): + return clone_inputs(self.original_example_inputs) + + def __call__(self, gm: torch.fx.GraphModule, example_inputs): + + self.restore = checkpoint_params(gm) + self.original_example_inputs = clone_inputs(example_inputs) + self.gm = gm + copy_gm = copy.deepcopy(self.gm) + self.candidate = self.backend(copy_gm, self.original_example_inputs) + + if self.candidate is None or self.candidate is self.gm.forward: + return self.gm.forward + + if not config.verify_correctness: + return self.candidate + + # if verify_correctness=True + try: + correct = self.gm.forward(*self.example_inputs) + result = self.candidate(*self.example_inputs) + + # TODO: replace `same` function with the one in testing + if same(correct, result): + return self.candidate + + print(f"incorrect results of backend {self}") + return self.gm.forward + + except Exception: + log.exception("error in verify_correctness") + return self.gm.forward + finally: + self.restore() + + def optimize(backend, nopython=False): """ The main entrypoint of TorchDynamo. Do graph capture and call diff --git a/torchdynamo/testing.py b/torchdynamo/testing.py index 02e58ace2d233..91d3e0ac8885d 100644 --- a/torchdynamo/testing.py +++ b/torchdynamo/testing.py @@ -17,6 +17,7 @@ from .bytecode_transformation import is_generator from .bytecode_transformation import transform_code_object from .guards import GuardedCode +from .utils import same unsupported = torchdynamo.eval_frame.unsupported three = 3 @@ -69,63 +70,6 @@ def reduce_to_scalar_loss(out): raise NotImplementedError("Don't know how to reduce") -def same(a, b, cos_similarity=False, tol=1e-4, equal_nan=False): - """Check correctness to see if a and b match""" - if isinstance(a, (list, tuple, torch.nn.ParameterList, torch.Size)): - assert isinstance(b, (list, tuple)), f"type mismatch {type(a)} {type(b)}" - return len(a) == len(b) and all( - same(ai, bi, cos_similarity, tol, equal_nan) for ai, bi in zip(a, b) - ) - elif isinstance(a, dict): - assert isinstance(b, dict) - assert set(a.keys()) == set( - b.keys() - ), f"keys mismatch {set(a.keys())} == {set(b.keys())}" - for k in a.keys(): - if not (same(a[k], b[k], cos_similarity, tol, equal_nan=equal_nan)): - print("Accuracy failed for key name", k) - return False - return True - elif isinstance(a, torch.Tensor): - if a.is_sparse: - assert b.is_sparse - a = a.to_dense() - b = b.to_dense() - assert isinstance(b, torch.Tensor) - if cos_similarity: - # TRT will bring error loss larger than current threshold. Use cosine similarity as replacement - a = a.flatten().to(torch.float32) - b = b.flatten().to(torch.float32) - res = torch.nn.functional.cosine_similarity(a, b, dim=0, eps=1e-6) - if res < 0.99: - print(f"Similarity score={res.cpu().numpy()}") - return res >= 0.99 - else: - return torch.allclose(a, b, atol=tol, rtol=tol, equal_nan=equal_nan) - elif isinstance(a, (str, int, float, type(None), bool, torch.device)): - return a == b - elif type(a).__name__ in ( - "MaskedLMOutput", - "Seq2SeqLMOutput", - "CausalLMOutputWithCrossAttentions", - "LongformerMaskedLMOutput", - "Instances", - "SquashedNormal", - "Boxes", - "Normal", - "TanhTransform", - "Foo", - "Variable", - ): - assert type(a) is type(b) - return all( - same(getattr(a, key), getattr(b, key), cos_similarity, tol, equal_nan) - for key in a.__dict__.keys() - ) - else: - raise RuntimeError(f"unsupported type: {type(a).__name__}") - - def debug_dir(): path = os.path.join(os.path.dirname(__file__), "../debug") if not os.path.exists(path): diff --git a/torchdynamo/utils.py b/torchdynamo/utils.py index 3fe52239a77a3..f2b6f3422f8d7 100644 --- a/torchdynamo/utils.py +++ b/torchdynamo/utils.py @@ -353,3 +353,60 @@ def rename_implicit(v): # to support .1 etc see guards.py and _eval_frame.c return f"___implicit{m.group(1)}" return v + + +def same(a, b, cos_similarity=False, tol=1e-4, equal_nan=False): + """Check correctness to see if a and b match""" + if isinstance(a, (list, tuple, torch.nn.ParameterList, torch.Size)): + assert isinstance(b, (list, tuple)), f"type mismatch {type(a)} {type(b)}" + return len(a) == len(b) and all( + same(ai, bi, cos_similarity, tol, equal_nan) for ai, bi in zip(a, b) + ) + elif isinstance(a, dict): + assert isinstance(b, dict) + assert set(a.keys()) == set( + b.keys() + ), f"keys mismatch {set(a.keys())} == {set(b.keys())}" + for k in a.keys(): + if not (same(a[k], b[k], cos_similarity, tol, equal_nan=equal_nan)): + print("Accuracy failed for key name", k) + return False + return True + elif isinstance(a, torch.Tensor): + if a.is_sparse: + assert b.is_sparse + a = a.to_dense() + b = b.to_dense() + assert isinstance(b, torch.Tensor) + if cos_similarity: + # TRT will bring error loss larger than current threshold. Use cosine similarity as replacement + a = a.flatten().to(torch.float32) + b = b.flatten().to(torch.float32) + res = torch.nn.functional.cosine_similarity(a, b, dim=0, eps=1e-6) + if res < 0.99: + print(f"Similarity score={res.cpu().numpy()}") + return res >= 0.99 + else: + return torch.allclose(a, b, atol=tol, rtol=tol, equal_nan=equal_nan) + elif isinstance(a, (str, int, float, type(None), bool, torch.device)): + return a == b + elif type(a).__name__ in ( + "MaskedLMOutput", + "Seq2SeqLMOutput", + "CausalLMOutputWithCrossAttentions", + "LongformerMaskedLMOutput", + "Instances", + "SquashedNormal", + "Boxes", + "Normal", + "TanhTransform", + "Foo", + "Variable", + ): + assert type(a) is type(b) + return all( + same(getattr(a, key), getattr(b, key), cos_similarity, tol, equal_nan) + for key in a.__dict__.keys() + ) + else: + raise RuntimeError(f"unsupported type: {type(a).__name__}")