Skip to content

Commit

Permalink
add warning when storage align doesn't work
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jul 13, 2022
1 parent 8b7fc70 commit f4b585e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
10 changes: 7 additions & 3 deletions python/tvm/meta_schedule/testing/te_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,14 +591,18 @@ def matmul(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
return (a, b, c)


def matmul_fp16(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
def matmul_fp16(n: int, m: int, k: int, out_dtype="float32") -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
a = te.placeholder((n, k), name="A", dtype="float16")
b = te.placeholder((k, m), name="B", dtype="float16")
k = te.reduce_axis((0, k), name="k")

def f_compute(i, j):
v_a = tir.Cast(dtype="float32", value=a[i, k])
v_b = tir.Cast(dtype="float32", value=b[k, j])
if out_dtype == "float32":
v_a = tir.Cast(dtype="float32", value=a[i, k])
v_b = tir.Cast(dtype="float32", value=b[k, j])
else:
v_a = a[i, k]
v_b = b[k, j]
return te.sum(v_a * v_b, axis=[k])

c = te.compute((n, m), f_compute, name="C")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
sch->StorageAlign(cache_read, 0, -2, 32, 8);
} else if (dtype.is_int() && dtype.bits() == 8) {
sch->StorageAlign(cache_read, 0, -2, 32, 16);
} else {
LOG(WARNING) << "StorageAlign is not applied for data type " << dtype
<< ", shared memory accesses might be inefficient.";
}
}
return {state};
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ struct ProducerConsumerSplit {
* \param self The schedule state.
* \param block The queried block.
* \param n The index of the queried buffer.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \param index_type The type of the buffer index, kRead or kWrite.
* \return The buffer of the n-th read/write region of the block.
* \throw ScheduleError If the buffer index is out of bound.
*/
Expand All @@ -435,7 +435,7 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,
* \param self The schedule state.
* \param block The queried block.
* \param n The index of the queried buffer.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \param index_type The type of the buffer index, kRead or kWrite.
* \return The n-th read/write region of the block.
* \throw ScheduleError If the buffer index is out of bound.
*/
Expand Down

0 comments on commit f4b585e

Please sign in to comment.