Skip to content

Commit

Permalink
ComputeDAG bug fix & Add Custom TensorCore Matmul Example (apache#42)
Browse files Browse the repository at this point in the history
* Bug Fix

* Sample example of Custom TensorCore Matmul
  • Loading branch information
jcf94 authored Jun 24, 2020
1 parent 5860191 commit 14a19cd
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 28 deletions.
34 changes: 17 additions & 17 deletions scripts/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,25 @@ def add_mn(M, N):
@register_workload_func
def matmul_nkkm(N, M, K, in_type='float32', out_type='float32',
tensor_core_support=False):
A = te.placeholder((N, K), name='A', dtype=in_type)
B = te.placeholder((K, M), name='B', dtype=in_type)
k = te.reduce_axis((0, K), name='k')
if in_type == out_type:
if not (in_type == 'float16' and out_type == 'float16'):
tensor_core_support = False
C = te.compute((N, M),
lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]),
name='C',
attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"})
else:
if tensor_core_support:
A = te.placeholder((N // 16, K // 16, 16, 16), name='A', dtype=in_type)
B = te.placeholder((K // 16, M // 16, 16, 16), name='B', dtype=in_type)
k = te.reduce_axis((0, K // 16), name='k')
kk = te.reduce_axis((0, 16), name='kk')
if not ((in_type == 'float16' and out_type == 'float32') or \
(in_type == 'int8' and out_type == 'int32')):
tensor_core_support = False
(in_type == 'int8' and out_type == 'int32')):
raise ValueError
C = te.compute((N // 16, M // 16, 16, 16),
lambda i, j, ii, jj: te.sum(A[i][k][ii][kk].astype(out_type) * B[k][j][kk][jj].astype(out_type),
axis=[k, kk]),
name='C')
else:
A = te.placeholder((N, K), name='A', dtype=in_type)
B = te.placeholder((K, M), name='B', dtype=in_type)
k = te.reduce_axis((0, K), name='k')
C = te.compute((N, M),
lambda i, j: te.sum(A[i][k].astype(out_type) * B[k][j].astype(out_type),
axis=[k]),
name='C',
attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"})
lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]),
name='C')

return [A, B, C]

Expand Down
181 changes: 173 additions & 8 deletions scripts/tune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,169 @@
import numpy as np

import tvm
from tvm import ansor
from tvm import ansor, te
from tvm.ansor.utils import request_remote

from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool

def tensor_core_meet_condition(meta_policy, state, stage_id):
pass

def intrin_wmma_load_matrix(scope):
n = 16
A = te.placeholder((n, n), name='A', dtype='float16')
BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256)
C = te.compute((n, n), lambda i, j: A[i, j], name='C')
BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256)

def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()

BA = ins[0]
BC = outs[0]
ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync',
BC.data, n, n, n, BC.elem_offset // 256,
BA.access_ptr('r'), n, 'row_major'))
return ib.get()

return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})

@tvm._ffi.register_func
def intrin_wmma_load_matrix_a():
return intrin_wmma_load_matrix("wmma.matrix_a")

@tvm._ffi.register_func
def intrin_wmma_load_matrix_b():
return intrin_wmma_load_matrix("wmma.matrix_b")

@tvm._ffi.register_func
def intrin_wmma_gemm():
n = 16
A = te.placeholder((n, n), name='A', dtype='float16')
B = te.placeholder((n, n), name='B', dtype='float16')
k = te.reduce_axis((0, n), name="k")
C = te.compute((n, n),
lambda ii, jj:
te.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k),
name='C')
BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256)
BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256)
BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256)

def intrin_func(ins, outs):
BA, BB = ins
BC, = outs

def init():
ib = tvm.tir.ir_builder.create()
ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0))
return ib.get()

def update():
ib = tvm.tir.ir_builder.create()
ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync',
BC.data, BC.elem_offset // 256,
BA.data, BA.elem_offset // 256,
BB.data, BB.elem_offset // 256,
BC.data, BC.elem_offset // 256))
return ib.get()

