Skip to content

Commit

Permalink
Fix xgb error & Simplify dispatcher (apache#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Jun 20, 2020
1 parent 2c27816 commit 0794875
Show file tree
Hide file tree
Showing 14 changed files with 70 additions and 240 deletions.
2 changes: 1 addition & 1 deletion python/tvm/ansor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
workload_key_to_dag, make_workload_key_func
from .task_scheduler import TaskScheduler, SimpleTaskScheduler
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \
FallbackContext, clear_fallback_cache, ApplyGraphBest
FallbackContext
from .relay_integration import extract_from_program, extract_from_multiple_program, \
finish_layout_rewrite, prepare_layout_rewrite, auto_schedule_topi
from .env import GLOBAL_SCOPE
1 change: 0 additions & 1 deletion python/tvm/ansor/auto_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class MetaTileRewritePolicy(SearchPolicy):
seed: int
Random seed
"""

def __init__(self,
program_cost_model,
params=None,
Expand Down
19 changes: 18 additions & 1 deletion python/tvm/ansor/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def get_init_state(self):

def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel.NO_REWRITE):
"""
Apply transform steps according to the history of a state
Parameters
----------
state : StateObject
Expand All @@ -68,6 +70,8 @@ def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel.

def print_python_code_from_state(self, state):
"""
Print transform steps in the history of a state as TVM's python schedule primitive
Parameters
----------
state : StateObject
Expand All @@ -81,16 +85,29 @@ def print_python_code_from_state(self, state):

def infer_bound_from_state(self, state):
"""
Infer bound for a state
Parameters
----------
state : StateObject
Returns
-------
state : StateObject
state : State
"""
state_obj = state if isinstance(state, StateObject) else state.state_object
return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self)

def rewrite_layout_from_state(self, state: State):
"""
Rewrite the layout according to the transform steps in the history of a state
Parameters
----------
state : StateObject
Returns
-------
state : StateObject
"""
return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state)
5 changes: 4 additions & 1 deletion python/tvm/ansor/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,20 @@

@tvm._ffi.register_object("ansor.CostModel")
class CostModel(Object):
"""The base class for cost model"""
pass


@tvm._ffi.register_object("ansor.RandomModel")
class RandomModel(Object):
"""A model returns random estimation for all inputs"""
def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.RandomModel)


# A random number generator func for c++'s RandomModel
@tvm._ffi.register_func("ansor.cost_model.random_number")
def random_number(n, return_ptr):
""" A random number generator func for c++'s RandomModel """
if n == 0:
return
return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float))
Expand All @@ -47,6 +49,7 @@ def random_number(n, return_ptr):

@tvm._ffi.register_object("ansor.PythonBasedModel")
class PythonBasedModel(CostModel):
"""Base class for cost models implemented in python"""
def __init__(self):
def update_func(inputs, results):
self.update(inputs, results)
Expand Down
12 changes: 5 additions & 7 deletions python/tvm/ansor/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@
# under the License.

"""Cost model based on xgboost"""
from typing import List
import multiprocessing
import logging
import time
from collections import defaultdict

import numpy as np
import xgboost as xgb

from ...autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve
from tvm.autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve
from .cost_model import PythonBasedModel
from ..feature import get_per_stmt_features_from_measure_pairs, get_per_stmt_features_from_states
from ..serialization import LogReader
Expand Down Expand Up @@ -65,8 +63,8 @@ def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None):
# todo(lmzheng): automatically decrease learning rate when the loss is too large

'n_gpus': 0,
'n_threads': multiprocessing.cpu_count() / 2,
'silent': 0,
'nthread': multiprocessing.cpu_count() // 2,
'verbosity': 0,
'seed': seed or 43,
'disable_default_eval_metric': 1
}
Expand Down Expand Up @@ -180,7 +178,7 @@ def pack_sum_xgbmatrix_for_prediction(xs):
x_flatten.append(row)
pack_ids.append(ct)

return xgb.DMatrix(x_flatten), pack_ids
return xgb.DMatrix(np.array(x_flatten)), pack_ids


def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None):
Expand Down Expand Up @@ -214,7 +212,7 @@ def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None):
y_flatten.append(y)
pack_ids.append(ct)

ret = xgb.DMatrix(x_flatten, y_flatten)
ret = xgb.DMatrix(np.array(x_flatten), y_flatten)
if weights is not None:
ret.set_weight(weights_flatten)
dmatrix_context.put('pack_ids', ret, np.array(pack_ids))
Expand Down
Loading

0 comments on commit 0794875

Please sign in to comment.