Skip to content

Commit

Permalink
update intrin registrations
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jul 12, 2022
1 parent 826a3fe commit 564bc89
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 12 deletions.
14 changes: 4 additions & 10 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,14 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
raise NotImplementedError(f"{target.kind.name} is not supported")


def multi_level_tiling_tensor_core(target: Target, scope="shared") -> ScheduleRule:
def multi_level_tiling_tensor_core(
target: Target, scope="shared", in_dtype="float16", out_dtype="float32", trans_b=False
) -> ScheduleRule:
"""Default schedule rules for with multi-level tiling reuse for tensor core"""
assert scope in ["shared", "global"]
if target.kind.name == "cuda":
return MultiLevelTilingTensorCore(
intrin_group={
"init": tensor_intrin.WMMA_FILL_16x16x16_F32_INTRIN,
"load_a": tensor_intrin.WMMA_LOAD_16x16x16_F16_A_INTRIN,
"load_b": tensor_intrin.WMMA_LOAD_16x16x16_F16_B_INTRIN,
"compute": tensor_intrin.WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
"store": tensor_intrin.WMMA_STORE_16x16x16_F32_SHARED_INTRIN
if scope == "shared"
else tensor_intrin.WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN,
},
intrin_group=tensor_intrin.get_wmma_intrin_group(scope, in_dtype, out_dtype, trans_b),
structure="SSSRRSRS",
tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
max_innermost_factor=4, # 64 // tensor intrin size
Expand Down
65 changes: 64 additions & 1 deletion python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name,missing-function-docstring
"""Intrinsics for tensorization on NVIDIA GPU."""
from typing import Tuple
from typing import Tuple, Dict
from tvm.script import tir as T
from tvm.tir.function import PrimFunc
from .. import IntImm, Cast
Expand Down Expand Up @@ -806,3 +806,66 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
TensorIntrin.register(
WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float16", "global")
)


def get_wmma_intrin_group(
scope: str, in_dtype: str, out_dtype: str, trans_b: bool
) -> Dict[str, str]:
"""Get a group of intrinsics for wmma tensor core with the given configurations
Parameters
----------
scope : str
Must be one of ["global", "shared"]. The memory scope of the result buffer.
in_dtype : str
The input data type.
out_dtype : str
The output data dtype.
trans_b : bool
Whether the input matrix B is transposed.
Returns
-------
ret : Dict[str, str]
A group of tensor intrinsics.
"""
assert scope in ["global", "shared"]
assert in_dtype in ["float16"]
assert out_dtype in ["float16", "float32"]

load_a_intrins = {"float16": WMMA_LOAD_16x16x16_F16_A_INTRIN}
load_b_intrins = {
"float16": WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN
if trans_b
else WMMA_LOAD_16x16x16_F16_B_INTRIN
}
compute_intrins = {
"float16": WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
if trans_b
else WMMA_SYNC_16x16x16_f16f16f16_INTRIN,
"float32": WMMA_SYNC_16x16x16_f16f16f32_TRANS_INTRIN
if trans_b
else WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
}
init_intrins = {
"float16": WMMA_FILL_16x16x16_F16_INTRIN,
"float32": WMMA_FILL_16x16x16_F32_INTRIN,
}
store_intrins = {
"float16": WMMA_STORE_16x16x16_F16_SHARED_INTRIN
if scope == "shared"
else WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN,
"float32": WMMA_STORE_16x16x16_F32_SHARED_INTRIN
if scope == "shared"
else WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN,
}
return {
"init": init_intrins[out_dtype],
"load_a": load_a_intrins[in_dtype],
"load_b": load_b_intrins,
"compute": compute_intrins[out_dtype],
"store": store_intrins[out_dtype],
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class TensorCoreState : public State {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode);
};


TVM_REGISTER_OBJECT_TYPE(TensorCoreStateNode);

TensorCoreState::TensorCoreState(Schedule sch, BlockRV block_rv, Array<Array<LoopRV>> tiles) {
Expand Down

0 comments on commit 564bc89

Please sign in to comment.