return update(), init(), update()

return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})

@tvm._ffi.register_func
def intrin_wmma_store_matrix():
n = 16
A = te.placeholder((n, n), name='A', dtype='float32')
BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256)
C = te.compute((n, n), lambda i, j: A[i, j], name='C')
BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256)

def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
BA = ins[0]
BC = outs[0]
ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync',
BA.data, n, n, n, BA.elem_offset // 256,
BC.access_ptr('w'), n, 'row_major'))
return ib.get()

return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})

def tensor_core_apply(meta_policy, state, stage_id):
ret = []
state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag)

A, B, C = meta_policy.cur_task.compute_dag.ops

C_local = state.cache_write(C, "wmma.accumulator")

its0 = state.split(C_local, state[C_local].iters[0], [None, None])
split_step0 = state.transform_steps_size() - 1
its1 = state.split(C_local, state[C_local].iters[3], [None, None])
split_step1 = state.transform_steps_size() - 1
its2 = state.split(C_local, state[C_local].iters[8], [None])

state.reorder(C_local, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2],
its2[0], its2[1],
state[C_local].iters[6],
state[C_local].iters[7],
state[C_local].iters[10]])
state.fuse(C_local, [state[C_local].iters[0], state[C_local].iters[1]])
state.fuse(C_local, [state[C_local].iters[1], state[C_local].iters[2]])
state.fuse(C_local, [state[C_local].iters[2], state[C_local].iters[3]])

its0 = state.follow_split(C, state[C].iters[0], split_step0, 2)
its1 = state.follow_split(C, state[C].iters[3], split_step1, 2)
state.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2],
state[C].iters[6], state[C].iters[7]])
state.fuse(C, [state[C].iters[0], state[C].iters[1]])
state.fuse(C, [state[C].iters[1], state[C].iters[2]])
local_write_pos = state.fuse(C, [state[C].iters[2], state[C].iters[3]])
state.compute_at(C_local, C, local_write_pos)
shared_read_pos = state[C_local].iters[3]
local_read_pos = state[C_local].iters[4]
state.bind_thread(C, state[C].iters[0], "blockIdx.x")
state.bind_thread(C, state[C].iters[1], "vthread")
state.bind_thread(C, state[C].iters[2], "threadIdx.x")

B_shared = state.cache_read(B, "shared", [C_local])
B_local = state.cache_read(B_shared, "wmma.matrix_b", [C_local])
state.compute_at(B_shared, C_local, shared_read_pos)
state.compute_at(B_local, C_local, local_read_pos)

it = state.fuse(B_shared, state[B_shared].iters[:])
its = state.split(B_shared, it, [4]) # vectorize add a callback check function
state.vectorize(B_shared, its[1])
its = state.follow_fused_split(B_shared, its[0], [split_step0, split_step1], 1, True)
state.bind_thread(B_shared, its[1], "threadIdx.x")

A_shared = state.cache_read(A, "shared", [C_local])
A_local = state.cache_read(A_shared, "wmma.matrix_a", [C_local])
state.compute_at(A_shared, C_local, shared_read_pos)
state.compute_at(A_local, C_local, local_read_pos)

it = state.fuse(A_shared, state[A_shared].iters[:])
its = state.split(A_shared, it, [4]) # vectorize add a callback check function
state.vectorize(A_shared, its[1])
its = state.follow_fused_split(A_shared, its[0], [split_step0, split_step1], 1, True)
state.bind_thread(A_shared, its[1], "threadIdx.x")

state.tensorize(A_local, state[A_local].iters[-2], "intrin_wmma_load_matrix_a")
state.tensorize(B_local, state[B_local].iters[-2], "intrin_wmma_load_matrix_b")
state.tensorize(C_local, state[C_local].iters[-3], "intrin_wmma_gemm")
state.tensorize(C, state[C].iters[-2], "intrin_wmma_store_matrix")

