Skip to content

Commit

Permalink
Implement verify_correctness pytorch#179 (pytorch#252)
Browse files Browse the repository at this point in the history
* Wrapperbackend to enable verifying corretness of backends; set config.verify_correctness as True to enable it.

* move testing.same() to utils.py
  • Loading branch information
pyjhzwh authored May 19, 2022
1 parent 83d4cab commit 909d09c
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 58 deletions.
196 changes: 196 additions & 0 deletions tests/test_verify_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
#!/usr/bin/env pytest
import importlib
import operator
import unittest

import torch

import torchdynamo
import torchdynamo.config as config
from torchdynamo.optimizations import backends
from torchdynamo.optimizations.inference import fixed_strategy1
from torchdynamo.optimizations.inference import offline_autotuner
from torchdynamo.testing import same


def has_onnxruntime():
try:
importlib.import_module("onnxruntime")
return True
except ImportError:
return False


def has_ipex():
try:
importlib.import_module("intel_extension_for_pytorch")
return True
except ImportError:
return False


class Seq(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.Sigmoid(),
)

def forward(self, x):
return self.layers(x)


class Conv_Bn_Relu(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(Conv_Bn_Relu, self).__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(self.bn(self.conv(x)))


def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b


def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
# Checks if we're calling a function (i.e:
# operator.add)
if node.op == "call_function":
# The target attribute is the function
# that call_function calls.
if node.target == operator.mul:
node.target = operator.add

gm.graph.lint() # Does some checks to make sure the
# Graph is well-formed.

gm.recompile()
return gm


config.verify_correctness = True


class TestVerifyCorrectness(torchdynamo.testing.TestCase):
def test_example_inputs(self):
def fn(a, bc, d):
b, c = bc
return a / d - b / c

def compiler_fn(graph, example_inputs):
nonlocal r1
r1 = graph(*example_inputs)[0]
return graph.forward

a = torch.empty(2).fill_(1)
b = torch.empty(2).fill_(2)
c = torch.empty(2).fill_(3)
d = 4
r1 = None
r2 = fn(a, (b, c), d)
with torchdynamo.optimize_assert(compiler_fn):
r3 = fn(a, (b, c), d)

self.assertIsNotNone(r1)
self.assertTrue(same(r1, r2))
self.assertTrue(same(r1, r3))

def test_fixed_strategy1(self):
s = Seq()
i = torch.randn(10)
r1 = s(i)
with torchdynamo.optimize(fixed_strategy1):
r2 = s(i)
self.assertTrue(same(r1, r2))

def test_nnc(self):
s = Seq()
i = torch.randn(10)
r1 = s(i)
with torchdynamo.optimize("nnc"):
r2 = s(i)
self.assertTrue(same(r1, r2))

def test_incorrect_verify_true(self):
"""
Even the bad optimization return a graph that
is not functionally equal to the original graph;
When config.verify_correctness=True, it will
check the correctness of outputs and fallback using
the original graph
"""
i1 = torch.randn(10)
i2 = torch.randn(10)

def incorrect_compile_fn(gm, example_inputs):
return transform(gm).forward

r1 = toy_example(i1, i2)
with torchdynamo.optimize(incorrect_compile_fn):
r2 = toy_example(i1, i2)
self.assertTrue(same(r1, r2))

def test_incorrect_verify_false(self):
config.verify_correctness = False
"""
The bad optimization return a graph that
is not functionally equal to the original graph;
When config.verify_correctness=False, wrong outputs
will return
"""
i1 = torch.randn(10)
i2 = torch.randn(10)

def incorrect_compile_fn(gm, example_inputs):
return transform(gm).forward

r1 = toy_example(i1, i2)
with torchdynamo.optimize(incorrect_compile_fn):
r2 = toy_example(i1, i2)
self.assertTrue(not same(r1, r2))
config.verify_correctness = True

@unittest.skipIf(not has_onnxruntime(), "requires onnxruntime")
def test_export(self):
s = Seq()
i = torch.randn(10)
r1 = s(i)
with torchdynamo.optimize_assert(offline_autotuner):
r2 = s(i)
self.assertTrue(same(r1, r2))

@unittest.skipIf(not has_ipex(), "requires ipex")
def test_ipex_fp32(self):
model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1)
model = model.to(memory_format=torch.channels_last)
model = model.eval()
input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last)
r1 = model(input)
with torchdynamo.optimize(backends.ipex_fp32), torch.no_grad():
r2 = model(input)
self.assertTrue(same(r1, r2))
self.assertEqual(r2.dtype, torch.float32)

@unittest.skipIf(not has_ipex(), "requires ipex")
def test_ipex_bf16(self):
model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1)
model = model.to(memory_format=torch.channels_last)
model = model.eval()
input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last)
r1 = model(input)
with torchdynamo.optimize(
backends.ipex_bf16
), torch.no_grad(), torch.cpu.amp.autocast():
r2 = model(input)
self.assertTrue(same(r1, r2.float(), tol=0.1))
self.assertEqual(r2.dtype, torch.bfloat16)
3 changes: 3 additions & 0 deletions torchdynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
# print out lots of stuff
debug = False

# verify the correctness of optimized backend
verify_correctness = False

# an unreasonable amount of debug printouts
trace = False

Expand Down
16 changes: 15 additions & 1 deletion torchdynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .bytecode_analysis import remove_pointless_jumps
from .bytecode_transformation import is_generator
from .bytecode_transformation import transform_code_object
from .eval_frame import WrapperBackend
from .eval_frame import skip_code
from .exc import InternalTorchDynamoError
from .exc import TorchRuntimeError
Expand Down Expand Up @@ -59,7 +60,7 @@ def fx_forward_from_src_skip_result(*args, **kwargs):
return result


