diff --git a/contrib/tvmop/core/multiarray.py b/contrib/tvmop/core/multiarray.py index ba72fe20d190..c8eed5b45368 100644 --- a/contrib/tvmop/core/multiarray.py +++ b/contrib/tvmop/core/multiarray.py @@ -32,7 +32,7 @@ def compute_dot(A, B): return C -@defop(name="dot", target="cpu", dispatch=True, dtype=AllTypes) +@defop(name="dot", target="cpu", dtype=AllTypes) def dot(dtype, fallback): cfg = autotvm.get_config() cfg.define_knob("bn", [64] if fallback else [64, 32]) diff --git a/contrib/tvmop/opdef.py b/contrib/tvmop/opdef.py index 57537bac648e..1e0f34669b10 100644 --- a/contrib/tvmop/opdef.py +++ b/contrib/tvmop/opdef.py @@ -17,6 +17,7 @@ # coding: utf-8 import tvm +import inspect from tvm import autotvm from itertools import product @@ -48,7 +49,7 @@ class OpDef: without considering whether dimension size equals to one. TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1. """ - def __init__(self, func, name, target, auto_broadcast, dispatch, **kwargs): + def __init__(self, func, name, target, auto_broadcast, **kwargs): # construct the value combination of the arguments # e.g., ldtype=["float32", "int32"], rdtype=["float16", "int16"] # arg_combination = [ @@ -69,7 +70,7 @@ def __init__(self, func, name, target, auto_broadcast, dispatch, **kwargs): self.name = name self.target = target self.auto_broadcast = auto_broadcast - self.dispatch = dispatch + self.dispatchable = 'fallback' in inspect.signature(self.func).parameters def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) @@ -79,7 +80,7 @@ def invoke_all(self): if self.attrs_valid(**each_kwargs): name = self.name \ + ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) - if self.dispatch is False: + if self.dispatchable is False: sch, args = self.func(**each_kwargs) yield sch, args, name else: @@ -105,7 +106,7 @@ def get_op_name(self, name, args): def get_config_spaces(self): for each_kwargs in self.arg_combination: - if self.attrs_valid(**each_kwargs) and self.dispatch is True: + if self.attrs_valid(**each_kwargs) and self.dispatchable is True: name = self.name \ + ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) config_space = autotvm.ConfigSpace() @@ -120,7 +121,7 @@ def get_binds(self, args): return None -def defop(name, target=None, auto_broadcast=False, dispatch=False, **kwargs): +def defop(name, target=None, auto_broadcast=False, **kwargs): """Decorator to define a tvm operator. Parameters ---------- @@ -141,7 +142,7 @@ def defop(name, target=None, auto_broadcast=False, dispatch=False, **kwargs): target = "cpu" if target is None else target def _defop(func): - opdef = OpDef(func, name, target, auto_broadcast, dispatch, **kwargs) + opdef = OpDef(func, name, target, auto_broadcast, **kwargs) __OP_DEF__.append(opdef) return opdef return _defop