print(state)

ret.append([state.state_object, -1])
return ret

def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose,
n_parallel, build_timeout, local_measure, rpc_device_key, rpc_host,
rpc_port, rpc_num_threads, ndk_cc, early_stopping=-1, run_timeout=10):
rpc_port, rpc_num_threads, ndk_cc, early_stopping=-1, run_timeout=10,
tensor_core_matmul=False):
builder = runner = measure_ctx = None
if local_measure:
builder = ansor.LocalBuilder(timeout=build_timeout)
Expand All @@ -52,13 +207,16 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose
config_threadpool = remote.get_function('runtime.config_threadpool')
config_threadpool(0, rpc_num_threads)

pre_search_callbacks = [ansor.PreloadMeasuredStates(log_file)]
if tensor_core_matmul:
pre_search_callbacks.append(ansor.PreloadCustomSketchRule(tensor_core_meet_condition, tensor_core_apply))
tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping,
num_measure_per_iter=num_measure_per_iter,
verbose=verbose,
builder=builder,
runner=runner,
measure_callbacks=[ansor.LogToFile(log_file)],
pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)])
pre_search_callbacks=pre_search_callbacks)

return tune_option, measure_ctx

Expand Down Expand Up @@ -113,10 +271,10 @@ def tune_workload(wkl_key, target, target_host, policy, model_type,
model.load(load_model_file)
elif load_log_file:
model.load_log_file(load_log_file)
elif model_type == "random":
model = ansor.RandomModel()
else:
raise ValueError("Invalid model: " + model_type)
elif model_type == "random":
model = ansor.RandomModel()
else:
raise ValueError("Invalid model: " + model_type)

if policy == 'sketch':
policy = ansor.SketchSearchPolicy(program_cost_model=model)
Expand Down Expand Up @@ -200,11 +358,18 @@ def objective_func(costs):
load_log_file = args.load_log or log_file
weights = get_workload_weights(args.wkl)

# Special check for tensor core
wkl_key = args.wkl
wkl_key = wkl_key.split("-")
tensor_core_matmul = False
if wkl_key[0] == "matmul" and wkl_key[6] == "tc":
tensor_core_matmul = True

tune_option, measure_ctx = create_tune_option(target, log_file,
args.n_trials, args.num_measure_per_iter, args.verbose,
args.n_parallel, args.build_timeout, args.local_measure,
args.rpc_device_key, args.rpc_host, args.rpc_port, args.rpc_num_threads,
args.ndk_cc)
args.ndk_cc, tensor_core_matmul=tensor_core_matmul)

if args.task_scheduler == 'no':
# tune workloads one by one
Expand Down
12 changes: 9 additions & 3 deletions src/ansor/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,13 +569,11 @@ State ComputeDAG::GetInitState() const {
ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
auto node = make_object<ComputeDAGNode>();
FlopEstimator estimator;

node->tensors = std::move(tensors);
node->access_analyzer = AccessAnalyzer(node->tensors);
node->ops = Array<te::Operation>(node->access_analyzer->ops_topo_order);
node->flop_ct = estimator.EstimateFlop(node->ops);
node->init_state = State(node->ops);

data_ = std::move(node);
}

Expand All @@ -587,7 +585,15 @@ ComputeDAG::ComputeDAG(const std::string& workload_key) {
} else {
LOG(FATAL) << "ansor.workload_key_to_tensors is not registered";
}
ComputeDAG(std::move(tens));

auto node = make_object<ComputeDAGNode>();
FlopEstimator estimator;
node->tensors = std::move(tens);
node->access_analyzer = AccessAnalyzer(node->tensors);
node->ops = Array<te::Operation>(node->access_analyzer->ops_topo_order);
node->flop_ct = estimator.EstimateFlop(node->ops);
node->init_state = State(node->ops);
data_ = std::move(node);
}

std::string BaseName(const std::string& str) {
Expand Down

0 comments on commit 14a19cd

Please sign in to comment.