From f278c421683500c5aebc895592c2e9cd7661f058 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Fri, 4 Dec 2020 15:02:25 +0100 Subject: [PATCH] Save PyTorch frontend state in object (#7023) While the functional approach is pretty neat, we ended up having global state (default frontend, dtype) and it'll be more soon (caching of inferred types, see #6900). To not have to pass around the state, this moves the op conversion into a class with instances having the state. --- python/tvm/relay/frontend/pytorch.py | 2013 ++++++++++---------------- 1 file changed, 774 insertions(+), 1239 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 38478e27ff92..4f75cf380cc6 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -17,6 +17,7 @@ # pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks # pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except # pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda +# pylint: disable=missing-function-docstring """PT: PyTorch frontend.""" import itertools import logging @@ -133,16 +134,24 @@ def _is_quantized_tensor(data, prelude): # operator implementation -def _elemwise(name): - def _impl(inputs, input_types): - data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) - return get_relay_op(name)(data0, data1) - return _impl +class PyTorchOpConverter: + """A helper class for holding PyTorch op converters.""" + + def __init__(self, prelude, default_dtype): + self.prelude = prelude + self.default_dtype = default_dtype + self.create_convert_map() + + def make_elemwise(self, name): + def elemwise(inputs, input_types): + data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) + return get_relay_op(name)(data0, data1) + + return elemwise -def _min_max_common(name_elemwise, name_reduce): - def _impl(inputs, input_types): + def min_max_common(self, name_elemwise, name_reduce, inputs, input_types): if len(inputs) == 1: data = _pytorch_promote_types(inputs[:1], input_types[:1]) return get_relay_op(name_reduce)(data[0]) @@ -156,38 +165,27 @@ def _impl(inputs, input_types): data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) return get_relay_op(name_elemwise)(data0, data1) - return _impl - + def max(self, inputs, input_types): + return self.min_max_common("maximum", "max", inputs, input_types) -def _max(): - return _min_max_common("maximum", "max") + def min(self, inputs, input_types): + return self.min_max_common("minimum", "min", inputs, input_types) + def make_unary(self, name): + def unary(inputs, input_types): + # this is just to ensure tensor input + (data,) = _pytorch_promote_types(inputs[:1], input_types[:1]) + return get_relay_op(name)(data) -def _min(): - return _min_max_common("minimum", "min") + return unary - -def _unary(name): - def _impl(inputs, input_types): - # this is just to ensure tensor input - (data,) = _pytorch_promote_types(inputs[:1], input_types[:1]) - return get_relay_op(name)(data) - - return _impl - - -def _log1p(): - def _impl(inputs, input_types): + def log1p(self, inputs, input_types): # 1_plus_log x = log(x + 1) (dtype,) = input_types one = _expr.const(1, dtype=dtype) return _op.log(inputs[0] + one) - return _impl - - -def _arange(): - def _impl(inputs, input_types): + def arange(self, inputs, input_types): def _get_value(val, dtype): # dtype is a tvm dtype if isinstance(val, _expr.Expr): @@ -235,11 +233,7 @@ def _get_type(val, inp_type): return _op.transform.arange(start=start, stop=stop, step=step, dtype=dtype) - return _impl - - -def _squeeze(): - def _impl(inputs, input_types): + def squeeze(self, inputs, input_types): data = inputs[0] if len(inputs) == 1: axis = None @@ -249,33 +243,27 @@ def _impl(inputs, input_types): return _op.transform.squeeze(data, axis) - return _impl - - -def _unsqueeze(): - def _impl(inputs, input_types): + def unsqueeze(self, inputs, input_types): data = inputs[0] axis = inputs[1] return _op.transform.expand_dims(data, int(axis), 1) - return _impl - - -def _concatenate(prelude): - def tensor_array_concat(lst, axis): - assert axis == 0, "Tensor array concat supported only for axis 0" - tensor_array, shape = _convert_to_tensor_array(lst, prelude) - concat_shape = (Any(),) + shape[1:] - concat = prelude.get_global_var_static("tensor_array_concat", "float32", shape) - concatenated = concat(tensor_array) - - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape) - static_tensor_array_ops.register() - get_tensor = prelude.get_global_var_static("tensor_get_data", "float32", concat_shape) - return get_tensor(concatenated) + def concatenate(self, inputs, input_types): + def tensor_array_concat(lst, axis): + assert axis == 0, "Tensor array concat supported only for axis 0" + tensor_array, shape = _convert_to_tensor_array(lst, self.prelude) + concat_shape = (Any(),) + shape[1:] + concat = self.prelude.get_global_var_static("tensor_array_concat", "float32", shape) + concatenated = concat(tensor_array) + + static_tensor_array_ops = StaticTensorArrayOps(self.prelude, "float32", concat_shape) + static_tensor_array_ops.register() + get_tensor = self.prelude.get_global_var_static( + "tensor_get_data", "float32", concat_shape + ) + return get_tensor(concatenated) - def _impl(inputs, input_types): data = inputs[0] axis = inputs[1] @@ -287,11 +275,7 @@ def _impl(inputs, input_types): return _op.tensor.concatenate(data, int(axis)) - return _impl - - -def _slice(): - def _impl(inputs, input_types): + def slice(self, inputs, input_types): axis_dtype = "int64" index_size_limit = 2 ** 63 - 1 data = inputs[0] @@ -391,11 +375,7 @@ def _impl(inputs, input_types): data, begin=begin, end=end, strides=strides, slice_mode="end" ) - return _impl - - -def _split(): - def _impl(inputs, input_types): + def split(self, inputs, input_types): data = inputs[0] split_size = int(inputs[1]) dim = int(inputs[2]) @@ -408,11 +388,7 @@ def _impl(inputs, input_types): return _op.split(data, indices, dim) - return _impl - - -def _split_with_sizes(): - def _impl(inputs, input_types): + def split_with_sizes(self, inputs, input_types): data = inputs[0] sections = inputs[1] dim = int(inputs[2]) @@ -430,31 +406,19 @@ def _impl(inputs, input_types): return _op.split(data, indices, dim) - return _impl - - -def _select(): - def _impl(inputs, input_types): + def select(self, inputs, input_types): data = inputs[0] dim = int(inputs[1]) index = _wrap_const(inputs[2]) return _op.transform.take(data, index, axis=dim) - return _impl - - -def _take(): - def _impl(inputs, input_types): + def take(self, inputs, input_types): data = inputs[0] indices = _op.cast(inputs[1], "int32") return _op.transform.take(data, indices=indices) - return _impl - - -def _topk(): - def _impl(inputs, input_types): + def topk(self, inputs, input_types): data = inputs[0] axis = int(inputs[2]) is_ascend = not bool(inputs[3]) @@ -473,28 +437,16 @@ def _impl(inputs, input_types): return outs[0], outs[1] - return _impl - - -def _reciprocal(): - def _impl(inputs, input_types): + def reciprocal(self, inputs, input_types): data = inputs[0] return _expr.const(1.0, dtype=input_types[0]) / data - return _impl - - -def _repeat(): - def _impl(inputs, input_types): + def repeat(self, inputs, input_types): data = inputs[0] reps = inputs[1] return _op.transform.tile(data, reps=reps) - return _impl - - -def _repeat_interleave(): - def _impl(inputs, input_types): + def repeat_interleave(self, inputs, input_types): data = inputs[0] if isinstance(inputs[1], int): repeats = inputs[1] @@ -507,77 +459,60 @@ def _impl(inputs, input_types): axis = 0 return _op.transform.repeat(data, repeats=repeats, axis=axis) - return _impl - - -def _addcdiv(): - def _impl(inputs, input_types): + def addcdiv(self, inputs, input_types): data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4]) return data + (c * (t1 / t2)) - return _impl - - -def _addcmul(): - def _impl(inputs, input_types): + def addcmul(self, inputs, input_types): data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4]) return data + (c * (t1 * t2)) - return _impl - - -def _where(): - def _impl(inputs, input_types): + def where(self, inputs, input_types): if len(inputs) == 1: - return _nonzero(False)([inputs[0], True], input_types) + return self.nonzero([inputs[0], True], input_types) cond = inputs[0] x, y = _pytorch_promote_types(inputs[1:3], input_types[1:3]) return _op.where(cond, x, y) - return _impl - - -def _full_impl(data, fill_value, dtype): - size = [] - need_reshape = False - new_shape = [] - for dim in data: - if isinstance(dim, _expr.Expr): - if isinstance(dim, _expr.Constant): - dim = int(dim.data.asnumpy()) - if isinstance(size, list): - size.append(dim) - new_shape.append(dim) - else: - dim, success = try_infer_value(dim, lambda ret: int(ret), lambda: 0) - new_shape.append(dim) - - if success: + def full_impl(self, data, fill_value, dtype): + size = [] + need_reshape = False + new_shape = [] + for dim in data: + if isinstance(dim, _expr.Expr): + if isinstance(dim, _expr.Constant): + dim = int(dim.data.asnumpy()) if isinstance(size, list): size.append(dim) + new_shape.append(dim) else: - size = None - need_reshape = True - else: - if isinstance(size, list): - size.append(dim) - new_shape.append(dim) + dim, success = try_infer_value(dim, lambda ret: int(ret), lambda: 0) + new_shape.append(dim) - if size is None: - tmp = [] - for dim in data: - tmp.append(_op.cast(_op.expand_dims(dim, axis=0), "int64")) - size = _op.concatenate(tmp, axis=0) + if success: + if isinstance(size, list): + size.append(dim) + else: + size = None + need_reshape = True + else: + if isinstance(size, list): + size.append(dim) + new_shape.append(dim) - out = _op.full(_expr.const(fill_value), size, dtype=dtype) - if need_reshape: - out = _op.reshape(out, new_shape) - return out + if size is None: + tmp = [] + for dim in data: + tmp.append(_op.cast(_op.expand_dims(dim, axis=0), "int64")) + size = _op.concatenate(tmp, axis=0) + out = _op.full(_expr.const(fill_value), size, dtype=dtype) + if need_reshape: + out = _op.reshape(out, new_shape) + return out -def _ones(default_dtype): - def _impl(inputs, input_types): + def ones(self, inputs, input_types): data = inputs[0] import torch @@ -589,14 +524,10 @@ def _impl(inputs, input_types): if inputs[1] is not None: dtype = _convert_dtype_value(inputs[1]) else: - dtype = default_dtype - return _full_impl(data, 1, dtype) - - return _impl + dtype = self.default_dtype + return self.full_impl(data, 1, dtype) - -def _ones_like(default_dtype): - def _impl(inputs, input_types): + def ones_like(self, inputs, input_types): data = inputs[0] out = _op.ones_like(data) @@ -604,17 +535,13 @@ def _impl(inputs, input_types): if inputs[1] is not None: dtype = _convert_dtype_value(inputs[1]) else: - dtype = default_dtype + dtype = self.default_dtype if input_types[0] != dtype: out = _op.cast(out, dtype) return out - return _impl - - -def _zeros(default_dtype): - def _impl(inputs, input_types): + def zeros(self, inputs, input_types): data = inputs[0] import torch @@ -626,14 +553,10 @@ def _impl(inputs, input_types): if inputs[1] is not None: dtype = _convert_dtype_value(inputs[1]) else: - dtype = default_dtype - return _full_impl(data, 0, dtype) - - return _impl - + dtype = self.default_dtype + return self.full_impl(data, 0, dtype) -def _zeros_like(default_dtype): - def _impl(inputs, input_types): + def zeros_like(self, inputs, input_types): data = inputs[0] out = _op.zeros_like(data) @@ -641,17 +564,13 @@ def _impl(inputs, input_types): if inputs[1] is not None: dtype = _convert_dtype_value(inputs[1]) else: - dtype = default_dtype + dtype = self.default_dtype if input_types[0] not in dtype: out = _op.cast(out, dtype) return out - return _impl - - -def _full(default_dtype): - def _impl(inputs, input_types): + def full(self, inputs, input_types): data = inputs[0] fill_value = inputs[1] @@ -665,15 +584,11 @@ def _impl(inputs, input_types): dtype = _convert_dtype_value(inputs[2]) else: # if dtype is None, torch uses a global default set by torch.set_default_tensor_type() - dtype = default_dtype - - return _full_impl(data, fill_value, dtype) + dtype = self.default_dtype - return _impl + return self.full_impl(data, fill_value, dtype) - -def _full_like(default_dtype): - def _impl(inputs, input_types): + def full_like(self, inputs, input_types): data = inputs[0] fill_value = inputs[1] @@ -684,17 +599,13 @@ def _impl(inputs, input_types): dtype = _convert_dtype_value(inputs[2]) else: # if dtype is None, torch uses a global default set by torch.set_default_tensor_type() - dtype = default_dtype + dtype = self.default_dtype if input_types[0] not in dtype: out = _op.cast(out, dtype) return out - return _impl - - -def _linspace(): - def _impl(inputs, input_types): + def linspace(self, inputs, input_types): start = inputs[0] stop = inputs[1] step = inputs[2] @@ -713,51 +624,31 @@ def _impl(inputs, input_types): return _op.transform.arange(start=start, stop=stop, step=step, dtype=dtype) - return _impl - - -def _relu(prelude): - def _impl(inputs, input_types): + def relu(self, inputs, input_types): data = inputs[0] - if _is_quantized_tensor(data, prelude): + if _is_quantized_tensor(data, self.prelude): assert len(inputs) == 3, "Input quant param not found in op inputs" input_zero_point = _expr.const(inputs[2], dtype="int32") return qnn_torch.quantized_relu(data, input_zero_point) return _op.nn.relu(data) - return _impl - - -def _prelu(): - def _impl(inputs, input_types): + def prelu(self, inputs, input_types): data = inputs[0] alpha = inputs[1] return _op.nn.prelu(data, alpha) - return _impl - - -def _leaky_relu(): - def _impl(inputs, input_types): + def leaky_relu(self, inputs, input_types): data = inputs[0] alpha = float(inputs[1]) return _op.nn.leaky_relu(data, alpha) - return _impl - - -def _elu(): - def _impl(inputs, input_types): + def elu(self, inputs, input_types): data = inputs[0] dtype = input_types[0] alpha = _expr.const(float(inputs[1]), dtype=dtype) return alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data) - return _impl - - -def _celu(): - def _impl(inputs, input_types): + def celu(self, inputs, input_types): data = inputs[0] dtype = input_types[0] alpha = _expr.const(float(inputs[1]), dtype=dtype) @@ -765,11 +656,7 @@ def _impl(inputs, input_types): _expr.const(1, dtype=dtype) - _op.exp(data / alpha) ) + _op.nn.relu(data) - return _impl - - -def _gelu(): - def _impl(inputs, input_types): + def gelu(self, inputs, input_types): data = inputs[0] dtype = input_types[0] # gelu is data * normcdf(data) @@ -781,11 +668,7 @@ def _impl(inputs, input_types): + _op.erf(data * _expr.const(0.5 ** 0.5, dtype=dtype)) * _expr.const(0.5, dtype=dtype) ) - return _impl - - -def _selu(): - def _impl(inputs, input_types): + def selu(self, inputs, input_types): data = inputs[0] # https://pytorch.org/docs/stable/nn.html#selu dtype = input_types[0] @@ -795,65 +678,41 @@ def _impl(inputs, input_types): alpha * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data) ) - return _impl - - -def _log_sigmoid(): - def _impl(inputs, input_types): + def log_sigmoid(self, inputs, input_types): data = inputs[0] return _op.log(_op.tensor.sigmoid(data)) - return _impl - - -def _adaptive_avg_pool_2d(prelude): - def _impl(inputs, input_types): + def adaptive_avg_pool_2d(self, inputs, input_types): data = inputs[0] output_size = inputs[1] def func(x): return _op.nn.adaptive_avg_pool2d(x, output_size=output_size) - if _is_quantized_tensor(data, prelude): + if _is_quantized_tensor(data, self.prelude): return qnn_torch.apply_with_upcast(data, func) return func(data) - return _impl - - -def _adaptive_max_pool_2d(): - def _impl(inputs, input_types): + def adaptive_max_pool_2d(self, inputs, input_types): data = inputs[0] output_size = inputs[1] # returns dummy indices too return _op.nn.adaptive_max_pool2d(data, output_size=output_size), None - return _impl - - -def _adaptive_max_pool_3d(): - def _impl(inputs, input_types): + def adaptive_max_pool_3d(self, inputs, input_types): data = inputs[0] output_size = inputs[1] # returns dummy indices too return _op.nn.adaptive_max_pool3d(data, output_size=output_size), None - return _impl - - -def _adaptive_avg_pool_3d(): - def _impl(inputs, input_types): + def adaptive_avg_pool_3d(self, inputs, input_types): data = inputs[0] output_size = inputs[1] return _op.nn.adaptive_avg_pool3d(data, output_size=output_size) - return _impl - - -def _maxpool_2d(): - def _impl(inputs, input_types): + def maxpool_2d(self, inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -868,19 +727,11 @@ def _impl(inputs, input_types): return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode) - return _impl - - -def _maxpool_2d_with_indices(): - def _impl(inputs, input_types): + def maxpool_2d_with_indices(self, inputs, input_types): # returns dummy indices too - return _maxpool_2d()(inputs, input_types), None - - return _impl + return self.maxpool_2d(inputs, input_types), None - -def _maxpool_1d(): - def _impl(inputs, input_types): + def maxpool_1d(self, inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -895,11 +746,7 @@ def _impl(inputs, input_types): return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode) - return _impl - - -def _maxpool_3d(): - def _impl(inputs, input_types): + def maxpool_3d(self, inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -915,21 +762,13 @@ def _impl(inputs, input_types): data, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode ) - return _impl - - -def _hardtanh(): - def _impl(inputs, input_types): + def hardtanh(self, inputs, input_types): a = inputs[0] tanh_min = float(inputs[1]) tanh_max = float(inputs[2]) return _op.tensor.clip(a, tanh_min, tanh_max) - return _impl - - -def _convolution(): - def _impl(inputs, input_types): + def convolution(self, inputs, input_types): # Use transpose or normal use_transpose = True if inputs[6] == 1 else False @@ -1018,11 +857,7 @@ def _impl(inputs, input_types): res = _op.squeeze(res, axis=[2]) return res - return _impl - - -def _softmax(): - def _impl(inputs, input_types): + def softmax(self, inputs, input_types): data = inputs[0] axis = inputs[1] if isinstance(axis, str): @@ -1030,27 +865,15 @@ def _impl(inputs, input_types): return _op.nn.softmax(data, axis=axis) - return _impl - - -def _threshold(): - def _impl(inputs, input_types): + def threshold(self, inputs, input_types): data = inputs[0] return _op.nn.relu(data) - return _impl - - -def _contiguous(): - def _impl(inputs, input_types): + def contiguous(self, inputs, input_types): data = inputs[0] return _op.tensor.copy(data) - return _impl - - -def _batch_norm(): - def _impl(inputs, input_types): + def batch_norm(self, inputs, input_types): data = inputs[0] data_type = input_types[0] @@ -1086,11 +909,7 @@ def _impl(inputs, input_types): scale=scale, )[0] - return _impl - - -def _instance_norm(): - def _impl(inputs, input_types): + def instance_norm(self, inputs, input_types): data = inputs[0] data_type = input_types[0] channels = _infer_shape(data) @@ -1114,28 +933,24 @@ def _impl(inputs, input_types): data, gamma, beta, axis=1, epsilon=epsilon, center=center, scale=scale ) - return _impl - - -def _get_dims(data): - import torch - - if isinstance(data, _expr.Expr): - dims = _infer_shape(data) - elif isinstance(data, list): - dims = data - elif isinstance(data, (torch.Tensor, np.ndarray)): - dims = data.shape - else: - msg = "Data type %s could not be parsed" % type(data) - raise AssertionError(msg) - return dims + @staticmethod + def get_dims(data): + import torch + if isinstance(data, _expr.Expr): + dims = _infer_shape(data) + elif isinstance(data, list): + dims = data + elif isinstance(data, (torch.Tensor, np.ndarray)): + dims = data.shape + else: + msg = "Data type %s could not be parsed" % type(data) + raise AssertionError(msg) + return dims -def _layer_norm(): - def _impl(inputs, input_types): + def layer_norm(self, inputs, input_types): data = inputs[0] - ndims = len(_get_dims(inputs[1])) + ndims = len(self.get_dims(inputs[1])) assert ndims == 1, "Support only normalization over last one dimension." return _op.nn.layer_norm( @@ -1148,11 +963,7 @@ def _impl(inputs, input_types): scale=True, ) - return _impl - - -def _group_norm(): - def _impl(inputs, input_types): + def group_norm(self, inputs, input_types): data = inputs[0] gamma = inputs[2] beta = inputs[3] @@ -1170,17 +981,13 @@ def _impl(inputs, input_types): scale=True, ) - return _impl - - -def _transpose(prelude): - def _impl(inputs, input_types): + def transpose(self, inputs, input_types): data = inputs[0] import torch if isinstance(data, _expr.Expr): - ndims = len(_infer_shape(data, prelude.mod)) + ndims = len(_infer_shape(data, self.prelude.mod)) elif isinstance(data, list): ndims = data elif isinstance(data, (torch.Tensor, np.ndarray)): @@ -1211,11 +1018,7 @@ def _impl(inputs, input_types): axes = inputs[1] return _op.transform.transpose(data, axes) - return _impl - - -def _flatten(): - def _impl(inputs, input_types): + def flatten(self, inputs, input_types): data = inputs[0] start = int(inputs[1]) end = int(inputs[2]) @@ -1237,11 +1040,7 @@ def _impl(inputs, input_types): out = _op.squeeze(out, axis=squeeze_axes) return out - return _impl - - -def _addmm(): - def _impl(inputs, input_types): + def addmm(self, inputs, input_types): input_mat = inputs[0] mat1 = inputs[1] data_type = input_types[1] @@ -1265,35 +1064,24 @@ def _impl(inputs, input_types): return dense_out + input_mat - return _impl - - -def _size(prelude): - def _impl_dynamic(inp, axis): - shape_dynamic = _op.shape_of(inp, dtype="int32") - if axis is not None: - return _op.take(shape_dynamic, _expr.const(axis), 0) - return shape_dynamic - - def _impl(inputs, input_types): - shape = _infer_shape(inputs[0], prelude.mod) + def size(self, inputs, input_types): + shape = _infer_shape(inputs[0], self.prelude.mod) axis = None if len(inputs) > 1: axis = int(inputs[1]) if any(map(lambda s: isinstance(s, tvm.tir.expr.Any), shape)): if axis is None or isinstance(shape[axis], tvm.tir.expr.Any): - return _impl_dynamic(inputs[0], axis) + shape_dynamic = _op.shape_of(inputs[0], dtype="int32") + if axis is not None: + return _op.take(shape_dynamic, _expr.const(axis), 0) + return shape_dynamic if axis is not None: return _expr.const(shape[axis]) return _expr.const(shape) - return _impl - - -def _numtotensor(): - def _impl(inputs, input_types): + def numtotensor(self, inputs, input_types): val = inputs[0] dtype = input_types[0] @@ -1307,18 +1095,10 @@ def _impl(inputs, input_types): arr = val * np.ones([]).astype(dtype) return arr - return _impl - - -def _tensortonum(): - def _impl(inputs, input_types): + def tensortonum(self, inputs, input_types): return inputs[0] - return _impl - - -def _view(): - def _impl(inputs, input_types): + def view(self, inputs, input_types): data = inputs[0] if len(inputs) == 3: @@ -1336,11 +1116,7 @@ def _impl(inputs, input_types): return _op.transform.reshape(data, new_shape) - return _impl - - -def _reshape(): - def _impl(inputs, input_types): + def reshape(self, inputs, input_types): data = inputs[0] new_shape = inputs[1] @@ -1371,11 +1147,7 @@ def _impl(inputs, input_types): new_shape = tmp_shape return _op.transform.reshape(data, new_shape) - return _impl - - -def _pixel_shuffle(prelude): - def _impl(inputs, input_types): + def pixel_shuffle(self, inputs, input_types): data = inputs[0] upscale_factor = inputs[1] upscale_squared = upscale_factor * upscale_factor @@ -1384,7 +1156,7 @@ def _impl(inputs, input_types): c % upscale_squared == 0 ), "input channel should be divisible by square of upscale_factor" - ndims = len(_infer_shape(data, prelude.mod)) + ndims = len(_infer_shape(data, self.prelude.mod)) axes = list(range(ndims)) num_inputs = len(inputs) oc = c // upscale_squared @@ -1402,46 +1174,26 @@ def _impl(inputs, input_types): data = _op.transform.transpose(data, axes) return _op.transform.reshape(data, out_shape) - return _impl - - -def _clone(): - def _impl(inputs, input_types): + def clone(self, inputs, input_types): data = inputs[0] return _op.tensor.copy(data) - return _impl - - -def _log_softmax(): - def _impl(inputs, input_types): + def log_softmax(self, inputs, input_types): data = inputs[0] axis = int(inputs[1]) return _op.nn.log_softmax(data, axis) - return _impl - - -def _sigmoid(): - def _impl(inputs, input_types): + def sigmoid(self, inputs, input_types): data = inputs[0] return _op.tensor.sigmoid(data) - return _impl - - -def _softplus(): - def _impl(inputs, input_types): + def softplus(self, inputs, input_types): data = inputs[0] dtype = input_types[0] beta = _expr.const(float(inputs[1]), dtype=dtype) return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta - return _impl - - -def _avg_pool2d(prelude): - def _impl(inputs, input_types): + def avg_pool2d(self, inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -1460,16 +1212,12 @@ def func(x): count_include_pad=count_include_pad, ) - if _is_quantized_tensor(data, prelude): + if _is_quantized_tensor(data, self.prelude): return qnn_torch.apply_with_upcast(data, func) return func(data) - return _impl - - -def _avg_pool3d(): - def _impl(inputs, input_types): + def avg_pool3d(self, inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -1487,41 +1235,32 @@ def _impl(inputs, input_types): count_include_pad=count_include_pad, ) - return _impl - - -def _dropout(): - def _impl(inputs, input_types): + def dropout(self, inputs, input_types): data = inputs[0] rate = float(inputs[1]) return _op.nn.dropout(data, rate) - return _impl - - -def _reduce(name): - def _impl(inputs, input_types): - data = inputs[0] - axis = None - keepdims = False - - if len(inputs) > 2: # default, torch have only data, axis=None, keepdims=False - if isinstance(inputs[1], int): - axis = int(inputs[1]) - elif _is_int_seq(inputs[1]): - axis = inputs[1] - else: - axis = list(_infer_shape(inputs[1])) - keepdims = bool(inputs[2]) + def make_reduce(self, name): + def reduce(inputs, input_types): + data = inputs[0] + axis = None + keepdims = False - return get_relay_op(name)(data, axis=axis, keepdims=keepdims) + if len(inputs) > 2: # default, torch have only data, axis=None, keepdims=False + if isinstance(inputs[1], int): + axis = int(inputs[1]) + elif _is_int_seq(inputs[1]): + axis = inputs[1] + else: + axis = list(_infer_shape(inputs[1])) + keepdims = bool(inputs[2]) - return _impl + return get_relay_op(name)(data, axis=axis, keepdims=keepdims) + return reduce -def _norm(): - def _impl(inputs, input_types): + def norm(self, inputs, input_types): data = inputs[0] dtype = input_types[0] axis = None @@ -1543,11 +1282,7 @@ def _impl(inputs, input_types): reci_order, ) - return _impl - - -def _frobenius_norm(): - def _impl(inputs, input_types): + def frobenius_norm(self, inputs, input_types): data = inputs[0] axis = None keepdims = False @@ -1557,11 +1292,7 @@ def _impl(inputs, input_types): return _op.sqrt(_op.reduce.sum((data * data), axis=axis, keepdims=keepdims)) - return _impl - - -def _std(): - def _impl(inputs, input_types): + def std(self, inputs, input_types): data = inputs[0] if len(inputs) == 2: axis = None @@ -1574,11 +1305,7 @@ def _impl(inputs, input_types): return _op.reduce.std(data, axis=axis, keepdims=keepdims, unbiased=unbiased) - return _impl - - -def _variance(): - def _impl(inputs, input_types): + def variance(self, inputs, input_types): data = inputs[0] if len(inputs) == 2: axis = None @@ -1591,11 +1318,7 @@ def _impl(inputs, input_types): return _op.reduce.variance(data, axis=axis, keepdims=keepdims, unbiased=unbiased) - return _impl - - -def _mean(prelude): - def _impl(inputs, input_types): + def mean(self, inputs, input_types): data = inputs[0] if inputs[1]: @@ -1615,7 +1338,7 @@ def _impl(inputs, input_types): def func(x): return _op.mean(x, axis, keepdims, exclude) - if _is_quantized_tensor(data, prelude): + if _is_quantized_tensor(data, self.prelude): assert len(inputs) == 6, "Input quant param not found in op inputs" input_scale = _expr.const(inputs[4]) input_zero_point = _expr.const(inputs[5]) @@ -1623,18 +1346,14 @@ def func(x): return func(data) - return _impl - - -def _chunk(prelude): - def _impl(inputs, input_types): + def chunk(self, inputs, input_types): data = inputs[0] num_chunks = int(inputs[1]) axis = int(inputs[2]) if isinstance(data, _expr.Expr): - inferred_shape = _infer_shape(data, prelude.mod) + inferred_shape = _infer_shape(data, self.prelude.mod) shape = [] for infer in inferred_shape: @@ -1670,18 +1389,14 @@ def _impl(inputs, input_types): return chunks - return _impl - - -def _matmul(prelude): - def _impl(inputs, input_types): + def matmul(self, inputs, input_types): inputs_0 = inputs[0] inputs_1 = inputs[1] # Need to check input shape as batch matmul must be supported. - a_shape = _infer_shape(inputs_0, prelude.mod) - b_shape = _infer_shape(inputs_1, prelude.mod) + a_shape = _infer_shape(inputs_0, self.prelude.mod) + b_shape = _infer_shape(inputs_1, self.prelude.mod) # When performing a batch matmul, we need to properly handle N-dim shapes. if len(a_shape) > 2 or len(b_shape) > 2: @@ -1689,8 +1404,8 @@ def _impl(inputs, input_types): a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]]) # Broadcast b to match batch size of a - new_b_shape = list(_infer_shape(b, prelude.mod)) - new_a_shape = _infer_shape(a, prelude.mod) + new_b_shape = list(_infer_shape(b, self.prelude.mod)) + new_a_shape = _infer_shape(a, self.prelude.mod) if new_a_shape[0] > new_b_shape[0]: new_b_shape[0] = new_a_shape[0] b = _op.broadcast_to(b, new_b_shape) @@ -1714,11 +1429,7 @@ def _impl(inputs, input_types): return out - return _impl - - -def _expand(): - def _impl(inputs, input_types): + def expand(self, inputs, input_types): data_in = inputs[0] shape = list(_infer_shape(data_in)) @@ -1740,85 +1451,64 @@ def _impl(inputs, input_types): return out - return _impl - - -def _int(): - def _impl(inputs, input_types): + def int(self, inputs, input_types): if isinstance(inputs[0], _expr.Expr): return inputs[0] return int(inputs[0]) - return _impl - - -def _identity(): - def _impl(inputs, input_types): + def identity(self, inputs, input_types): return inputs[0] - return _impl - - -def _none(): - def _impl(inputs, input_types): + def none(self, inputs, input_types): return None - return _impl - - -def _pad(mode): - def _impl(inputs, input_types): - data = inputs[0] - if isinstance(inputs[1], list): - pad_list = inputs[1] - else: - pad_list = list(_infer_shape(inputs[1])) - - # initialize paddings based on input len - pad_len = len(_infer_shape(data)) * 2 - paddings = [0] * pad_len - - if len(pad_list) >= 2: - paddings[-1] = pad_list[1] - paddings[-2] = pad_list[0] - if len(pad_list) >= 4: - paddings[-3] = pad_list[3] - paddings[-4] = pad_list[2] - if len(pad_list) >= 6: - paddings[-5] = pad_list[5] - paddings[-6] = pad_list[4] - - # group into tuple of 2 ints - paddings = [paddings[i : i + 2] for i in range(0, len(paddings), 2)] - - const_paddings = [] - for pad in paddings: - const_paddings.append([]) - for p in pad: - if not isinstance(p, int): - p = int(_infer_value(p, {}).asnumpy()) - const_paddings[-1].append(p) - - if mode == "constant": - return _op.nn.pad(data, const_paddings, pad_value=inputs[2], pad_mode=mode) - else: - return _op.nn.pad(data, const_paddings, pad_mode=mode) - - return _impl + def make_pad(self, mode): + def pad(inputs, input_types): + data = inputs[0] + if isinstance(inputs[1], list): + pad_list = inputs[1] + else: + pad_list = list(_infer_shape(inputs[1])) + + # initialize paddings based on input len + pad_len = len(_infer_shape(data)) * 2 + paddings = [0] * pad_len + + if len(pad_list) >= 2: + paddings[-1] = pad_list[1] + paddings[-2] = pad_list[0] + if len(pad_list) >= 4: + paddings[-3] = pad_list[3] + paddings[-4] = pad_list[2] + if len(pad_list) >= 6: + paddings[-5] = pad_list[5] + paddings[-6] = pad_list[4] + + # group into tuple of 2 ints + paddings = [paddings[i : i + 2] for i in range(0, len(paddings), 2)] + + const_paddings = [] + for pad in paddings: + const_paddings.append([]) + for p in pad: + if not isinstance(p, int): + p = int(_infer_value(p, {}).asnumpy()) + const_paddings[-1].append(p) + + if mode == "constant": + return _op.nn.pad(data, const_paddings, pad_value=inputs[2], pad_mode=mode) + else: + return _op.nn.pad(data, const_paddings, pad_mode=mode) + return pad -def _clamp(): - def _impl(inputs, input_types): + def clamp(self, inputs, input_types): data = inputs[0] amin = inputs[1] if inputs[1] else np.finfo(np.float32).min amax = inputs[2] if inputs[2] else np.finfo(np.float32).max return _op.clip(data, amin, amax) - return _impl - - -def _to(): - def _impl(inputs, input_types): + def to(self, inputs, input_types): data = inputs[0] dtype = inputs[1] if inputs[1] is not None and not isinstance(inputs[1], str) else inputs[2] # special handling for aten::to(data, 6, _, _, _) case @@ -1844,87 +1534,81 @@ def _impl(inputs, input_types): return ret - return _impl - - -def _get_upsample_out_size(inputs, method): - # This assumes a static shape - out_size = [] - if inputs[1] is not None: - for size in inputs[1]: - if not isinstance(size, int): - out_size.append(int(_infer_value(size, {}).asnumpy())) - else: - out_size.append(size) - else: - scale_index = 3 if method in ["bilinear", "trilinear"] else 2 - scales = inputs[scale_index] - assert scales is not None, "neither out size nor scale provided" - assert isinstance(scales, list) - ishape = _infer_shape(inputs[0]) - for i, scale in enumerate(scales): - out_size.append(int(math.floor(float(ishape[2 + i]) * scale))) - - return out_size - - -def _upsample(method, prelude): - def _impl(inputs, input_types): - data = inputs[0] - out_size = _get_upsample_out_size(inputs, method) - - if len(inputs) > 2 and method == "bilinear": - align_corners = inputs[2] - else: - align_corners = False - - if method == "nearest_neighbor": - coord_trans = "asymmetric" - elif align_corners: - coord_trans = "align_corners" + @staticmethod + def get_upsample_out_size(inputs, method): + # This assumes a static shape + out_size = [] + if inputs[1] is not None: + for size in inputs[1]: + if not isinstance(size, int): + out_size.append(int(_infer_value(size, {}).asnumpy())) + else: + out_size.append(size) else: - coord_trans = "half_pixel" - - def func(x): - return _op.image.resize(x, out_size, "NCHW", method, coord_trans) + scale_index = 3 if method in ["bilinear", "trilinear"] else 2 + scales = inputs[scale_index] + assert scales is not None, "neither out size nor scale provided" + assert isinstance(scales, list) + ishape = _infer_shape(inputs[0]) + for i, scale in enumerate(scales): + out_size.append(int(math.floor(float(ishape[2 + i]) * scale))) + + return out_size + + def make_upsample(self, method): + def upsample(inputs, input_types): + data = inputs[0] + out_size = self.get_upsample_out_size(inputs, method) + + if len(inputs) > 2 and method == "bilinear": + align_corners = inputs[2] + else: + align_corners = False - if _is_quantized_tensor(data, prelude): - # input qparams are manually appended by us - assert isinstance(inputs[-2], float) - assert isinstance(inputs[-1], int) - input_scale = _expr.const(inputs[-2]) - input_zero_point = _expr.const(inputs[-1]) - return qnn_torch.quantized_upsample(data, input_scale, input_zero_point, func) + if method == "nearest_neighbor": + coord_trans = "asymmetric" + elif align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" - return func(data) + def func(x): + return _op.image.resize(x, out_size, "NCHW", method, coord_trans) - return _impl + if _is_quantized_tensor(data, self.prelude): + # input qparams are manually appended by us + assert isinstance(inputs[-2], float) + assert isinstance(inputs[-1], int) + input_scale = _expr.const(inputs[-2]) + input_zero_point = _expr.const(inputs[-1]) + return qnn_torch.quantized_upsample(data, input_scale, input_zero_point, func) + return func(data) -def _upsample3d(method): - def _impl(inputs, input_types): - data = inputs[0] - out_size = _get_upsample_out_size(inputs, method) + return upsample - if len(inputs) > 2 and method == "trilinear": - align_corners = inputs[2] - else: - align_corners = False + def make_upsample3d(self, method): + def upsample3d(inputs, input_types): + data = inputs[0] + out_size = self.get_upsample_out_size(inputs, method) - if method == "nearest_neighbor": - coord_trans = "asymmetric" - elif align_corners: - coord_trans = "align_corners" - else: - coord_trans = "half_pixel" + if len(inputs) > 2 and method == "trilinear": + align_corners = inputs[2] + else: + align_corners = False - return _op.image.resize3d(data, out_size, "NCDHW", method, coord_trans) + if method == "nearest_neighbor": + coord_trans = "asymmetric" + elif align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" - return _impl + return _op.image.resize3d(data, out_size, "NCDHW", method, coord_trans) + return upsample3d -def _expand_as(): - def _impl(inputs, input_types): + def expand_as(self, inputs, input_types): target = inputs[1] t0 = _infer_type(inputs[0]).checked_type.dtype t1 = _infer_type(inputs[1]).checked_type.dtype @@ -1932,34 +1616,18 @@ def _impl(inputs, input_types): target = _op.cast(target, t0) return _op.broadcast_to_like(inputs[0], target) - return _impl - - -def _Bool(): - def _impl(inputs, input_types): + def Bool(self, inputs, input_types): assert len(inputs) == 1 return inputs[0] - return _impl - - -def _Float(): - def _impl(inputs, input_types): + def Float(self, inputs, input_types): assert len(inputs) == 1 return _op.cast(inputs[0], "float32") - return _impl - - -def _mm(): - def _impl(inputs, input_types): + def mm(self, inputs, input_types): return _op.nn.dense(inputs[0], inputs[1]) - return _impl - - -def _bitwise_not(): - def _impl(inputs, input_types): + def bitwise_not(self, inputs, input_types): data = inputs[0] # The input tensor must be of integral or Boolean types. # For bool tensors, it computes the logical NOT @@ -1970,11 +1638,7 @@ def _impl(inputs, input_types): return out - return _impl - - -def _bitwise_xor(): - def _impl(inputs, input_types): + def bitwise_xor(self, inputs, input_types): lhs = inputs[0] rhs = inputs[1] lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int") @@ -1982,91 +1646,55 @@ def _impl(inputs, input_types): return _op.bitwise_xor(lhs, rhs) - return _impl - - -def _logical_not(): - def _impl(inputs, input_types): + def logical_not(self, inputs, input_types): data = _wrap_const(inputs[0]) return _op.logical_not(_op.cast(data, "bool")) - return _impl - - -def _logical_xor(): - def _impl(inputs, input_types): + def logical_xor(self, inputs, input_types): lhs = _op.cast(inputs[0], "bool") rhs = _op.cast(inputs[1], "bool") return _op.logical_xor(lhs, rhs) - return _impl - - -def _list_getitem(prelude): - def _impl(inputs, input_types): - return prelude.nth(inputs[0], _wrap_const(inputs[1])) - - return _impl - - -def _list_len(prelude): - def _impl(inputs, input_types): - return prelude.length(inputs[0]) + def list_getitem(self, inputs, input_types): + return self.prelude.nth(inputs[0], _wrap_const(inputs[1])) - return _impl + def list_len(self, inputs, input_types): + return self.prelude.length(inputs[0]) - -def _type_as(): - def _impl(inputs, input_types): + def type_as(self, inputs, input_types): assert len(inputs) == 2 assert len(input_types) == 2 return _op.cast(inputs[0], input_types[1]) - return _impl - - -def _gather(): - def _impl(inputs, input_types): + def gather(self, inputs, input_types): data = inputs[0] axis = inputs[1] indices = inputs[2] return _op.gather(data, axis, indices) - return _impl - - -def _add(prelude): - # add_ is overloaded for tensor add and list concat - def _impl(inputs, input_types): + def add(self, inputs, input_types): + # add_ is overloaded for tensor add and list concat if input_types[0] == "ListType": - return prelude.concat(inputs[0], inputs[1]) - return _elemwise("add")(inputs, input_types) - - return _impl + return self.prelude.concat(inputs[0], inputs[1]) + return self.make_elemwise("add")(inputs, input_types) - -def _tensor_array_stack(prelude): - def _impl(inputs, input_types): + def tensor_array_stack(self, inputs, input_types): dim = inputs[1] assert dim == 0, "stacking on a dynamic tensor list only supported on a first axis" - tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude) + tensor_array, shape = _convert_to_tensor_array(inputs[0], self.prelude) stacked_shape = (Any(),) + shape - stack = prelude.get_global_var_static("tensor_array_stack", "float32", shape) + stack = self.prelude.get_global_var_static("tensor_array_stack", "float32", shape) stacked = stack(tensor_array) - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", stacked_shape) + static_tensor_array_ops = StaticTensorArrayOps(self.prelude, "float32", stacked_shape) static_tensor_array_ops.register() - get_tensor = prelude.get_global_var_static("tensor_get_data", "float32", stacked_shape) + get_tensor = self.prelude.get_global_var_static("tensor_get_data", "float32", stacked_shape) return get_tensor(stacked) - return _impl - - -def _stack(prelude): - def _impl(inputs, input_types): + def stack(self, inputs, input_types): if isinstance(inputs[0], list): # a static python list of tensors dim = inputs[1] @@ -2074,17 +1702,13 @@ def _impl(inputs, input_types): else: # List ADT case assert isinstance(inputs[0], _expr.Expr) - ty = _infer_type_with_prelude(inputs[0], prelude) - list_ty = prelude.mod.get_global_type_var("List") + ty = _infer_type_with_prelude(inputs[0], self.prelude) + list_ty = self.prelude.mod.get_global_type_var("List") msg = "The input list is expected to be List ADT" assert isinstance(ty, tvm.ir.TypeCall) and ty.func == list_ty, msg - return _tensor_array_stack(prelude)(inputs, input_types) - - return _impl - + return self.tensor_array_stack(inputs, input_types) -def _rsub(): - def _impl(inputs, input_types): + def rsub(self, inputs, input_types): data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) # TODO (t-vi): should this also be part of the type promotion? @@ -2093,21 +1717,13 @@ def _impl(inputs, input_types): # note: rsub means data0 and data1 swap places return get_relay_op("subtract")(data1, alpha * data0) - return _impl - - -def _embedding(): - def _impl(inputs, input_types): + def embedding(self, inputs, input_types): weight = inputs[0] indices = inputs[1] return _op.take(weight, indices.astype("int32"), axis=0) - return _impl - - -def _one_hot(): - def _impl(inputs, input_types): + def one_hot(self, inputs, input_types): indices = inputs[0].astype("int32") num_classes = inputs[1] if num_classes == -1: @@ -2120,28 +1736,16 @@ def _impl(inputs, input_types): return _op.one_hot(indices, on_value, off_value, num_classes, -1, dtype) - return _impl - - -def _index(): - def _impl(inputs, input_types): + def index(self, inputs, input_types): data = inputs[0] indices = inputs[1] return _op.adv_index([data] + indices) - return _impl - - -def _meshgrid(): - def _impl(inputs, input_types): + def meshgrid(self, inputs, input_types): data = inputs[0] return _op.meshgrid(data, indexing="ij") - return _impl - - -def _nms(prelude): - def _impl(inputs, input_types): + def nms(self, inputs, input_types): boxes = inputs[0] scores = inputs[1] iou_threshold = inputs[2] @@ -2187,11 +1791,7 @@ def _impl(inputs, input_types): # in torchvision, indices from nms are int64 return _op.cast(ret, "int64") - return _impl - - -def _logsumexp(): - def _impl(inputs, input_types): + def logsumexp(self, inputs, input_types): data = _pytorch_promote_types(inputs[:1], input_types[:1]) dim_list = inputs[1] keepdim = inputs[2] if len(inputs) > 2 else False @@ -2199,11 +1799,7 @@ def _impl(inputs, input_types): assert isinstance(dim_list, list), "dim is expected to be a list" return _op.logsumexp(data[0], axis=dim_list, keepdims=keepdim) - return _impl - - -def _roi_align(prelude): - def _impl(inputs, input_types): + def roi_align(self, inputs, input_types): data = inputs[0] boxes = inputs[1] @@ -2217,16 +1813,12 @@ def _impl(inputs, input_types): return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio) - return _impl - - -def _unbind(): - def _impl(inputs, input_types): + def unbind(self, inputs, input_types): data = inputs[0] dim = int(inputs[1]) ishapes = _infer_shape(data) if dim >= len(ishapes): - msg = "Please check input dim, it shouldn't" "be greater than or equal to rank." + msg = "Please check input dim, it shouldn't be greater than or equal to rank." raise AttributeError(msg) selections = ishapes[dim] @@ -2239,13 +1831,9 @@ def _impl(inputs, input_types): ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) return ret - return _impl - - -def _shape_as_tensor(prelude): - def _impl(inputs, input_types): + def shape_as_tensor(self, inputs, input_types): is_symbolic_shape = False - input_shape = _infer_shape(inputs[0], prelude.mod) + input_shape = _infer_shape(inputs[0], self.prelude.mod) for axis in input_shape: if not isinstance(axis, (int, tvm.tir.IntImm)): is_symbolic_shape = True @@ -2258,45 +1846,30 @@ def _impl(inputs, input_types): return ret - return _impl - - -def _logical_and(): - def _impl(inputs, input_types): + def logical_and(self, inputs, input_types): lhs = _op.cast(inputs[0], "bool") rhs = _op.cast(inputs[1], "bool") return _op.logical_and(lhs, rhs) - return _impl - - -def _nonzero(is_numpy_style): - def _impl(inputs, input_types): + def nonzero(self, inputs, input_types, is_numpy_style=False): data = inputs[0] ret = _op.transform.argwhere(data) - if is_numpy_style or (len(inputs) > 1 and inputs[1]): - return _unbind()([ret, 1], None) - + return self.unbind([ret, 1], None) return ret - return _impl - + def nonzero_numpy(self, inputs, input_types): + return self.nonzero(inputs, input_types, is_numpy_style=False) -def _scatter(): - def _impl(inputs, input_types): + def scatter(self, inputs, input_types): data = inputs[0] axis = int(inputs[1]) index = inputs[2] src = inputs[3] return _op.transform.scatter(data, index, src, axis) - return _impl - - -def _scalar_tensor(): - def _impl(inputs, input_types): + def scalar_tensor(self, inputs, input_types): data = inputs[0] cast_map = { 6: "float32", @@ -2309,11 +1882,7 @@ def _impl(inputs, input_types): data = data.data.asnumpy().tolist() return _expr.const(data, cast_map[type_key]) - return _impl - - -def _interpolate(): - def _impl(inputs, input_types): + def interpolate(self, inputs, input_types): if isinstance(inputs[1], _expr.Expr): out_size = inputs[1] elif isinstance(inputs[1], list): @@ -2342,26 +1911,14 @@ def _impl(inputs, input_types): return _op.image.resize(data, out_size, "NCHW", method, coord_trans) - return _impl - - -def _numel(): - def _impl(inputs, input_types): + def numel(self, inputs, input_types): return _op.ndarray_size(inputs[0]) - return _impl - - -def _empty(): - def _impl(inputs, input_types): + def empty(self, inputs, input_types): shape = inputs[0] return _op.zeros(shape, _convert_dtype_value(inputs[1])) - return _impl - - -def _bincount(): - def _impl(inputs, input_types): + def bincount(self, inputs, input_types): data = inputs[0] weights = inputs[1] maximum = _op.max(data) @@ -2377,18 +1934,427 @@ def _impl(inputs, input_types): counts = _op.zeros(_op.reshape(dim, [1]), out_dtype) return _op.scatter_add(counts, data, updates, axis=0) - return _impl - - -def _scatter_add(): - def _impl(inputs, input_types): + def scatter_add(self, inputs, input_types): data = inputs[0] axis = inputs[1] index = inputs[2] src = inputs[3] return _op.scatter_add(data, index, src, axis=axis) - return _impl + # Operator mappings + def create_convert_map(self): + self.convert_map = { + "aten::pixel_shuffle": self.pixel_shuffle, + "aten::device": self.none, + "prim::device": self.none, + "aten::sub": self.make_elemwise("subtract"), + "aten::sub_": self.make_elemwise("subtract"), + "aten::max": self.max, + "aten::min": self.min, + "aten::mul": self.make_elemwise("multiply"), + "aten::mul_": self.make_elemwise("multiply"), + "aten::pow": self.make_elemwise("power"), + "aten::arange": self.arange, + "aten::meshgrid": self.meshgrid, + "aten::div": self.make_elemwise("divide"), + "aten::div_": self.make_elemwise("divide"), + "aten::floor_divide": self.make_elemwise("floor_divide"), + "aten::true_divide": self.make_elemwise("divide"), + "aten::addcdiv": self.addcdiv, + "aten::addcmul": self.addcmul, + "aten::ones": self.ones, + "aten::ones_like": self.ones_like, + "aten::zeros": self.zeros, + "aten::zeros_like": self.zeros_like, + "aten::full": self.full, + "aten::full_like": self.full_like, + "aten::linspace": self.linspace, + "aten::reciprocal": self.reciprocal, + "aten::repeat": self.repeat, + "aten::repeat_interleave": self.repeat_interleave, + "aten::to": self.to, + "aten::squeeze": self.squeeze, + "aten::unsqueeze": self.unsqueeze, + "aten::cat": self.concatenate, + "aten::slice": self.slice, + "aten::split": self.split, + "aten::split_with_sizes": self.split_with_sizes, + "aten::select": self.select, + "aten::take": self.take, + "aten::where": self.where, + "aten::topk": self.topk, + "aten::relu": self.relu, + "aten::relu_": self.relu, + "aten::prelu": self.prelu, + "aten::leaky_relu": self.leaky_relu, + "aten::leaky_relu_": self.leaky_relu, + "aten::elu": self.elu, + "aten::elu_": self.elu, + "aten::celu": self.celu, + "aten::gelu": self.gelu, + "aten::selu": self.selu, + "aten::log_sigmoid": self.log_sigmoid, + "aten::adaptive_avg_pool2d": self.adaptive_avg_pool_2d, + "aten::adaptive_max_pool2d": self.adaptive_max_pool_2d, + "aten::max_pool2d": self.maxpool_2d, + "aten::max_pool2d_with_indices": self.maxpool_2d_with_indices, + "aten::max_pool1d": self.maxpool_1d, + "aten::max_pool3d": self.maxpool_3d, + "aten::hardtanh": self.hardtanh, + "aten::hardtanh_": self.hardtanh, + "aten::_convolution": self.convolution, + "aten::softmax": self.softmax, + "aten::threshold": self.threshold, + "aten::threshold_": self.threshold, + "aten::contiguous": self.contiguous, + "aten::batch_norm": self.batch_norm, + "aten::instance_norm": self.instance_norm, + "aten::layer_norm": self.layer_norm, + "aten::group_norm": self.group_norm, + "aten::transpose": self.transpose, + "aten::transpose_": self.transpose, + "aten::t": self.transpose, + "aten::flatten": self.flatten, + "aten::addmm": self.addmm, + "aten::size": self.size, + "aten::view": self.view, + "aten::reshape": self.reshape, + "aten::clone": self.clone, + "aten::log_softmax": self.log_softmax, + "aten::sigmoid": self.sigmoid, + "aten::softplus": self.softplus, + "aten::avg_pool2d": self.avg_pool2d, + "aten::avg_pool3d": self.avg_pool3d, + "aten::dropout": self.dropout, + "aten::dropout_": self.dropout, + "aten::feature_dropout": self.dropout, + "aten::alpha_dropout": self.dropout, + "aten::mean": self.mean, + "aten::chunk": self.chunk, + "aten::matmul": self.matmul, + "aten::bmm": self.matmul, + "aten::expand": self.expand, + "aten::Int": self.int, + "prim::NumToTensor": self.numtotensor, + "prim::ImplicitTensorToNum": self.tensortonum, + "aten::ScalarImplicit": self.tensortonum, + "aten::constant_pad_nd": self.make_pad("constant"), + "aten::reflection_pad1d": self.make_pad("reflect"), + "aten::reflection_pad2d": self.make_pad("reflect"), + "aten::replication_pad1d": self.make_pad("edge"), + "aten::replication_pad2d": self.make_pad("edge"), + "aten::replication_pad3d": self.make_pad("edge"), + "aten::permute": self.transpose, + "aten::sum": self.make_reduce("sum"), + "aten::prod": self.make_reduce("prod"), + "aten::argmin": self.make_reduce("argmin"), + "aten::argmax": self.make_reduce("argmax"), + "aten::norm": self.norm, + "aten::frobenius_norm": self.frobenius_norm, + "aten::std": self.std, + "aten::var": self.variance, + "aten::abs": self.make_unary("abs"), + "aten::neg": self.make_unary("negative"), + "aten::cos": self.make_unary("cos"), + "aten::cosh": self.make_unary("cosh"), + "aten::sin": self.make_unary("sin"), + "aten::sinh": self.make_unary("sinh"), + "aten::tan": self.make_unary("tan"), + "aten::tanh": self.make_unary("tanh"), + "aten::acos": self.make_unary("acos"), + "aten::asin": self.make_unary("asin"), + "aten::atan": self.make_unary("atan"), + "aten::log": self.make_unary("log"), + "aten::log2": self.make_unary("log2"), + "aten::log10": self.make_unary("log10"), + "aten::log1p": self.log1p, + "aten::exp": self.make_unary("exp"), + "aten::erf": self.make_unary("erf"), + "aten::trunc": self.make_unary("trunc"), + "aten::sign": self.make_unary("sign"), + "aten::sqrt": self.make_unary("sqrt"), + "aten::rsqrt": self.make_unary("rsqrt"), + "aten::ceil": self.make_unary("ceil"), + "aten::floor": self.make_unary("floor"), + "aten::round": self.make_unary("round"), + "aten::isfinite": self.make_unary("isfinite"), + "aten::isinf": self.make_unary("isinf"), + "aten::isnan": self.make_unary("isnan"), + "aten::clamp": self.clamp, + "aten::clamp_": self.clamp, + "aten::detach": self.identity, + "aten::upsample_bilinear2d": self.make_upsample("bilinear"), + "aten::upsample_nearest2d": self.make_upsample("nearest_neighbor"), + "aten::upsample_trilinear3d": self.make_upsample3d("trilinear"), + "aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"), + "aten::expand_as": self.expand_as, + "aten::lt": self.make_elemwise("less"), + "aten::gt": self.make_elemwise("greater"), + "aten::le": self.make_elemwise("less_equal"), + "aten::ge": self.make_elemwise("greater_equal"), + "aten::ne": self.make_elemwise("not_equal"), + "aten::eq": self.make_elemwise("equal"), + "aten::logical_not": self.logical_not, + "aten::logical_xor": self.logical_xor, + "aten::bitwise_not": self.bitwise_not, + "aten::bitwise_xor": self.bitwise_xor, + "aten::Bool": self.Bool, + "aten::Float": self.Float, + "aten::adaptive_avg_pool3d": self.adaptive_avg_pool_3d, + "aten::adaptive_max_pool3d": self.adaptive_max_pool_3d, + "aten::rsub": self.rsub, + "aten::embedding": self.embedding, + "aten::one_hot": self.one_hot, + "aten::mm": self.matmul, + "aten::add": self.add, + "aten::add_": self.add, + "aten::stack": self.stack, + "aten::__getitem__": self.list_getitem, + "aten::len": self.list_len, + "aten::type_as": self.type_as, + "aten::gather": self.gather, + "aten::index_select": self.select, + "aten::index": self.index, + "torchvision::nms": self.nms, + "aten::logsumexp": self.logsumexp, + "torchvision::roi_align": self.roi_align, + "aten::unbind": self.unbind, + "aten::__and__": self.logical_and, + "aten::_shape_as_tensor": self.shape_as_tensor, + "aten::nonzero": self.nonzero, + "aten::nonzero_numpy": self.nonzero_numpy, + "aten::scatter": self.scatter, + "aten::scalar_tensor": self.scalar_tensor, + "aten::__interpolate": self.interpolate, + "aten::IntImplicit": self.identity, + "aten::tensor": self.identity, # used for example in tensor(1.0) + "aten::numel": self.numel, + "aten::empty": self.empty, + "aten::bincount": self.bincount, + "aten::scatter_add": self.scatter_add, + "aten::__not__": self.logical_not, + } + + def update_convert_map(self, custom_map): + self.convert_map.update(custom_map) + + def report_missing_conversion(self, op_names): + """ Check if all ops in an input graph are supported by TVM """ + known_ops = [ + "prim::Constant", + "prim::GetAttr", + "prim::ListConstruct", + "prim::ListUnpack", + "prim::TupleConstruct", + "prim::TupleUnpack", + "prim::RaiseException", + "prim::If", + "prim::Loop", + ] + known_ops += list(self.convert_map.keys()) + known_ops += list(qnn_torch.convert_map.keys()) + + missing = [op_name for op_name in op_names if op_name not in known_ops] + + if missing: + msg = "The following operators are not implemented: {}".format(missing) + raise NotImplementedError(msg) + + def convert_block(self, block, outputs): + """ Translate Torch "Block", used for prim::If and prim::Loop """ + ops = _get_operator_nodes(block.nodes()) + ret_names = _get_input_names(block.returnNode()) + return self.convert_operators(ops, outputs, ret_names) + + def convert_if(self, if_node, outputs): + """ Translate Torch prim::If to Relay If """ + cond = outputs[if_node.inputsAt(0).debugName()] + blocks = list(if_node.blocks()) + true_branch = self.convert_block(blocks[0], outputs) + false_branch = self.convert_block(blocks[1], outputs) + assert len(true_branch) == 1 and len(false_branch) == 1 + return _expr.If(cond, true_branch[0], false_branch[0]) + + def convert_loop(self, loop_node, outputs): + """ Translate Torch prim::Loop to Relay while_loop """ + + def get_input(index): + ivalue = loop_node.inputsAt(index) + inode = ivalue.node() + if inode.kind() == "prim::Constant": + return _expr.const(_get_constant(inode)) + var_name = ivalue.debugName() + assert var_name in outputs + return _wrap_const(outputs[var_name]) + + # Refer to the spec for prim::Loop below + # https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops + # The first input: %max_trip_count + # The second input: %initial_condition + # The rest of input: loop variables + max_loop_count = get_input(0) + init_cond = get_input(1) + num_loop_var = len(list(loop_node.inputs())) - 2 + init_vals = [get_input(i + 2) for i in range(num_loop_var)] + + # while loop has always max_loop_count being int64 max + # max_loop_count.data (tvm.runtime.NDArray) is -1, so _get_constant again + is_while_loop = ( + isinstance(max_loop_count, _expr.Constant) + and _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize + ) + + if is_while_loop: + loop_iter_dtype = "bool" + # while loop with non input dependent condition such as while i < 10: + # init_cond is int, need to cast to bool to type check + if isinstance(init_cond, _expr.Constant): + init_cond = _op.cast(init_cond, "bool") + init_loop_iter_val = init_cond + else: + loop_iter_dtype = "int32" + # always count from 0 + init_loop_iter_val = _expr.const(0, dtype="int32") + + body_block = list(loop_node.blocks())[0] + block_input_names = _get_input_names(body_block) + num_block_inputs = len(block_input_names) + name_val_pairs = list(zip(block_input_names, [init_loop_iter_val] + init_vals)) + outputs.update(name_val_pairs) + + def get_var(name, val): + if val: + checked_type = _infer_type_with_prelude(val, self.prelude) + if hasattr(checked_type, "shape"): + shape = get_const_tuple(checked_type.shape) + actual_shape = [] + for dim in shape: + if isinstance(dim, int) and dim == 0: + actual_shape.append(Any()) + else: + actual_shape.append(dim) + return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype) + else: + return _expr.var(name, type_annotation=checked_type) + return _expr.var(name) + + loop_iter_var = _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype) + loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]] + + # Add non constant free variables to loop variables to prevent code blow up + # Without this, if there are two for loops in a row, which often happens + # if the outer loop is unrolled, the computation corresponding to the first for loop + # is inlined inside loop body, turning O(N) + O(N) computation into O(N^2). + # This issue was found when converting from Stacked LSTM test. Torch does not add the + # outputof the eariler loop into loop variables of the next loop. + # So the variable corresponding to the first loop output appears free in the second + # loop body. + free_vars = [ + var + for var in _get_free_vars_from_block(body_block) + if var in outputs + and not isinstance(outputs[var], (_expr.Constant, int, float, str)) + and outputs[var] + ] + + prev_outputs = {} + for name in free_vars: + prev_output = outputs[name] + new_loop_var = get_var(name, prev_output) + prev_outputs[name] = prev_output + outputs[name] = new_loop_var + loop_vars.append(new_loop_var) + init_vals.append(prev_output) + + def cond(*current_vals): + i = current_vals[0] + + if is_while_loop: + return _op.equal(i, _expr.const(True, "bool")) + + return _op.less(i, max_loop_count) + + def body(*current_vals): + # Update loop variables using the prev iteration outputs + assert len(current_vals) == num_block_inputs + len(free_vars) + + for (i, val) in enumerate(current_vals): + if i < num_block_inputs: + outputs[block_input_names[i]] = val + else: + outputs[free_vars[i - num_block_inputs]] = val + + block_outputs = self.convert_block(body_block, outputs) + block_outputs += [outputs[name] for name in free_vars] + + if not is_while_loop: + # iter var increment implicit in torch, so do it manually + # for while loop, block_outputs[0] is already a boolean, + # the result of termination check + incr = _expr.const(1, dtype="int32") + block_outputs[0] = current_vals[0] + incr + + return block_outputs + + loop = while_loop(cond, [loop_iter_var] + loop_vars, body) + loop_val = loop(init_loop_iter_val, *init_vals) + + # restore original output values for free vars + outputs.update(prev_outputs) + + # The first element is a loop counter or boolean condition, ignore it + return [_expr.TupleGetItem(loop_val, i + 1) for i in range(num_loop_var)] + + def convert_operators(self, operators, outputs, ret_names): + """ Convert each Torch IR operators to Relay equivalent """ + for node_name, op_node in operators: + operator = op_node.kind() + inputs = _get_op_inputs(op_node, outputs) + + if operator == "prim::Constant": + outputs[node_name] = _get_constant(op_node) + elif operator == "prim::ListConstruct" and _should_construct_dynamic_list(op_node): + outputs[node_name] = _convert_to_list_adt(inputs, self.prelude) + elif operator == "prim::ListConstruct": + # This assumes that no more elements will be appended to this list + # In this case, we keep the Python list + outputs[node_name] = inputs + elif operator == "prim::TupleConstruct": + outputs[node_name] = _expr.Tuple(inputs) + elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]: + assert len(inputs) == 1 + if isinstance(inputs[0], (list, _expr.TupleWrapper)): + unpacked = inputs[0] + else: + unpacked = _unpack_tuple(inputs[0]) + outputs.update(zip(_get_output_names(op_node), unpacked)) + elif operator == "prim::prim::RaiseException": + logging.warning("raising exceptions is ignored") + outputs[node_name] = None + elif operator == "prim::If": + if_out = self.convert_if(op_node, outputs) + outputs[node_name] = if_out + elif operator == "prim::Loop": + loop_out = self.convert_loop(op_node, outputs) + unpacked_names = _get_output_names(op_node) + assert len(loop_out) == len(unpacked_names) + outputs.update(zip(unpacked_names, loop_out)) + else: + relay_op = self.convert_map[operator] + relay_out = relay_op( + inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype) + ) + + if isinstance(relay_out, tuple): + # This is for torch operators that return multiple outputs + # See _adaptive_max_2d above for example + out_names = _get_output_names(op_node) + outputs.update(zip(out_names, relay_out)) + else: + assert op_node.outputsSize() == 1 + outputs[node_name] = relay_out + + return [_wrap_const(outputs[ret_name]) for ret_name in ret_names] def _pytorch_result_type(dtypes, non_tensor_inputs): @@ -2544,202 +2510,6 @@ def _wrap_const(c): return c -# Operator mappings -def _get_convert_map(prelude, default_dtype): - convert_map = { - "aten::pixel_shuffle": _pixel_shuffle(prelude), - "aten::device": _none(), - "prim::device": _none(), - "aten::sub": _elemwise("subtract"), - "aten::sub_": _elemwise("subtract"), - "aten::max": _max(), - "aten::min": _min(), - "aten::mul": _elemwise("multiply"), - "aten::mul_": _elemwise("multiply"), - "aten::pow": _elemwise("power"), - "aten::arange": _arange(), - "aten::meshgrid": _meshgrid(), - "aten::div": _elemwise("divide"), - "aten::div_": _elemwise("divide"), - "aten::floor_divide": _elemwise("floor_divide"), - "aten::true_divide": _elemwise("divide"), - "aten::addcdiv": _addcdiv(), - "aten::addcmul": _addcmul(), - "aten::ones": _ones(default_dtype), - "aten::ones_like": _ones_like(default_dtype), - "aten::zeros": _zeros(default_dtype), - "aten::zeros_like": _zeros_like(default_dtype), - "aten::full": _full(default_dtype), - "aten::full_like": _full_like(default_dtype), - "aten::linspace": _linspace(), - "aten::reciprocal": _reciprocal(), - "aten::repeat": _repeat(), - "aten::repeat_interleave": _repeat_interleave(), - "aten::to": _to(), - "aten::squeeze": _squeeze(), - "aten::unsqueeze": _unsqueeze(), - "aten::cat": _concatenate(prelude), - "aten::slice": _slice(), - "aten::split": _split(), - "aten::split_with_sizes": _split_with_sizes(), - "aten::select": _select(), - "aten::take": _take(), - "aten::where": _where(), - "aten::topk": _topk(), - "aten::relu": _relu(prelude), - "aten::relu_": _relu(prelude), - "aten::prelu": _prelu(), - "aten::leaky_relu": _leaky_relu(), - "aten::leaky_relu_": _leaky_relu(), - "aten::elu": _elu(), - "aten::elu_": _elu(), - "aten::celu": _celu(), - "aten::gelu": _gelu(), - "aten::selu": _selu(), - "aten::log_sigmoid": _log_sigmoid(), - "aten::adaptive_avg_pool2d": _adaptive_avg_pool_2d(prelude), - "aten::adaptive_max_pool2d": _adaptive_max_pool_2d(), - "aten::max_pool2d": _maxpool_2d(), - "aten::max_pool2d_with_indices": _maxpool_2d_with_indices(), - "aten::max_pool1d": _maxpool_1d(), - "aten::max_pool3d": _maxpool_3d(), - "aten::hardtanh": _hardtanh(), - "aten::hardtanh_": _hardtanh(), - "aten::_convolution": _convolution(), - "aten::softmax": _softmax(), - "aten::threshold": _threshold(), - "aten::threshold_": _threshold(), - "aten::contiguous": _contiguous(), - "aten::batch_norm": _batch_norm(), - "aten::instance_norm": _instance_norm(), - "aten::layer_norm": _layer_norm(), - "aten::group_norm": _group_norm(), - "aten::transpose": _transpose(prelude), - "aten::transpose_": _transpose(prelude), - "aten::t": _transpose(prelude), - "aten::flatten": _flatten(), - "aten::addmm": _addmm(), - "aten::size": _size(prelude), - "aten::view": _view(), - "aten::reshape": _reshape(), - "aten::clone": _clone(), - "aten::log_softmax": _log_softmax(), - "aten::sigmoid": _sigmoid(), - "aten::softplus": _softplus(), - "aten::avg_pool2d": _avg_pool2d(prelude), - "aten::avg_pool3d": _avg_pool3d(), - "aten::dropout": _dropout(), - "aten::dropout_": _dropout(), - "aten::feature_dropout": _dropout(), - "aten::alpha_dropout": _dropout(), - "aten::mean": _mean(prelude), - "aten::chunk": _chunk(prelude), - "aten::matmul": _matmul(prelude), - "aten::bmm": _matmul(prelude), - "aten::expand": _expand(), - "aten::Int": _int(), - "prim::NumToTensor": _numtotensor(), - "prim::ImplicitTensorToNum": _tensortonum(), - "aten::ScalarImplicit": _tensortonum(), - "aten::constant_pad_nd": _pad("constant"), - "aten::reflection_pad1d": _pad("reflect"), - "aten::reflection_pad2d": _pad("reflect"), - "aten::replication_pad1d": _pad("edge"), - "aten::replication_pad2d": _pad("edge"), - "aten::replication_pad3d": _pad("edge"), - "aten::permute": _transpose(prelude), - "aten::sum": _reduce("sum"), - "aten::prod": _reduce("prod"), - "aten::argmin": _reduce("argmin"), - "aten::argmax": _reduce("argmax"), - "aten::norm": _norm(), - "aten::frobenius_norm": _frobenius_norm(), - "aten::std": _std(), - "aten::var": _variance(), - "aten::abs": _unary("abs"), - "aten::neg": _unary("negative"), - "aten::cos": _unary("cos"), - "aten::cosh": _unary("cosh"), - "aten::sin": _unary("sin"), - "aten::sinh": _unary("sinh"), - "aten::tan": _unary("tan"), - "aten::tanh": _unary("tanh"), - "aten::acos": _unary("acos"), - "aten::asin": _unary("asin"), - "aten::atan": _unary("atan"), - "aten::log": _unary("log"), - "aten::log2": _unary("log2"), - "aten::log10": _unary("log10"), - "aten::log1p": _log1p(), - "aten::exp": _unary("exp"), - "aten::erf": _unary("erf"), - "aten::trunc": _unary("trunc"), - "aten::sign": _unary("sign"), - "aten::sqrt": _unary("sqrt"), - "aten::rsqrt": _unary("rsqrt"), - "aten::ceil": _unary("ceil"), - "aten::floor": _unary("floor"), - "aten::round": _unary("round"), - "aten::isfinite": _unary("isfinite"), - "aten::isinf": _unary("isinf"), - "aten::isnan": _unary("isnan"), - "aten::clamp": _clamp(), - "aten::clamp_": _clamp(), - "aten::detach": _identity(), - "aten::upsample_bilinear2d": _upsample("bilinear", prelude), - "aten::upsample_nearest2d": _upsample("nearest_neighbor", prelude), - "aten::upsample_trilinear3d": _upsample3d("trilinear"), - "aten::upsample_nearest3d": _upsample3d("nearest_neighbor"), - "aten::expand_as": _expand_as(), - "aten::lt": _elemwise("less"), - "aten::gt": _elemwise("greater"), - "aten::le": _elemwise("less_equal"), - "aten::ge": _elemwise("greater_equal"), - "aten::ne": _elemwise("not_equal"), - "aten::eq": _elemwise("equal"), - "aten::logical_not": _logical_not(), - "aten::logical_xor": _logical_xor(), - "aten::bitwise_not": _bitwise_not(), - "aten::bitwise_xor": _bitwise_xor(), - "aten::Bool": _Bool(), - "aten::Float": _Float(), - "aten::adaptive_avg_pool3d": _adaptive_avg_pool_3d(), - "aten::adaptive_max_pool3d": _adaptive_max_pool_3d(), - "aten::rsub": _rsub(), - "aten::embedding": _embedding(), - "aten::one_hot": _one_hot(), - "aten::mm": _matmul(prelude), - "aten::add": _add(prelude), - "aten::add_": _add(prelude), - "aten::stack": _stack(prelude), - "aten::__getitem__": _list_getitem(prelude), - "aten::len": _list_len(prelude), - "aten::type_as": _type_as(), - "aten::gather": _gather(), - "aten::index_select": _select(), - "aten::index": _index(), - "torchvision::nms": _nms(prelude), - "aten::logsumexp": _logsumexp(), - "torchvision::roi_align": _roi_align(prelude), - "aten::unbind": _unbind(), - "aten::__and__": _logical_and(), - "aten::_shape_as_tensor": _shape_as_tensor(prelude), - "aten::nonzero": _nonzero(False), - "aten::nonzero_numpy": _nonzero(True), - "aten::scatter": _scatter(), - "aten::scalar_tensor": _scalar_tensor(), - "aten::__interpolate": _interpolate(), - "aten::IntImplicit": _identity(), - "aten::tensor": _identity(), # used for example in tensor(1.0) - "aten::numel": _numel(), - "aten::empty": _empty(), - "aten::bincount": _bincount(), - "aten::scatter_add": _scatter_add(), - "aten::__not__": _logical_not(), - } - return convert_map - - def _run_jit_passes(graph): """ The inline pass is necessary to unwrap prim::CallMethod """ # pylint: disable=c-extension-no-member @@ -2793,29 +2563,6 @@ def _get_users(node): return [use.user for use in _get_uses(node)] -def _report_missing_conversion(op_names, convert_map): - """ Check if all ops in an input graph are supported by TVM """ - known_ops = [ - "prim::Constant", - "prim::GetAttr", - "prim::ListConstruct", - "prim::ListUnpack", - "prim::TupleConstruct", - "prim::TupleUnpack", - "prim::RaiseException", - "prim::If", - "prim::Loop", - ] - known_ops += list(convert_map.keys()) - known_ops += list(qnn_torch.convert_map.keys()) - - missing = [op_name for op_name in op_names if op_name not in known_ops] - - if missing: - msg = "The following operators are not implemented: {}".format(missing) - raise NotImplementedError(msg) - - def _getattr_attr_name(node): attribute_names = node.attributeNames() assert len(attribute_names) == 1 @@ -3117,211 +2864,6 @@ def convert_params(graph, state_dict): return params, param_tensors, packed_param_map -def convert_block(block, outputs, convert_map, prelude, default_dtype="float32"): - """ Translate Torch "Block", used for prim::If and prim::Loop """ - ops = _get_operator_nodes(block.nodes()) - ret_names = _get_input_names(block.returnNode()) - return convert_operators( - ops, outputs, ret_names, convert_map, prelude, default_dtype=default_dtype - ) - - -def convert_if(if_node, outputs, convert_map, prelude, default_dtype="float32"): - """ Translate Torch prim::If to Relay If """ - cond = outputs[if_node.inputsAt(0).debugName()] - blocks = list(if_node.blocks()) - true_branch = convert_block( - blocks[0], outputs, convert_map, prelude, default_dtype=default_dtype - ) - false_branch = convert_block( - blocks[1], outputs, convert_map, prelude, default_dtype=default_dtype - ) - assert len(true_branch) == 1 and len(false_branch) == 1 - return _expr.If(cond, true_branch[0], false_branch[0]) - - -def convert_loop(loop_node, outputs, convert_map, prelude): - """ Translate Torch prim::Loop to Relay while_loop """ - - def get_input(index): - ivalue = loop_node.inputsAt(index) - inode = ivalue.node() - if inode.kind() == "prim::Constant": - return _expr.const(_get_constant(inode)) - var_name = ivalue.debugName() - assert var_name in outputs - return _wrap_const(outputs[var_name]) - - # Refer to the spec for prim::Loop below - # https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops - # The first input: %max_trip_count - # The second input: %initial_condition - # The rest of input: loop variables - max_loop_count = get_input(0) - init_cond = get_input(1) - num_loop_var = len(list(loop_node.inputs())) - 2 - init_vals = [get_input(i + 2) for i in range(num_loop_var)] - - # while loop has always max_loop_count being int64 max - # max_loop_count.data (tvm.runtime.NDArray) is -1, so _get_constant again - is_while_loop = ( - isinstance(max_loop_count, _expr.Constant) - and _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize - ) - - if is_while_loop: - loop_iter_dtype = "bool" - # while loop with non input dependent condition such as while i < 10: - # init_cond is int, need to cast to bool to type check - if isinstance(init_cond, _expr.Constant): - init_cond = _op.cast(init_cond, "bool") - init_loop_iter_val = init_cond - else: - loop_iter_dtype = "int32" - # always count from 0 - init_loop_iter_val = _expr.const(0, dtype="int32") - - body_block = list(loop_node.blocks())[0] - block_input_names = _get_input_names(body_block) - num_block_inputs = len(block_input_names) - name_val_pairs = list(zip(block_input_names, [init_loop_iter_val] + init_vals)) - outputs.update(name_val_pairs) - - def get_var(name, val): - if val: - checked_type = _infer_type_with_prelude(val, prelude) - if hasattr(checked_type, "shape"): - shape = get_const_tuple(checked_type.shape) - actual_shape = [] - for dim in shape: - if isinstance(dim, int) and dim == 0: - actual_shape.append(Any()) - else: - actual_shape.append(dim) - return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype) - else: - return _expr.var(name, type_annotation=checked_type) - return _expr.var(name) - - loop_iter_var = _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype) - loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]] - - # Add non constant free variables to loop variables to prevent code blow up - # Without this, if there are two for loops in a row, which often happens - # if the outer loop is unrolled, the computation corresponding to the first for loop - # is inlined inside loop body, turning O(N) + O(N) computation into O(N^2). - # This issue was found when converting from Stacked LSTM test. Torch does not add the output - # of the eariler loop into loop variables of the next loop. - # So the variable corresponding to the first loop output appears free in the second loop body. - free_vars = [ - var - for var in _get_free_vars_from_block(body_block) - if var in outputs - and not isinstance(outputs[var], (_expr.Constant, int, float, str)) - and outputs[var] - ] - - prev_outputs = {} - for name in free_vars: - prev_output = outputs[name] - new_loop_var = get_var(name, prev_output) - prev_outputs[name] = prev_output - outputs[name] = new_loop_var - loop_vars.append(new_loop_var) - init_vals.append(prev_output) - - def cond(*current_vals): - i = current_vals[0] - - if is_while_loop: - return _op.equal(i, _expr.const(True, "bool")) - - return _op.less(i, max_loop_count) - - def body(*current_vals): - # Update loop variables using the prev iteration outputs - assert len(current_vals) == num_block_inputs + len(free_vars) - - for (i, val) in enumerate(current_vals): - if i < num_block_inputs: - outputs[block_input_names[i]] = val - else: - outputs[free_vars[i - num_block_inputs]] = val - - block_outputs = convert_block(body_block, outputs, convert_map, prelude) - block_outputs += [outputs[name] for name in free_vars] - - if not is_while_loop: - # iter var increment implicit in torch, so do it manually - # for while loop, block_outputs[0] is already a boolean, - # the result of termination check - incr = _expr.const(1, dtype="int32") - block_outputs[0] = current_vals[0] + incr - - return block_outputs - - loop = while_loop(cond, [loop_iter_var] + loop_vars, body) - loop_val = loop(init_loop_iter_val, *init_vals) - - # restore original output values for free vars - outputs.update(prev_outputs) - - # The first element is a loop counter or boolean condition, ignore it - return [_expr.TupleGetItem(loop_val, i + 1) for i in range(num_loop_var)] - - -def convert_operators(operators, outputs, ret_names, convert_map, prelude, default_dtype="float32"): - """ Convert each Torch IR operators to Relay equivalent """ - for node_name, op_node in operators: - operator = op_node.kind() - inputs = _get_op_inputs(op_node, outputs) - - if operator == "prim::Constant": - outputs[node_name] = _get_constant(op_node) - elif operator == "prim::ListConstruct" and _should_construct_dynamic_list(op_node): - outputs[node_name] = _convert_to_list_adt(inputs, prelude) - elif operator == "prim::ListConstruct": - # This assumes that no more elements will be appended to this list - # In this case, we keep the Python list - outputs[node_name] = inputs - elif operator == "prim::TupleConstruct": - outputs[node_name] = _expr.Tuple(inputs) - elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]: - assert len(inputs) == 1 - if isinstance(inputs[0], (list, _expr.TupleWrapper)): - unpacked = inputs[0] - else: - unpacked = _unpack_tuple(inputs[0]) - outputs.update(zip(_get_output_names(op_node), unpacked)) - elif operator == "prim::prim::RaiseException": - logging.warning("raising exceptions is ignored") - outputs[node_name] = None - elif operator == "prim::If": - if_out = convert_if(op_node, outputs, convert_map, prelude, default_dtype=default_dtype) - outputs[node_name] = if_out - elif operator == "prim::Loop": - loop_out = convert_loop(op_node, outputs, convert_map, prelude) - unpacked_names = _get_output_names(op_node) - assert len(loop_out) == len(unpacked_names) - outputs.update(zip(unpacked_names, loop_out)) - else: - relay_op = convert_map[operator] - relay_out = relay_op( - inputs, _get_input_types(op_node, outputs, default_dtype=default_dtype) - ) - - if isinstance(relay_out, tuple): - # This is for torch operators that return multiple outputs - # See _adaptive_max_2d above for example - out_names = _get_output_names(op_node) - outputs.update(zip(out_names, relay_out)) - else: - assert op_node.outputsSize() == 1 - outputs[node_name] = relay_out - - return [_wrap_const(outputs[ret_name]) for ret_name in ret_names] - - def get_all_op_names(graph): """ Return all operator names in the input graph """ nodes = list(graph.nodes()) @@ -3370,16 +2912,16 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt mod = tvm.IRModule() prelude = Prelude(mod) - convert_map = _get_convert_map(prelude, default_dtype) + converter = PyTorchOpConverter(prelude, default_dtype) graph = script_module.graph.copy() _run_jit_passes(graph) if custom_convert_map: - convert_map.update(custom_convert_map) + converter.update_convert_map(custom_convert_map) op_names = get_all_op_names(graph) - _report_missing_conversion(op_names, convert_map) + converter.report_missing_conversion(op_names) is_module = isinstance(script_module, torch.jit.ScriptModule) params = script_module.state_dict() if is_module else {} @@ -3399,16 +2941,9 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt qnn_torch.add_input_quant_params_to_op_inputs(graph) qnn_torch.add_quant_params_to_outputs(outputs, packed_param_map, weight_quant_params) qnn_torch.add_quant_params(tvm_params, weight_quant_params) - convert_map.update(qnn_torch.convert_map) - - ret = convert_operators( - _get_operator_nodes(graph.nodes()), - outputs, - ret_name, - convert_map, - prelude, - default_dtype=default_dtype, - ) + converter.update_convert_map(qnn_torch.convert_map) + + ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name) mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])