From 14a19cd9597809801d570228818aea61b7082072 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Wed, 24 Jun 2020 13:22:45 +0800 Subject: [PATCH] ComputeDAG bug fix & Add Custom TensorCore Matmul Example (#42) * Bug Fix * Sample example of Custom TensorCore Matmul --- scripts/common.py | 34 ++++---- scripts/tune_test.py | 181 +++++++++++++++++++++++++++++++++++++-- src/ansor/compute_dag.cc | 12 ++- 3 files changed, 199 insertions(+), 28 deletions(-) diff --git a/scripts/common.py b/scripts/common.py index ac25b28e55b1..e9cf58e128bb 100644 --- a/scripts/common.py +++ b/scripts/common.py @@ -81,25 +81,25 @@ def add_mn(M, N): @register_workload_func def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', tensor_core_support=False): - A = te.placeholder((N, K), name='A', dtype=in_type) - B = te.placeholder((K, M), name='B', dtype=in_type) - k = te.reduce_axis((0, K), name='k') - if in_type == out_type: - if not (in_type == 'float16' and out_type == 'float16'): - tensor_core_support = False - C = te.compute((N, M), - lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), - name='C', - attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"}) - else: + if tensor_core_support: + A = te.placeholder((N // 16, K // 16, 16, 16), name='A', dtype=in_type) + B = te.placeholder((K // 16, M // 16, 16, 16), name='B', dtype=in_type) + k = te.reduce_axis((0, K // 16), name='k') + kk = te.reduce_axis((0, 16), name='kk') if not ((in_type == 'float16' and out_type == 'float32') or \ - (in_type == 'int8' and out_type == 'int32')): - tensor_core_support = False + (in_type == 'int8' and out_type == 'int32')): + raise ValueError + C = te.compute((N // 16, M // 16, 16, 16), + lambda i, j, ii, jj: te.sum(A[i][k][ii][kk].astype(out_type) * B[k][j][kk][jj].astype(out_type), + axis=[k, kk]), + name='C') + else: + A = te.placeholder((N, K), name='A', dtype=in_type) + B = te.placeholder((K, M), name='B', dtype=in_type) + k = te.reduce_axis((0, K), name='k') C = te.compute((N, M), - lambda i, j: te.sum(A[i][k].astype(out_type) * B[k][j].astype(out_type), - axis=[k]), - name='C', - attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"}) + lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), + name='C') return [A, B, C] diff --git a/scripts/tune_test.py b/scripts/tune_test.py index c98da3eca53b..6b39cf5e7865 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -24,14 +24,169 @@ import numpy as np import tvm -from tvm import ansor +from tvm import ansor, te from tvm.ansor.utils import request_remote from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool +def tensor_core_meet_condition(meta_policy, state, stage_id): + pass + +def intrin_wmma_load_matrix(scope): + n = 16 + A = te.placeholder((n, n), name='A', dtype='float16') + BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) + C = te.compute((n, n), lambda i, j: A[i, j], name='C') + BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + + BA = ins[0] + BC = outs[0] + ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', + BC.data, n, n, n, BC.elem_offset // 256, + BA.access_ptr('r'), n, 'row_major')) + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + +@tvm._ffi.register_func +def intrin_wmma_load_matrix_a(): + return intrin_wmma_load_matrix("wmma.matrix_a") + +@tvm._ffi.register_func +def intrin_wmma_load_matrix_b(): + return intrin_wmma_load_matrix("wmma.matrix_b") + +@tvm._ffi.register_func +def intrin_wmma_gemm(): + n = 16 + A = te.placeholder((n, n), name='A', dtype='float16') + B = te.placeholder((n, n), name='B', dtype='float16') + k = te.reduce_axis((0, n), name="k") + C = te.compute((n, n), + lambda ii, jj: + te.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), + name='C') + BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) + BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) + BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + BA, BB = ins + BC, = outs + + def init(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) + return ib.get() + + def update(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', + BC.data, BC.elem_offset // 256, + BA.data, BA.elem_offset // 256, + BB.data, BB.elem_offset // 256, + BC.data, BC.elem_offset // 256)) + return ib.get() + + return update(), init(), update() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + +@tvm._ffi.register_func +def intrin_wmma_store_matrix(): + n = 16 + A = te.placeholder((n, n), name='A', dtype='float32') + BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256) + C = te.compute((n, n), lambda i, j: A[i, j], name='C') + BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + BA = ins[0] + BC = outs[0] + ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', + BA.data, n, n, n, BA.elem_offset // 256, + BC.access_ptr('w'), n, 'row_major')) + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + +def tensor_core_apply(meta_policy, state, stage_id): + ret = [] + state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) + + A, B, C = meta_policy.cur_task.compute_dag.ops + + C_local = state.cache_write(C, "wmma.accumulator") + + its0 = state.split(C_local, state[C_local].iters[0], [None, None]) + split_step0 = state.transform_steps_size() - 1 + its1 = state.split(C_local, state[C_local].iters[3], [None, None]) + split_step1 = state.transform_steps_size() - 1 + its2 = state.split(C_local, state[C_local].iters[8], [None]) + + state.reorder(C_local, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], + its2[0], its2[1], + state[C_local].iters[6], + state[C_local].iters[7], + state[C_local].iters[10]]) + state.fuse(C_local, [state[C_local].iters[0], state[C_local].iters[1]]) + state.fuse(C_local, [state[C_local].iters[1], state[C_local].iters[2]]) + state.fuse(C_local, [state[C_local].iters[2], state[C_local].iters[3]]) + + its0 = state.follow_split(C, state[C].iters[0], split_step0, 2) + its1 = state.follow_split(C, state[C].iters[3], split_step1, 2) + state.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], + state[C].iters[6], state[C].iters[7]]) + state.fuse(C, [state[C].iters[0], state[C].iters[1]]) + state.fuse(C, [state[C].iters[1], state[C].iters[2]]) + local_write_pos = state.fuse(C, [state[C].iters[2], state[C].iters[3]]) + state.compute_at(C_local, C, local_write_pos) + shared_read_pos = state[C_local].iters[3] + local_read_pos = state[C_local].iters[4] + state.bind_thread(C, state[C].iters[0], "blockIdx.x") + state.bind_thread(C, state[C].iters[1], "vthread") + state.bind_thread(C, state[C].iters[2], "threadIdx.x") + + B_shared = state.cache_read(B, "shared", [C_local]) + B_local = state.cache_read(B_shared, "wmma.matrix_b", [C_local]) + state.compute_at(B_shared, C_local, shared_read_pos) + state.compute_at(B_local, C_local, local_read_pos) + + it = state.fuse(B_shared, state[B_shared].iters[:]) + its = state.split(B_shared, it, [4]) # vectorize add a callback check function + state.vectorize(B_shared, its[1]) + its = state.follow_fused_split(B_shared, its[0], [split_step0, split_step1], 1, True) + state.bind_thread(B_shared, its[1], "threadIdx.x") + + A_shared = state.cache_read(A, "shared", [C_local]) + A_local = state.cache_read(A_shared, "wmma.matrix_a", [C_local]) + state.compute_at(A_shared, C_local, shared_read_pos) + state.compute_at(A_local, C_local, local_read_pos) + + it = state.fuse(A_shared, state[A_shared].iters[:]) + its = state.split(A_shared, it, [4]) # vectorize add a callback check function + state.vectorize(A_shared, its[1]) + its = state.follow_fused_split(A_shared, its[0], [split_step0, split_step1], 1, True) + state.bind_thread(A_shared, its[1], "threadIdx.x") + + state.tensorize(A_local, state[A_local].iters[-2], "intrin_wmma_load_matrix_a") + state.tensorize(B_local, state[B_local].iters[-2], "intrin_wmma_load_matrix_b") + state.tensorize(C_local, state[C_local].iters[-3], "intrin_wmma_gemm") + state.tensorize(C, state[C].iters[-2], "intrin_wmma_store_matrix") + + print(state) + + ret.append([state.state_object, -1]) + return ret + def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose, n_parallel, build_timeout, local_measure, rpc_device_key, rpc_host, - rpc_port, rpc_num_threads, ndk_cc, early_stopping=-1, run_timeout=10): + rpc_port, rpc_num_threads, ndk_cc, early_stopping=-1, run_timeout=10, + tensor_core_matmul=False): builder = runner = measure_ctx = None if local_measure: builder = ansor.LocalBuilder(timeout=build_timeout) @@ -52,13 +207,16 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose config_threadpool = remote.get_function('runtime.config_threadpool') config_threadpool(0, rpc_num_threads) + pre_search_callbacks = [ansor.PreloadMeasuredStates(log_file)] + if tensor_core_matmul: + pre_search_callbacks.append(ansor.PreloadCustomSketchRule(tensor_core_meet_condition, tensor_core_apply)) tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping, num_measure_per_iter=num_measure_per_iter, verbose=verbose, builder=builder, runner=runner, measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)]) + pre_search_callbacks=pre_search_callbacks) return tune_option, measure_ctx @@ -113,10 +271,10 @@ def tune_workload(wkl_key, target, target_host, policy, model_type, model.load(load_model_file) elif load_log_file: model.load_log_file(load_log_file) - elif model_type == "random": - model = ansor.RandomModel() - else: - raise ValueError("Invalid model: " + model_type) + elif model_type == "random": + model = ansor.RandomModel() + else: + raise ValueError("Invalid model: " + model_type) if policy == 'sketch': policy = ansor.SketchSearchPolicy(program_cost_model=model) @@ -200,11 +358,18 @@ def objective_func(costs): load_log_file = args.load_log or log_file weights = get_workload_weights(args.wkl) + # Special check for tensor core + wkl_key = args.wkl + wkl_key = wkl_key.split("-") + tensor_core_matmul = False + if wkl_key[0] == "matmul" and wkl_key[6] == "tc": + tensor_core_matmul = True + tune_option, measure_ctx = create_tune_option(target, log_file, args.n_trials, args.num_measure_per_iter, args.verbose, args.n_parallel, args.build_timeout, args.local_measure, args.rpc_device_key, args.rpc_host, args.rpc_port, args.rpc_num_threads, - args.ndk_cc) + args.ndk_cc, tensor_core_matmul=tensor_core_matmul) if args.task_scheduler == 'no': # tune workloads one by one diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index ee87318cdd84..9e6da6ff6f3b 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -569,13 +569,11 @@ State ComputeDAG::GetInitState() const { ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); FlopEstimator estimator; - node->tensors = std::move(tensors); node->access_analyzer = AccessAnalyzer(node->tensors); node->ops = Array(node->access_analyzer->ops_topo_order); node->flop_ct = estimator.EstimateFlop(node->ops); node->init_state = State(node->ops); - data_ = std::move(node); } @@ -587,7 +585,15 @@ ComputeDAG::ComputeDAG(const std::string& workload_key) { } else { LOG(FATAL) << "ansor.workload_key_to_tensors is not registered"; } - ComputeDAG(std::move(tens)); + + auto node = make_object(); + FlopEstimator estimator; + node->tensors = std::move(tens); + node->access_analyzer = AccessAnalyzer(node->tensors); + node->ops = Array(node->access_analyzer->ops_topo_order); + node->flop_ct = estimator.EstimateFlop(node->ops); + node->init_state = State(node->ops); + data_ = std::move(node); } std::string BaseName(const std::string& str) {