Skip to content

Commit

Permalink
[TIR] Fix buffer scope in structural equal (apache#8768)
Browse files Browse the repository at this point in the history
* fix buffer scope in structual equal

* make global equal to empty
  • Loading branch information
Hzfengsy authored and shingjan committed Aug 23, 2021
1 parent 7f2827e commit 95e55f5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
11 changes: 9 additions & 2 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,17 @@ class PointerTypeNode : public TypeNode {
}

bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const {
return equal(element_type, other->element_type);
// Make "global" equal to ""
String lhs_scope = storage_scope.empty() ? "global" : storage_scope;
String rhs_scope = other->storage_scope.empty() ? "global" : other->storage_scope;
return equal(element_type, other->element_type) && equal(lhs_scope, rhs_scope);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(element_type); }
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(element_type);
// Make "global" equal to ""
hash_reduce(storage_scope.empty() ? "global" : storage_scope);
}

static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
Expand Down
19 changes: 19 additions & 0 deletions tests/python/unittest/test_tir_structural_equal_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,24 @@ def func2():
assert consistent_equal(func2(), func2())


def test_buffer_storage_scope():
x = te.var("x", dtype="handle")

buffer_local_0 = tvm.tir.decl_buffer((10, 10), "float32", scope="local")
buffer_local_1 = tvm.tir.decl_buffer((10, 10), "float32", scope="local")
buffer_global = tvm.tir.decl_buffer((10, 10), "float32", scope="global")
buffer_empty = tvm.tir.decl_buffer((10, 10), "float32", scope="")

func0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_local_0})
func1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_local_1})
func2 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_global})
func3 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_empty})

assert consistent_equal(func0, func1)
assert consistent_equal(func2, func3)
assert not consistent_equal(func0, func2)


def test_buffer_load_store():
b = tvm.tir.decl_buffer((10, 10), "float32")
x = tvm.tir.BufferLoad(b, [0, 1])
Expand All @@ -188,4 +206,5 @@ def test_buffer_load_store():
test_array()
test_env_func()
test_stmt()
test_buffer_storage_scope()
test_buffer_load_store()

0 comments on commit 95e55f5

Please sign in to comment.