Skip to content

Commit

Permalink
Migrate workload_registry.py (apache#16)
Browse files Browse the repository at this point in the history
* add workload registry

* update

* update
  • Loading branch information
merrymercy committed Jun 20, 2020
1 parent b839c0f commit cfe58d7
Show file tree
Hide file tree
Showing 13 changed files with 364 additions and 39 deletions.
8 changes: 5 additions & 3 deletions python/tvm/ansor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@
from . import measure
from . import serialization
from . import loop_state
from . import task
from . import auto_schedule
from . import utils
from . import feature
from . import workload_registry

# Shortcut
from .compute_dag import ComputeDAG
from .task import SearchTask, MetaTileRewritePolicy, TuneOption
from .task import auto_schedule
from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams
from .auto_schedule import auto_schedule
from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, RPCRunnerWarpper
from .cost_model import RandomModel
from .cost_model.xgb_model import XGBModel
from .serialization import LogToFile, LogReader, best_measure_pair_in_file
from .workload_registry import register_auto_scheduler_workload_func, workload_key_to_dag
File renamed without changes.
2 changes: 1 addition & 1 deletion python/tvm/ansor/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import numpy as np

from .loop_state import StateObject
from .task import SearchTask
from .auto_schedule import SearchTask
from .measure import MeasureInput, MeasureResult
from . import _ffi_api

Expand Down
5 changes: 3 additions & 2 deletions python/tvm/ansor/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@

logger = logging.getLogger('ansor')

MAX_ERROR_MSG_LEN = 512


@tvm._ffi.register_object("ansor.MeasureCallback")
class MeasureCallback(Object):
Expand Down Expand Up @@ -238,8 +240,6 @@ def __exit__(self, type, value, trace):
self.tracker.terminate()
self.server.terminate()

MAX_ERROR_MSG_LEN = 512


class MeasureErrorNo(object):
"""Error type for MeasureResult"""
Expand Down Expand Up @@ -505,3 +505,4 @@ def timed_func(inp, build_res):
print("")

return measure_results

5 changes: 5 additions & 0 deletions python/tvm/ansor/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def write_measure_records_to_file(filename, inputs, results):
_ffi_api.WriteMeasureRecordsToFile(filename, inputs, results)


def get_states_from_measure_inputs(inputs, task):
"""Get states from measure inputs"""
return _ffi_api.GetStatesFromMeasureInputs(inputs, task)


def best_measure_pair_in_file(filename, workload_key=None, target=None):
""" Return best results form log file
Expand Down
190 changes: 190 additions & 0 deletions python/tvm/ansor/workload_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


"""
Workload registration and serialization.
We use a json string to represent a workload (a compute dag).
The format of the string is `[func_name, [args...]]`.
The dag should be the return value of this `func_name(*args)`.
Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags
and matching them efficiently is not easy. Therefore, we use the above string to encode a compute dag.
These strings are efficient for serialization/matching and wont' be too long.
When we need the dag, we decode the string and call the function, which will return the dag.
"""

from typing import List, Tuple, Callable, Union
from collections import Hashable
import pickle
import json
import hashlib

import tvm._ffi
from ..te import Tensor, PlaceholderOp, ComputeOp, placeholder
from .utils import get_const_tuple
from .compute_dag import ComputeDAG

WORKLOAD_FUNC_REGISTRY = {}


def register_auto_scheduler_workload_func(func: Callable):
"""Register a workload generation function
The input function should take hashable and jsonable arguments
(int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor.
Examples
--------
@register_auto_scheduler_workload_func
def matmul(N, M, K):
A = tvm.placeholder((N, K), name='A')
B = tvm.placeholder((K, M), name='B')
k = tvm.reduce_axis((0, K), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C')
return [A, B, C]
"""
func_name = func.__name__
if func_name in WORKLOAD_FUNC_REGISTRY:
raise RuntimeError('%s has been registered already' % func_name)
WORKLOAD_FUNC_REGISTRY[func_name] = func
return func


def compute_dag_hash(dag: ComputeDAG):
# todo: implement this more carefully and move this to c++ as a member function of ComputeDAG
str_key = ''
for op in dag.ops:
t = op.output(0)
if isinstance(op, PlaceholderOp):
str_key += 'placeholder,'
str_key += str(get_const_tuple(t.shape)) + ','
str_key += t.dtype + ';'
elif isinstance(op, ComputeOp):
str_key += str(t.op.body) + ','
str_key += str(get_const_tuple(t.shape)) + ','
str_key += t.dtype + ';'
else:
raise ValueError("Invalid op: " + op)

str_key = str_key.encode(encoding='utf-8')
return hashlib.md5(str_key).hexdigest()


def register_auto_scheduler_workload_bufs(bufs: List[Tensor]) -> str:
"""Directly register buffers of a workload and return the workload_key
The buffers can be looked up with workload_key_to_tensors by the workload_key
"""
dag = ComputeDAG(bufs)
key = compute_dag_hash(dag)
WORKLOAD_FUNC_REGISTRY[key] = bufs
return json.dumps((key,))


def list_to_tuple(x: List) -> Tuple:
"""Convert a list to a tuple recursively"""
assert isinstance(x, list)
return tuple(list_to_tuple(y) if isinstance(y, list) else y for y in x)


def serialize_args(args: Tuple) -> Tuple:
"""
Serialize arguments of a function to a hashable and jsonable tuple.
Currently this is mainly used for tvm.tensor.Tensor
"""
ret = []
for t in args:
if isinstance(t, Tensor):
t = ('TENSOR', get_const_tuple(t.shape), t.dtype)
elif isinstance(t, list):
t = list_to_tuple(t)

assert isinstance(t, Hashable), str(t) + " is not hashable"
ret.append(t)

return tuple(ret)


def deserialize_args(args: Tuple) -> List:
"""The inverse function of :code:`serialize_args`"""
ret = []
for t in args:
if isinstance(t, (tuple, list)) and t[0] == 'TENSOR':
ret.append(placeholder(shape=t[1], dtype=t[2]))
else:
ret.append(t)
return ret


@tvm._ffi.register_func("auto_scheduler.workload_key_to_tensors")
def workload_key_to_tensors(workload_key: str) -> List[Tensor]:
"""Decode a workload key to the input/output tensors"""
workload = json.loads(workload_key)
name = workload[0]
lookup = WORKLOAD_FUNC_REGISTRY[name]

if callable(lookup):
args = deserialize_args(workload[1:])
return lookup(*args)
else:
return lookup


@ tvm._ffi.register_func("auto_scheduler.workload_key_to_dag")
def workload_key_to_dag(workload_key: str) -> ComputeDAG:
"""Decode a workload key to a compute dag"""
tensors = workload_key_to_tensors(workload_key)
return ComputeDAG(tensors)


def make_workload_key_func(func: Union[str, Callable], args: Tuple) -> str:
"""make a workload key from function and arguments"""
args = serialize_args(args)

if callable(func):
func_name = func.__name__
elif isinstance(func, str):
func_name = func
else:
raise ValueError("Invalid function: " + str(func))

assert func_name in WORKLOAD_FUNC_REGISTRY, \
"%s is not registered. Please register it with register_auto_scheduler_workload_func" % func

return json.dumps((func_name,) + args)


def make_workload_key_bufs(bufs: List[Tensor]) -> str:
"""make a workload key from bufs"""
dag = ComputeDAG(bufs)
key = compute_dag_hash(dag)
return json.dumps((key,))


def dump_workload_func_registry(filename: str):
"""Dump workload function registry to a pickle binary file"""
global WORKLOAD_FUNC_REGISTRY

pickle.dump(WORKLOAD_FUNC_REGISTRY, open(filename, 'wb'))


def load_workload_func_registry(filename: str):
"""Load workload function registry from a pickle binary file"""
global WORKLOAD_FUNC_REGISTRY

WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, 'rb'))

2 changes: 2 additions & 0 deletions src/ansor/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,8 @@ void GetPerStmtFeaturesFromStates(const Array<State>& states,
for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) {
pool.Enqueue(GetPerStmtFeaturesWorkerFunc, task, states[i],
max_n_bufs, &(*features)[i], &error_ct);
//GetPerStmtFeaturesWorkerFunc(task, states[i],
// max_n_bufs, &(*features)[i], &error_ct);
}
pool.WaitBatch();

Expand Down
62 changes: 55 additions & 7 deletions src/ansor/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,13 +507,7 @@ bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) {
// skip comment lines begin with '#' or ' '
continue;
}

try {
ReadMeasureRecord(cur_line, inp, res, &log_version);
} catch (...) {
return false;
}

ReadMeasureRecord(cur_line, inp, res, &log_version);
return true;
}

Expand Down Expand Up @@ -607,5 +601,59 @@ TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext")
}
});

TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Array<MeasureInput> inputs = args[0];
SearchTask external_task;

if (args.size() > 1) {
external_task = args[1];
}

Array<State> states;
states.reserve(inputs.size());

// (workload_key, target) -> (search_task)
std::unordered_map<std::pair<std::string, std::string>, SearchTask> task_cache;

for (const auto& inp : inputs) {
const std::string& workload_key = inp->task->workload_key;
std::pair<std::string, std::string> key(workload_key, inp->task->target->str());

const SearchTaskNode* ptask;
if (external_task.defined()) {
ptask = external_task.operator->();
} else {
auto find_res = task_cache.find(key);
if (find_res == task_cache.end()) {
if (inp->task->compute_dag.defined()) { // the measure input is complete
ptask = inp->task.operator->();
} else { // the measure input is incomplete
// rebuild task for incomplete measure pairs read from file
SearchTask new_task = SearchTaskNode::make(
ComputeDAGNode::make_by_workload_key(workload_key),
workload_key,
inp->task->target,
inp->task->target_host,
inp->task->hardware_params);
task_cache.insert(std::make_pair(key, new_task));
ptask = new_task.operator->();
}
} else {
ptask = find_res->second.operator->();
}
}

State tmp_s = ptask->compute_dag.GetInitState();
StateNode *ps = tmp_s.CopyOnWrite();
ps->transform_steps = inp->state->transform_steps;
tmp_s.DoSteps(ps->transform_steps, ptask->compute_dag);
states.push_back(std::move(tmp_s));
}

*ret = states;
});


} // namespace ansor
} // namespace tvm
Loading

0 comments on commit cfe58d7

Please sign in to comment.