diff --git a/python/tvm/meta_schedule/testing/te_workload.py b/python/tvm/meta_schedule/testing/te_workload.py index 28a2df628c530..29fe710594f2f 100644 --- a/python/tvm/meta_schedule/testing/te_workload.py +++ b/python/tvm/meta_schedule/testing/te_workload.py @@ -591,14 +591,18 @@ def matmul(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: return (a, b, c) -def matmul_fp16(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: +def matmul_fp16(n: int, m: int, k: int, out_dtype="float32") -> Tuple[te.Tensor, te.Tensor, te.Tensor]: a = te.placeholder((n, k), name="A", dtype="float16") b = te.placeholder((k, m), name="B", dtype="float16") k = te.reduce_axis((0, k), name="k") def f_compute(i, j): - v_a = tir.Cast(dtype="float32", value=a[i, k]) - v_b = tir.Cast(dtype="float32", value=b[k, j]) + if out_dtype == "float32": + v_a = tir.Cast(dtype="float32", value=a[i, k]) + v_b = tir.Cast(dtype="float32", value=b[k, j]) + else: + v_a = a[i, k] + v_b = b[k, j] return te.sum(v_a * v_b, axis=[k]) c = te.compute((n, m), f_compute, name="C") diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 9544a9a9463f4..91df62fc36632 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -227,6 +227,9 @@ std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( sch->StorageAlign(cache_read, 0, -2, 32, 8); } else if (dtype.is_int() && dtype.bits() == 8) { sch->StorageAlign(cache_read, 0, -2, 32, 16); + } else { + LOG(WARNING) << "StorageAlign is not applied for data type " << dtype + << ", shared memory accesses might be inefficient."; } } return {state}; diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index ee70596ddc502..37b3e00cc23ee 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -423,7 +423,7 @@ struct ProducerConsumerSplit { * \param self The schedule state. * \param block The queried block. * \param n The index of the queried buffer. - * \param buffer_index_type The type of the buffer index, kRead or kWrite. + * \param index_type The type of the buffer index, kRead or kWrite. * \return The buffer of the n-th read/write region of the block. * \throw ScheduleError If the buffer index is out of bound. */ @@ -435,7 +435,7 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, * \param self The schedule state. * \param block The queried block. * \param n The index of the queried buffer. - * \param buffer_index_type The type of the buffer index, kRead or kWrite. + * \param index_type The type of the buffer index, kRead or kWrite. * \return The n-th read/write region of the block. * \throw ScheduleError If the buffer index is out of bound. */