def wrap_compiler_fn(compiler_fn):
def _wrap_compiler_fn(compiler_fn):
"""Expand backend strings to functions"""
if compiler_fn == "inductor":
from torchinductor.compile_fx import compile_fx
Expand All @@ -73,6 +74,19 @@ def wrap_compiler_fn(compiler_fn):
return compiler_fn


def wrap_compiler_fn(compiler_fn):
"""WrapperBackend if config.verify_correctness is True"""
wrapped_compiler_fn = _wrap_compiler_fn(compiler_fn)

if config.verify_correctness:
# wrap backend if verify_correctness is True
wrapper_backend_compiler_fn = WrapperBackend(wrapped_compiler_fn)

return wrapper_backend_compiler_fn

return wrapped_compiler_fn


def wrap_convert_context(fn):
"""
Context manager to:
Expand Down
50 changes: 50 additions & 0 deletions torchdynamo/eval_frame.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import contextlib
import copy
import functools
import logging
import threading

import torch

from torchdynamo.utils import checkpoint_params
from torchdynamo.utils import clone_inputs

from . import config
from . import convert_frame
from . import skipfiles
from .mutation_guard import install_generation_tagging_new
from .utils import same

log = logging.getLogger(__name__)

try:
from . import _eval_frame
Expand Down Expand Up @@ -124,6 +133,47 @@ def _optimize_catch_errors(compile_fn, backend_ctx_ctor=null_context):
)


class WrapperBackend:
def __init__(self, backend=None):
self.backend = backend

@property
def example_inputs(self):
return clone_inputs(self.original_example_inputs)

def __call__(self, gm: torch.fx.GraphModule, example_inputs):

self.restore = checkpoint_params(gm)
self.original_example_inputs = clone_inputs(example_inputs)
self.gm = gm
copy_gm = copy.deepcopy(self.gm)
self.candidate = self.backend(copy_gm, self.original_example_inputs)

if self.candidate is None or self.candidate is self.gm.forward:
return self.gm.forward

if not config.verify_correctness:
return self.candidate

# if verify_correctness=True
try:
correct = self.gm.forward(*self.example_inputs)
result = self.candidate(*self.example_inputs)

# TODO: replace `same` function with the one in testing
if same(correct, result):
return self.candidate

print(f"incorrect results of backend {self}")
return self.gm.forward

except Exception:
log.exception("error in verify_correctness")
return self.gm.forward
finally:
self.restore()


def optimize(backend, nopython=False):
"""
The main entrypoint of TorchDynamo. Do graph capture and call
Expand Down
58 changes: 1 addition & 57 deletions torchdynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .bytecode_transformation import is_generator
from .bytecode_transformation import transform_code_object
from .guards import GuardedCode
from .utils import same

unsupported = torchdynamo.eval_frame.unsupported
three = 3
Expand Down Expand Up @@ -69,63 +70,6 @@ def reduce_to_scalar_loss(out):
raise NotImplementedError("Don't know how to reduce")


def same(a, b, cos_similarity=False, tol=1e-4, equal_nan=False):
"""Check correctness to see if a and b match"""
if isinstance(a, (list, tuple, torch.nn.ParameterList, torch.Size)):
assert isinstance(b, (list, tuple)), f"type mismatch {type(a)} {type(b)}"
return len(a) == len(b) and all(
same(ai, bi, cos_similarity, tol, equal_nan) for ai, bi in zip(a, b)
)
elif isinstance(a, dict):
assert isinstance(b, dict)
assert set(a.keys()) == set(
b.keys()
), f"keys mismatch {set(a.keys())} == {set(b.keys())}"
for k in a.keys():
if not (same(a[k], b[k], cos_similarity, tol, equal_nan=equal_nan)):
print("Accuracy failed for key name", k)
return False
return True
elif isinstance(a, torch.Tensor):
if a.is_sparse:
assert b.is_sparse
a = a.to_dense()
b = b.to_dense()
assert isinstance(b, torch.Tensor)
if cos_similarity:
# TRT will bring error loss larger than current threshold. Use cosine similarity as replacement
a = a.flatten().to(torch.float32)
b = b.flatten().to(torch.float32)
res = torch.nn.functional.cosine_similarity(a, b, dim=0, eps=1e-6)
if res < 0.99:
print(f"Similarity score={res.cpu().numpy()}")
return res >= 0.99
else:
return torch.allclose(a, b, atol=tol, rtol=tol, equal_nan=equal_nan)
elif isinstance(a, (str, int, float, type(None), bool, torch.device)):
return a == b
elif type(a).__name__ in (
"MaskedLMOutput",
"Seq2SeqLMOutput",
"CausalLMOutputWithCrossAttentions",
"LongformerMaskedLMOutput",
"Instances",
"SquashedNormal",
"Boxes",
"Normal",
"TanhTransform",
"Foo",
"Variable",
):
assert type(a) is type(b)
return all(
same(getattr(a, key), getattr(b, key), cos_similarity, tol, equal_nan)
for key in a.__dict__.keys()
)
else:
raise RuntimeError(f"unsupported type: {type(a).__name__}")


def debug_dir():
path = os.path.join(os.path.dirname(__file__), "../debug")
if not os.path.exists(path):
Expand Down
Loading

0 comments on commit 909d09c

Please sign in to comment.