diff --git a/python/tvm/ir/container.py b/python/tvm/ir/container.py index c62952554bc6..3c7a57a830d9 100644 --- a/python/tvm/ir/container.py +++ b/python/tvm/ir/container.py @@ -38,6 +38,16 @@ def __getitem__(self, idx): def __len__(self): return _ffi_api.ArraySize(self) + def __dir__(self): + return sorted(dir(self.__class__) + ["type_key"]) + + def __getattr__(self, name): + if name == "handle": + raise AttributeError("handle is not set") + if name == "type_key": + return super().__getattr__(name) + raise AttributeError("%s has no attribute %s" % (str(type(self)), name)) + @tvm._ffi.register_object class Map(Object): @@ -59,6 +69,16 @@ def __iter__(self): for i in range(len(self)): yield akvs[i * 2] + def __dir__(self): + return sorted(dir(self.__class__) + ["type_key"]) + + def __getattr__(self, name): + if name == "handle": + raise AttributeError("handle is not set") + if name == "type_key": + return super().__getattr__(name) + raise AttributeError("%s has no attribute %s" % (str(type(self)), name)) + def keys(self): return iter(self) diff --git a/tests/python/unittest/test_ir_container.py b/tests/python/unittest/test_ir_container.py index fb83817f1eed..3652d5bdb280 100644 --- a/tests/python/unittest/test_ir_container.py +++ b/tests/python/unittest/test_ir_container.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm from tvm import te import numpy as np @@ -34,6 +35,17 @@ def test_array_save_load_json(): assert a_loaded[1].value == 2 +def test_dir_array(): + a = tvm.runtime.convert([1, 2, 3]) + assert dir(a) + + +def test_getattr_array(): + a = tvm.runtime.convert([1, 2, 3]) + assert getattr(a, "type_key") == "Array" + assert not hasattr(a, "test_key") + + def test_map(): a = te.var("a") b = te.var("b") @@ -70,6 +82,21 @@ def test_map_save_load_json(): assert dd == {"a": 2, "b": 3} +def test_dir_map(): + a = te.var("a") + b = te.var("b") + amap = tvm.runtime.convert({a: 2, b: 3}) + assert dir(amap) + + +def test_getattr_map(): + a = te.var("a") + b = te.var("b") + amap = tvm.runtime.convert({a: 2, b: 3}) + assert getattr(amap, "type_key") == "Map" + assert not hasattr(amap, "test_key") + + def test_in_container(): arr = tvm.runtime.convert(["a", "b", "c"]) assert "a" in arr @@ -86,10 +113,4 @@ def test_ndarray_container(): if __name__ == "__main__": - test_str_map() - test_array() - test_map() - test_array_save_load_json() - test_map_save_load_json() - test_in_container() - test_ndarray_container() + pytest.main([__file__])