diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py index e5b21cb107f5..b4ea7f3cff62 100644 --- a/python/tvm/relay/analysis/__init__.py +++ b/python/tvm/relay/analysis/__init__.py @@ -29,3 +29,6 @@ # Feature from . import feature from . import sparse_dense + +# Utilities +from .count_layers import count_layers diff --git a/python/tvm/relay/analysis/count_layers.py b/python/tvm/relay/analysis/count_layers.py new file mode 100644 index 000000000000..93d4f2766284 --- /dev/null +++ b/python/tvm/relay/analysis/count_layers.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utilities that enable counting the number of layers in a graph.""" +import tvm +from tvm import relay +from ..expr_functor import ExprVisitor + + +class LayerCounter(ExprVisitor): + """A visitor pass that computes the deepest chain of specified ops in graph.""" + + def __init__(self, valid_ops): + self.depth_count = 0 + self.deepest_count = 0 + self.valid_ops = [relay.op.get(op) for op in valid_ops] + super().__init__() + + def visit_call(self, call): + if call.op in self.valid_ops: + self.depth_count += 1 + current_count = self.depth_count + self.deepest_count = max(self.deepest_count, current_count) + for arg in call.args: + self.visit(arg) + self.depth_count = current_count + + def count(self): + return self.deepest_count + + +def count_layers(expr, valid_ops): + """Determine the number of layers of specified ops in a graph. + This pass computes only the deepest chain of ops rather than the + total number of ops in a graph. Thus, if there are two parallel + convolutions (for example), they would be considered a single layer. + + Parameters + ---------- + expr : tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule. + The input expression. + + valid_ops: List[str] + A list of the operations that should be included in the count. + + Returns + ------- + layer_count : int + The number of layers of the specified operations found in the graph. + """ + if isinstance(expr, tvm.ir.IRModule): + expr = expr["main"] + count_pass = LayerCounter(valid_ops) + count_pass.visit(expr) + return count_pass.count() diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index e1aabe1e15b5..c9926647989e 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -28,6 +28,7 @@ from ..op import OpPattern from .._tensor import elemwise_shape_func from ..strategy.generic import is_depthwise_conv2d +from ...transform import LayoutConfig # relu reg.register_broadcast_schedule("nn.relu") @@ -164,6 +165,16 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts): from tvm import relay data, weight = inputs + + # First check if there is a LayoutConfig scope, and if so, whether + # it indicates we should ignore this layer or not. + layout_config = LayoutConfig.current + if layout_config is not None: + skip_layer = layout_config.check_skip() + if skip_layer: + return relay.nn.conv2d(data, weight, **attrs) + + # Prepare new layout. new_attrs = dict(attrs) assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs" desired_data_layout, desired_kernel_layout = map(str, desired_layouts) @@ -192,6 +203,9 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts): else: new_attrs["kernel_layout"] = "HWIO" return relay.nn.conv2d(data, weight, **new_attrs) + elif desired_data_layout == "HWNC": + new_attrs["kernel_layout"] = "HWOI" + return relay.nn.conv2d(data, weight, **new_attrs) raise ValueError("Layout %s is not yet supported." % desired_data_layout) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 7031365251aa..ca44e49ce1dd 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -219,8 +219,13 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): out_channels = oc_chunk * oc_block_factor else: _, _, out_channels, _ = get_const_tuple(kernel.shape) - if topi.cuda.is_shape_tensorcore_direct_qualified( - batch=N, in_channels=in_channels, num_filter=out_channels, in_dtype=data.dtype + + tensorcore_dtypes = ["int4", "uint4", "int8", "uint8"] + if ( + (N % 16 == 0 and in_channels % 16 == 0 and out_channels % 16 == 0) + or (N % 8 == 0 and in_channels % 16 == 0 and out_channels % 32 == 0) + or (N % 32 == 0 and in_channels % 16 == 0 and out_channels % 8 == 0) + and (data.dtype in tensorcore_dtypes and kernel.dtype in tensorcore_dtypes) ): strategy.add_implementation( wrap_compute_conv2d(topi.cuda.conv2d_hwnc_tensorcore), diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index 138a36611c6f..1d0ea176b16f 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -18,4 +18,5 @@ """The Relay IR namespace containing transformations.""" # transformation passes from .transform import * +from .recast import recast from . import memory_alloc diff --git a/python/tvm/relay/transform/recast.py b/python/tvm/relay/transform/recast.py new file mode 100644 index 000000000000..05a72676a907 --- /dev/null +++ b/python/tvm/relay/transform/recast.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relay type recasting pass""" +import tvm +from tvm import relay +from tvm.ir import IRModule +from .transform import InferType +from ..analysis import count_layers +from ..expr_functor import ExprMutator, Call + + +class RecastMutator(ExprMutator): + """Cast operations to the target type.""" + + def __init__(self, dtype, out_dtype, valid_ops, valid_op_count, skip_layers): + self.dtype = dtype + self.out_dtype = out_dtype + self.depth_count = 0 + self.valid_ops = [relay.op.get(op) for op in valid_ops] + self.valid_op_count = valid_op_count + self.skip_layers = skip_layers + # Convert negative indices to positive ones. + for i, layer in enumerate(skip_layers): + if layer < 0: + skip_layers[i] = self.valid_op_count + layer + super().__init__() + + def visit_call(self, call): + # Keep track of our current depth and layer count + # so we can know whether to skip this layer or not. + current_depth = self.depth_count + current_layer = self.valid_op_count - current_depth - 1 + if call.op in self.valid_ops: + self.depth_count += 1 + # Visit current call operation + new_fn = self.visit(call.op) + # Visit current arguments + args = [] + for arg in call.args: + args.append(self.visit(arg)) + self.depth_count = current_depth + + # Downcast this op if its the correct type and not skipped. + if call.op in self.valid_ops and current_layer not in self.skip_layers: + # Recast inputs to specified type. + args = [self.visit(arg) for arg in call.args] + new_args = list() + for arg in args: + new_args.append(relay.cast(arg, dtype=self.dtype)) + + # If out_dtype is in the attributes, we need to update it. + orig_dtype = None + if "out_dtype" in call.attrs.keys(): + new_attr_dict = {} + for attr in call.attrs.keys(): + attr_value = call.attrs[attr] + if isinstance(attr_value, tvm.ir.container.Array): + attr_value = tuple(attr_value) + new_attr_dict[str(attr)] = attr_value + new_attr_dict["out_dtype"] = self.out_dtype + attr_type = str(call.attrs).split("(")[0] + new_attrs = tvm.ir.make_node(attr_type, **new_attr_dict) + if call.attrs["out_dtype"] != "": + orig_dtype = call.attrs["out_dtype"] + else: + new_attrs = call.attrs + + if orig_dtype is None: + # Perform type inference to determine the original type. + new_mod = IRModule.from_expr(call) + new_mod = InferType()(new_mod) + checked_arg = new_mod["main"].body + orig_dtype = checked_arg.checked_type.dtype + # Recast the output for compatibility with other graph operations. + return relay.cast(Call(new_fn, new_args, new_attrs), orig_dtype) + + # Otherwise return the unchanged call. + return Call(new_fn, args, call.attrs) + + +def recast(expr, dtype, out_dtype, ops=None, skip_layers=None): + """Convert the types of operations in a graph to a new value. + Note that this is primarily useful for testing performance of individual + operations at the new datatype. In a real setting, this pass will + almost certainly do a poor job converting from one datatype to another + as it just applies hard casting. For example, when recasting from float + to integer, many small values will simply be set to 0. Although this will + allow autotuning and benchmarking to produce proper timings at the new + data type, the output of the model will of course be heavily impacted. + + Parameters + --------- + expr: tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule + The original function that will have its type changed. + dtype: str + The target type to cast to. + out_dtype: str + The output type to cast to. + ops: List[str] + A list of operations that should have their type changed, + others will be left as is. + skip_layers: List[int] + A list of integers indicating operations that should + not have their type changed, counted starting with the + first valid operation encountered. Negative indices are + allowed and indicate starting at the last layer. + Returns + ------- + output_expr : tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule + The graph after recasting to the specified datatype. + """ + return_mod = False + if isinstance(expr, tvm.ir.IRModule): + expr = expr["main"] + return_mod = True + if ops is None: + ops = ["nn.conv2d"] + if skip_layers is None: + skip_layers = [] + layer_depth = count_layers(expr, ops) + recast_pass = RecastMutator(dtype, out_dtype, ops, layer_depth, skip_layers) + expr = recast_pass.visit(expr) + if return_mod: + return tvm.IRModule.from_expr(expr) + return expr diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index e155f83a7c5d..060547e4c4d7 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -386,6 +386,33 @@ def AlterOpLayout(): return _ffi_api.AlterOpLayout() +class LayoutConfig(object): + """A structure for customizing the ConvertLayout pass.""" + + current = None + + def __init__(self, skip_layers=None): + self.skip_counter = 0 + self.skip_layers = skip_layers if skip_layers is not None else [] + + def check_skip(self): + skip = self.skip_counter in self.skip_layers + self.skip_counter += 1 + return skip + + def reset(self): + self.skip_counter = 0 + self.skip_layers = [] + + def __enter__(self): + self._old_manager = LayoutConfig.current + LayoutConfig.current = self + return self + + def __exit__(self, ptype, value, trace): + LayoutConfig.current = self._old_manager + + def ConvertLayout(desired_layouts): """Given a dest layout, this pass transforms the expr such that most of the ops input data layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 601c68abdf08..21960d9d4b1b 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -310,7 +310,7 @@ void GraphRuntime::SetupStorage() { ICHECK_GE(storage_id, 0) << "Do not support runtime shape op"; DLDataType t = vtype[i]; size_t bits = t.bits * t.lanes; - ICHECK(bits % 8U == 0U || bits == 1U); + ICHECK(bits % 8U == 0U || bits == 1U || bits == 4U); size_t bytes = ((bits + 7U) / 8U) * size; uint32_t sid = static_cast(storage_id); diff --git a/tests/python/relay/test_layer_count.py b/tests/python/relay/test_layer_count.py new file mode 100644 index 000000000000..f680bb2725f2 --- /dev/null +++ b/tests/python/relay/test_layer_count.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tvm.relay.testing import resnet +from tvm.relay.analysis import count_layers + + +def test_layer_count(): + def verify(num_layers): + # Load a resnet with a known number of layers. + mod, _ = resnet.get_workload(num_layers=num_layers) + # Count the number of conv and dense layers. + count = count_layers(mod, valid_ops=["nn.conv2d", "nn.dense"]) + assert count == num_layers + + verify(18) + verify(50) + + +if __name__ == "__main__": + test_layer_count() diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 40aef264a335..1fc5d39b9486 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1162,6 +1162,78 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_convert_with_config(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.nn.relu(y) + + weight2 = relay.var("weight2", shape=(3, 3, 64, 64)) + y2 = relay.nn.conv2d( + y, + weight2, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y2 = relay.nn.relu(y2) + + out = relay.Function([x, weight, weight2], y2) + return out + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + + weight2 = relay.var("weight2", shape=(3, 3, 64, 64)) + weight2 = relay.layout_transform(weight2, "HWIO", "HWOI") + + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NHWC", "HWNC") + + y2 = relay.nn.conv2d( + y, + weight2, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="HWNC", + kernel_layout="HWOI", + ) + y2 = relay.nn.relu(y2) + + y2 = relay.layout_transform(y2, "HWNC", "NHWC") + output = relay.Function(relay.analysis.free_vars(y2), y2) + return output + + a = before() + layout_config = relay.transform.LayoutConfig(skip_layers=[0]) + with layout_config: + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["HWNC", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + if __name__ == "__main__": test_qnn_binary_no_convert_layout() test_no_convert_layout() @@ -1185,3 +1257,4 @@ def expected(): test_default_keyword() test_different_ops_convert_layout() test_no_desired_layout() + test_convert_with_config() diff --git a/tests/python/relay/test_recast.py b/tests/python/relay/test_recast.py new file mode 100644 index 000000000000..8c5a562ddbba --- /dev/null +++ b/tests/python/relay/test_recast.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay.transform import recast + + +def test_recast_simple(): + """Recast a single convolution operator.""" + + def before(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype="float32") + return relay.Function([x, w], c) + + def expected(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + x_int = relay.cast(x, "int8") + w_int = relay.cast(w, "int8") + c = relay.nn.conv2d(x_int, w_int, padding=(1, 1), out_dtype="int32") + c_float = relay.cast(c, "float32") + return relay.Function([x, w], c_float) + + pre = before() + post = recast(pre, "int8", "int32") + expected = expected() + assert tvm.ir.structural_equal(expected, post) + + +def test_recast_medium(): + """Recast a slightly larger graph.""" + + def before(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype="float32") + w2 = relay.var("w2", shape=[8, 8, 3, 3]) + c2 = relay.nn.conv2d(c, w2, padding=(1, 1), out_dtype="float32") + return relay.Function([x, w, w2], c2) + + def expected(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + x_int = relay.cast(x, "int8") + w_int = relay.cast(w, "int8") + c = relay.nn.conv2d(x_int, w_int, padding=(1, 1), out_dtype="int32") + c_float = relay.cast(c, "float32") + w2 = relay.var("w2", shape=[8, 8, 3, 3]) + w2_int = relay.cast(w2, "int8") + c_float_int = relay.cast(c_float, "int8") + c2 = relay.nn.conv2d(c_float_int, w2_int, padding=(1, 1), out_dtype="int32") + c2_float = relay.cast(c2, "float32") + return relay.Function([x, w, w2], c2_float) + + pre = before() + post = recast(pre, "int8", "int32") + expected = expected() + assert tvm.ir.structural_equal(expected, post) + + +def test_recast_skip(): + """Recast a graph using skip layers.""" + + def before(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype="float32") + w2 = relay.var("w2", shape=[8, 8, 3, 3]) + c2 = relay.nn.conv2d(c, w2, padding=(1, 1), out_dtype="float32") + return relay.Function([x, w, w2], c2) + + def expected(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype="float32") + w2 = relay.var("w2", shape=[8, 8, 3, 3]) + w2_int = relay.cast(w2, "int8") + c_int = relay.cast(c, "int8") + c2 = relay.nn.conv2d(c_int, w2_int, padding=(1, 1), out_dtype="int32") + c2_float = relay.cast(c2, "float32") + return relay.Function([x, w, w2], c2_float) + + pre = before() + post = recast(pre, "int8", "int32", skip_layers=[0]) + expected = expected() + assert tvm.ir.structural_equal(expected, post) + + +if __name__ == "__main__": + test_recast_simple() + test_recast_medium() + test_recast_skip()