Skip to content

Commit

Permalink
[Bugfix] Fix visit_attrs error if its function pointer is equal to nu…
Browse files Browse the repository at this point in the history
…llptr (#8920)

* fix visit_attrs equals nullptr on python container object

* add a test a for python container object about function dir and getattr

* change test_ir_container.py to the pytest style

* update the style to fix ci error

* update the style of ir container to fix ci error
  • Loading branch information
Sen Yang authored Sep 8, 2021
1 parent 01aeeb1 commit f8b1df4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
20 changes: 20 additions & 0 deletions python/tvm/ir/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ def __getitem__(self, idx):
def __len__(self):
return _ffi_api.ArraySize(self)

def __dir__(self):
return sorted(dir(self.__class__) + ["type_key"])

def __getattr__(self, name):
if name == "handle":
raise AttributeError("handle is not set")
if name == "type_key":
return super().__getattr__(name)
raise AttributeError("%s has no attribute %s" % (str(type(self)), name))


@tvm._ffi.register_object
class Map(Object):
Expand All @@ -59,6 +69,16 @@ def __iter__(self):
for i in range(len(self)):
yield akvs[i * 2]

def __dir__(self):
return sorted(dir(self.__class__) + ["type_key"])

def __getattr__(self, name):
if name == "handle":
raise AttributeError("handle is not set")
if name == "type_key":
return super().__getattr__(name)
raise AttributeError("%s has no attribute %s" % (str(type(self)), name))

def keys(self):
return iter(self)

Expand Down
35 changes: 28 additions & 7 deletions tests/python/unittest/test_ir_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
from tvm import te
import numpy as np
Expand All @@ -34,6 +35,17 @@ def test_array_save_load_json():
assert a_loaded[1].value == 2


def test_dir_array():
a = tvm.runtime.convert([1, 2, 3])
assert dir(a)


def test_getattr_array():
a = tvm.runtime.convert([1, 2, 3])
assert getattr(a, "type_key") == "Array"
assert not hasattr(a, "test_key")


def test_map():
a = te.var("a")
b = te.var("b")
Expand Down Expand Up @@ -70,6 +82,21 @@ def test_map_save_load_json():
assert dd == {"a": 2, "b": 3}


def test_dir_map():
a = te.var("a")
b = te.var("b")
amap = tvm.runtime.convert({a: 2, b: 3})
assert dir(amap)


def test_getattr_map():
a = te.var("a")
b = te.var("b")
amap = tvm.runtime.convert({a: 2, b: 3})
assert getattr(amap, "type_key") == "Map"
assert not hasattr(amap, "test_key")


def test_in_container():
arr = tvm.runtime.convert(["a", "b", "c"])
assert "a" in arr
Expand All @@ -86,10 +113,4 @@ def test_ndarray_container():


if __name__ == "__main__":
test_str_map()
test_array()
test_map()
test_array_save_load_json()
test_map_save_load_json()
test_in_container()
test_ndarray_container()
pytest.main([__file__])

0 comments on commit f8b1df4

Please sign in to comment.