diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index f1b052c8d31f..3194a5601688 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -114,6 +114,8 @@ def squeeze(a, axis=None): for i, a_dim in enumerate(a_shape): if i not in search_axis: out_shape.append(a_dim) + if not out_shape: + out_shape.append(1) def _compute(*indices): real_indices = [] flag = 0 diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 8acb4b4f5a1b..1113856fbfdb 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -82,7 +82,11 @@ def check_device(device): data_npy = np.random.normal(size=src_shape).astype(A.dtype) out_npy = np.squeeze(data_npy, axis=axis) data_nd = tvm.nd.array(data_npy, ctx) - out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype) + if out_npy.shape == (): + out_nd_shape = (1,) + else: + out_nd_shape = out_npy.shape + out_nd = tvm.nd.empty(out_nd_shape, ctx=ctx, dtype=B.dtype) foo(data_nd, out_nd) np.testing.assert_allclose(out_nd.asnumpy(), out_npy) @@ -159,6 +163,7 @@ def test_squeeze(): verify_squeeze((1, 2, 3, 4), 0) verify_squeeze((1, 2, 1, 4), None) verify_squeeze((1, 1, 1, 4), (1, 2)) + verify_squeeze((1, 1, 1, 1), None) def test_concatenate():