Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AUTOTVM] Fix GATuner and improve error message #1605

Merged
merged 2 commits into from
Aug 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ class ExternOpNode : public OperationNode {
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("input_placeholders", &input_placeholders);
v->Visit("output_placeholders", &output_placeholders);
v->Visit("body", &body);
}
EXPORT static Operation make(std::string name,
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
msg = str(exc)
if "Stack trace returned" in msg:
msg = msg[:msg.index("Stack trace returned")]
if "CUDA Source" in msg:
msg = msg[:msg.index("CUDA Source")]
costs = (RuntimeError(msg),)
errno = MeasureErrorNo.RUNTIME_DEVICE
tstamp = time.time()
Expand Down
15 changes: 14 additions & 1 deletion python/tvm/autotvm/task/nnvm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@

"""
import warnings
import logging


from ... import tensor, placeholder, target as _target

from ..util import get_const_tuple
from .task import create, register
from .dispatcher import ApplyHistoryBest

logger = logging.getLogger('autotvm')

def serialize_args(args):
"""serialize arguments of a topi function to a hashable tuple.
Expand Down Expand Up @@ -176,8 +180,17 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):

# run compiler to collect all TOPI calls during compilation
env.reset()

# disable logger temporarily
old_state = logger.disabled
logger.disabled = True

# use a dummy target to do a fake compile for collecting topi calls
dummy_target = _target.create("opencl -device=dummy")
nnvm.compiler.build(graph, target=dummy_target, shape=shape, dtype=dtype)
with ApplyHistoryBest([], allow_fallback=True):
nnvm.compiler.build(graph, target=dummy_target, shape=shape, dtype=dtype)

logger.disabled = old_state

tasks = []
for task_name, args in env.get_tasks():
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def traverse(ops):
pass
else:
raise FlopCalculationError("Only support tvm.compute currently. "
"Other ops like tvm.scan is not supported")
"Other ops like tvm.scan/tvm.extern is not supported")
return ret

try:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _decorator(f):
for target_key in targets:
if target_key not in _REGISTED_DISPATHCER:
_REGISTED_DISPATHCER[target_key] = {}
if topi_compute not in _REGISTED_DISPATHCER:
if topi_compute not in _REGISTED_DISPATHCER[target_key]:
@topi_compute.register(target_key)
@dispatcher
def config_dispatcher(*args, **kwargs):
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/autotvm/tuner/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,17 @@ def __init__(self):
self.total = total

def __del__(self):
sys.stdout.write(' Done.\n')
if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
sys.stdout.write(' Done.\n')

ctx = _Context()
tic = time.time()

if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
'| %.2f s' % (prefix, 0, 0, 0, total, time.time() - tic))
sys.stdout.flush()

def _callback(tuner, inputs, results):
ctx.ct += len(inputs)

Expand Down
14 changes: 10 additions & 4 deletions python/tvm/autotvm/tuner/ga_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, task, pop_size, elite_num=3, mutation_prob=0.1):

# random initialization
self.pop_size = min(self.pop_size, len(self.space))
self.elite_num = min(self.pop_size, self.elite_num)
for _ in range(self.pop_size):
tmp_gene = point2knob(np.random.randint(len(self.space)), self.dims)
while knob2point(tmp_gene, self.dims) in self.visited:
Expand All @@ -70,9 +71,9 @@ def update(self, inputs, results):
y = inp.task.flop / np.mean(res.costs)
self.scores.append(y)
else:
self.scores.append(0)
self.scores.append(0.0)

if len(self.scores) >= len(self.genes):
if len(self.scores) >= len(self.genes) and len(self.visited) < len(self.space):
genes = self.genes + self.elites
scores = np.array(self.scores[:len(self.genes)] + self.elite_scores)

Expand All @@ -85,8 +86,13 @@ def update(self, inputs, results):

# cross over
indices = np.arange(len(genes))
scores /= np.max(scores)
probs = scores / np.sum(scores)
max_score = np.max(scores)
if max_score < 1e-8:
probs = np.empty_like(scores)
probs[:] = 1.0 / len(scores)
else:
scores /= max_score
probs = scores / np.sum(scores)
tmp_genes = []
for _ in range(self.pop_size):
p1, p2 = np.random.choice(indices, size=2, replace=False, p=probs)
Expand Down