Skip to content

Commit

Permalink
[UnitTests] Require cached fixtures to be copy-able, with opt-in.
Browse files Browse the repository at this point in the history
Previously, any class that doesn't raise a TypeError in copy.deepcopy
could be used as a return value in a @tvm.testing.fixture.  This has
the possibility of incorrectly copying classes inherit the default
object.__reduce__ implementation.  Therefore, only classes that
explicitly implement copy functionality (e.g. __deepcopy__ or
__getstate__/__setstate__), or that are explicitly listed in
tvm.testing._fixture_cache are allowed to be cached.
  • Loading branch information
Lunderberg committed Jul 14, 2021
1 parent 1d7a9e9 commit 9b99050
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 12 deletions.
76 changes: 64 additions & 12 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def test_something():
"""
import collections
import copy
import copyreg
import ctypes
import functools
import logging
import os
Expand Down Expand Up @@ -1160,6 +1162,60 @@ def wraps(func):
return wraps(func)


class _DeepCopyAllowedClasses(dict):
def __init__(self, allowed_class_list):
self.allowed_class_list = allowed_class_list
super().__init__()

def get(self, key, *args, **kwargs):
"""Overrides behavior of copy.deepcopy to avoid implicit copy.
By default, copy.deepcopy uses a dict of id->object to track
all objects that it has seen, which is passed as the second
argument to all recursive calls. This class is intended to be
passed in instead, and inspects the type of all objects being
copied.
Where copy.deepcopy does a best-effort attempt at copying an
object, for unit tests we would rather have all objects either
be copied correctly, or to throw an error. Classes that
define an explicit method to perform a copy are allowed, as
are any explicitly listed classes. Classes that would fall
back to using object.__reduce__, and are not explicitly listed
as safe, will throw an exception.
"""
obj = ctypes.cast(key, ctypes.py_object).value
cls = type(obj)
if (
cls in copy._deepcopy_dispatch
or issubclass(cls, type)
or getattr(obj, "__deepcopy__", None)
or copyreg.dispatch_table.get(cls)
or cls.__reduce__ is not object.__reduce__
or cls.__reduce_ex__ is not object.__reduce_ex__
or cls in self.allowed_class_list
):
return super().get(key, *args, **kwargs)

rfc_url = (
"https://github.com/apache/tvm-rfcs/blob/main/rfcs/0007-parametrized-unit-tests.md"
)
raise TypeError(
(
f"Cannot copy fixture of type {cls.__name__}. TVM fixture caching "
"is limited to objects that explicitly provide the ability "
"to be copied (e.g. through __deepcopy__, __getstate__, or __setstate__),"
"and forbids the use of the default `object.__reduce__` and "
"`object.__reduce_ex__`. For third-party classes that are "
"safe to use with copy.deepcopy, please add the class to "
"the arguments of _DeepCopyAllowedClasses in tvm.testing._fixture_cache.\n"
"\n"
f"For discussion on this restriction, please see {rfc_url}."
)
)


def _fixture_cache(func):
cache = {}

Expand Down Expand Up @@ -1199,18 +1255,14 @@ def wrapper(*args, **kwargs):
except KeyError:
cached_value = cache[cache_key] = func(*args, **kwargs)

try:
yield copy.deepcopy(cached_value)
except TypeError as e:
rfc_url = (
"https://github.com/apache/tvm-rfcs/blob/main/rfcs/"
"0007-parametrized-unit-tests.md#unresolved-questions"
)
message = (
"TVM caching of fixtures can only be used on serializable data types, not {}.\n"
"Please see {} for details/discussion."
).format(type(cached_value), rfc_url)
raise TypeError(message) from e
yield copy.deepcopy(
cached_value,
# allowed_class_list should be a list of classes that
# are safe to copy using copy.deepcopy, but do not
# implement __deepcopy__, __reduce__, or
# __reduce_ex__.
_DeepCopyAllowedClasses(allowed_class_list=[]),
)

finally:
# Clear the cache once all tests that use a particular fixture
Expand Down
39 changes: 39 additions & 0 deletions tests/python/unittest/test_tvm_testing_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,44 @@ def test_num_uses_cached(self):
assert self.num_uses_broken_cached_fixture == 0


@pytest.mark.skipif(
bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0"))),
reason="Cannot test cache behavior while caching is disabled",
)
class TestCacheableTypes:
class EmptyClass:
pass

@tvm.testing.fixture(cache_return_value=True)
def uncacheable_fixture(self):
return self.EmptyClass()

@pytest.mark.xfail(reason="Requests cached fixture of uncacheable type", strict=True)
def test_uses_uncacheable(self, uncacheable_fixture):
pass

class ImplementsReduce:
def __reduce__(self):
return super().__reduce__()

@tvm.testing.fixture(cache_return_value=True)
def fixture_with_reduce(self):
return self.ImplementsReduce()

def test_uses_reduce(self, fixture_with_reduce):
pass

class ImplementsDeepcopy:
def __deepcopy__(self, memo):
return type(self)()

@tvm.testing.fixture(cache_return_value=True)
def fixture_with_deepcopy(self):
return self.ImplementsDeepcopy()

def test_uses_deepcopy(self, fixture_with_deepcopy):
pass


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))

0 comments on commit 9b99050

Please sign in to comment.