Skip to content

Commit

Permalink
[AutoScheduler] Add function name in message (apache#7703)
Browse files Browse the repository at this point in the history
* [AutoScheduler] Add function name in message

* fix
  • Loading branch information
comaniac authored and trevor-m committed May 11, 2021
1 parent dd5dbd9 commit 689beed
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 23 deletions.
49 changes: 29 additions & 20 deletions python/tvm/auto_scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class DispatchContext(object):
def __init__(self):
self._old_ctx = DispatchContext.current

def query(self, target, workload_key, has_complex_op, dag):
def query(self, target, workload_key, has_complex_op, dag, func_name):
"""
Query the context to get the specific config for a workload.
If cannot find the result inside this context, this function will query it
Expand All @@ -66,15 +66,17 @@ def query(self, target, workload_key, has_complex_op, dag):
Whether this workload has at least one complex op.
dag: ComputeDAG
The ComputeDAG of the workload.
func_name: str
The function name of this workload.
Returns
-------
state : StateObject
The state that stores schedule configuration for the workload
"""
ret = self._query_inside(target, workload_key)
ret = self._query_inside(target, workload_key, func_name)
if ret is None:
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag, func_name)
return ret

def update(self, target, workload_key, state):
Expand All @@ -92,7 +94,7 @@ def update(self, target, workload_key, state):
"""
raise NotImplementedError()

def _query_inside(self, target, workload_key):
def _query_inside(self, target, workload_key, func_name):
"""
Query the context to get the specific config for a workload.
This function only query config inside this context.
Expand All @@ -103,6 +105,8 @@ def _query_inside(self, target, workload_key):
The current target
workload_key : str
The current workload_key.
func_name: str
The function name of this workload.
Returns
-------
Expand Down Expand Up @@ -241,7 +245,7 @@ def load(self, records, n_lines=None):

logger.debug("Finish loading %d records", counter)

def _query_inside(self, target, workload_key):
def _query_inside(self, target, workload_key, func_name):
if target is None:
raise RuntimeError(
"Need a target context to find the history best. "
Expand Down Expand Up @@ -343,18 +347,20 @@ def __init__(
records, n_lines=None, include_compatible=True
)

def query(self, target, workload_key, has_complex_op, dag):
def query(self, target, workload_key, has_complex_op, dag, func_name):
if has_complex_op or self.sample_simple_workloads:
ret = self._query_inside(target, workload_key)
ret = self._query_inside(target, workload_key, func_name)
else:
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
ret = super(ApplyHistoryBestOrSample, self)._query_inside(
target, workload_key, func_name
)

if ret is None:
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag, func_name)
return ret

def _query_inside(self, target, workload_key):
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
def _query_inside(self, target, workload_key, func_name):
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key, func_name)
if ret is not None:
return ret

Expand Down Expand Up @@ -386,7 +392,9 @@ def _query_inside(self, target, workload_key):

# Load the sampled records and query again.
self.load(log_file)
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
ret = super(ApplyHistoryBestOrSample, self)._query_inside(
target, workload_key, func_name
)

del measure_ctx
return ret
Expand All @@ -411,18 +419,19 @@ def __init__(self):
# a set to prevent print duplicated message
self.messages = set()

def query(self, target, workload_key, has_complex_op, dag):
def query(self, target, workload_key, has_complex_op, dag, func_name):
key = (str(target), workload_key)
if key in self.memory:
return self.memory[key]

if self.verbose == 2 or (has_complex_op and self.verbose == 1):
msg = (
"-----------------------------------\n"
"Cannot find tuned schedules for target=%s, workload_key=%s. "
"A fallback TOPI schedule is used, "
"which may bring great performance regression or even compilation failure. "
"Compute DAG info:\n%s" % (target, workload_key, dag)
f"-----------------------------------\n"
f"{func_name}\n"
f"Cannot find tuned schedules for target={target}, workload_key={workload_key}. "
f"A fallback TOPI schedule is used, "
f"which may bring great performance regression or even compilation failure. "
f"Compute DAG info:\n{dag}"
)
if msg not in self.messages:
self.messages.add(msg)
Expand All @@ -434,8 +443,8 @@ def query(self, target, workload_key, has_complex_op, dag):
self.memory[key] = state
return state

def _query_inside(self, target, workload_key):
_ = target = workload_key
def _query_inside(self, target, workload_key, func_name):
_ = target = workload_key = func_name
raise RuntimeError("This function should never be called")

def update(self, target, workload_key, state):
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,17 @@ def traverse(t):


@tvm._ffi.register_func("auto_scheduler.relay_integration.auto_schedule_topi_compute")
def auto_schedule_topi(outs):
def auto_schedule_topi(func_name, outs):
"""Use auto-scheduler to schedule any topi compute function.
Note: This is used internally for relay integration. Do
not use this as a general user-facing API.
Parameters
----------
func_name: str
The name of the function being scheduled.
outs: List[Tensor]
The output tensors of topi compute functions
Expand All @@ -289,7 +292,7 @@ def auto_schedule_topi(outs):
target = tvm.target.Target.current()

dispatch_ctx = DispatchContext.current
state = dispatch_ctx.query(target, key, has_complex_op, dag)
state = dispatch_ctx.query(target, key, has_complex_op, dag, func_name)
schedule = None

env = TracingEnvironment.current
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
ICHECK(fauto_schedule != nullptr)
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
ObjectRef obj = (*fauto_schedule)(tensor_outs);
ObjectRef obj = (*fauto_schedule)(String(cache_node->func_name), tensor_outs);
if (obj.defined()) {
schedule = Downcast<te::Schedule>(obj);
}
Expand Down

0 comments on commit 689beed

Please sign in to comment.