Skip to content

Commit

Permalink
[Relay] A set of utilities that allows a model to be run efficiently …
Browse files Browse the repository at this point in the history
…on tensorcores. (apache#6748)
  • Loading branch information
jwfromm authored and Trevor Morris committed Dec 4, 2020
1 parent bc54cd5 commit b845dd3
Show file tree
Hide file tree
Showing 11 changed files with 475 additions and 3 deletions.
3 changes: 3 additions & 0 deletions python/tvm/relay/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@
# Feature
from . import feature
from . import sparse_dense

# Utilities
from .count_layers import count_layers
68 changes: 68 additions & 0 deletions python/tvm/relay/analysis/count_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utilities that enable counting the number of layers in a graph."""
import tvm
from tvm import relay
from ..expr_functor import ExprVisitor


class LayerCounter(ExprVisitor):
"""A visitor pass that computes the deepest chain of specified ops in graph."""

def __init__(self, valid_ops):
self.depth_count = 0
self.deepest_count = 0
self.valid_ops = [relay.op.get(op) for op in valid_ops]
super().__init__()

def visit_call(self, call):
if call.op in self.valid_ops:
self.depth_count += 1
current_count = self.depth_count
self.deepest_count = max(self.deepest_count, current_count)
for arg in call.args:
self.visit(arg)
self.depth_count = current_count

def count(self):
return self.deepest_count


def count_layers(expr, valid_ops):
"""Determine the number of layers of specified ops in a graph.
This pass computes only the deepest chain of ops rather than the
total number of ops in a graph. Thus, if there are two parallel
convolutions (for example), they would be considered a single layer.
Parameters
----------
expr : tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule.
The input expression.
valid_ops: List[str]
A list of the operations that should be included in the count.
Returns
-------
layer_count : int
The number of layers of the specified operations found in the graph.
"""
if isinstance(expr, tvm.ir.IRModule):
expr = expr["main"]
count_pass = LayerCounter(valid_ops)
count_pass.visit(expr)
return count_pass.count()
14 changes: 14 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..op import OpPattern
from .._tensor import elemwise_shape_func
from ..strategy.generic import is_depthwise_conv2d
from ...transform import LayoutConfig

# relu
reg.register_broadcast_schedule("nn.relu")
Expand Down Expand Up @@ -164,6 +165,16 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
from tvm import relay

data, weight = inputs

# First check if there is a LayoutConfig scope, and if so, whether
# it indicates we should ignore this layer or not.
layout_config = LayoutConfig.current
if layout_config is not None:
skip_layer = layout_config.check_skip()
if skip_layer:
return relay.nn.conv2d(data, weight, **attrs)

# Prepare new layout.
new_attrs = dict(attrs)
assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs"
desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
Expand Down Expand Up @@ -192,6 +203,9 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
else:
new_attrs["kernel_layout"] = "HWIO"
return relay.nn.conv2d(data, weight, **new_attrs)
elif desired_data_layout == "HWNC":
new_attrs["kernel_layout"] = "HWOI"
return relay.nn.conv2d(data, weight, **new_attrs)

raise ValueError("Layout %s is not yet supported." % desired_data_layout)

Expand Down
9 changes: 7 additions & 2 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,13 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
out_channels = oc_chunk * oc_block_factor
else:
_, _, out_channels, _ = get_const_tuple(kernel.shape)
if topi.cuda.is_shape_tensorcore_direct_qualified(
batch=N, in_channels=in_channels, num_filter=out_channels, in_dtype=data.dtype

tensorcore_dtypes = ["int4", "uint4", "int8", "uint8"]
if (
(N % 16 == 0 and in_channels % 16 == 0 and out_channels % 16 == 0)
or (N % 8 == 0 and in_channels % 16 == 0 and out_channels % 32 == 0)
or (N % 32 == 0 and in_channels % 16 == 0 and out_channels % 8 == 0)
and (data.dtype in tensorcore_dtypes and kernel.dtype in tensorcore_dtypes)
):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_hwnc_tensorcore),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
"""The Relay IR namespace containing transformations."""
# transformation passes
from .transform import *
from .recast import recast
from . import memory_alloc
139 changes: 139 additions & 0 deletions python/tvm/relay/transform/recast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Relay type recasting pass"""
import tvm
from tvm import relay
from tvm.ir import IRModule
from .transform import InferType
from ..analysis import count_layers
from ..expr_functor import ExprMutator, Call


class RecastMutator(ExprMutator):
"""Cast operations to the target type."""

def __init__(self, dtype, out_dtype, valid_ops, valid_op_count, skip_layers):
self.dtype = dtype
self.out_dtype = out_dtype
self.depth_count = 0
self.valid_ops = [relay.op.get(op) for op in valid_ops]
self.valid_op_count = valid_op_count
self.skip_layers = skip_layers
# Convert negative indices to positive ones.
for i, layer in enumerate(skip_layers):
if layer < 0:
skip_layers[i] = self.valid_op_count + layer
super().__init__()

def visit_call(self, call):
# Keep track of our current depth and layer count
# so we can know whether to skip this layer or not.
current_depth = self.depth_count
current_layer = self.valid_op_count - current_depth - 1
if call.op in self.valid_ops:
self.depth_count += 1
# Visit current call operation
new_fn = self.visit(call.op)
# Visit current arguments
args = []
for arg in call.args:
args.append(self.visit(arg))
self.depth_count = current_depth

# Downcast this op if its the correct type and not skipped.
if call.op in self.valid_ops and current_layer not in self.skip_layers:
# Recast inputs to specified type.
args = [self.visit(arg) for arg in call.args]
new_args = list()
for arg in args:
new_args.append(relay.cast(arg, dtype=self.dtype))

# If out_dtype is in the attributes, we need to update it.
orig_dtype = None
if "out_dtype" in call.attrs.keys():
new_attr_dict = {}
for attr in call.attrs.keys():
attr_value = call.attrs[attr]
if isinstance(attr_value, tvm.ir.container.Array):
attr_value = tuple(attr_value)
new_attr_dict[str(attr)] = attr_value
new_attr_dict["out_dtype"] = self.out_dtype
attr_type = str(call.attrs).split("(")[0]
new_attrs = tvm.ir.make_node(attr_type, **new_attr_dict)
if call.attrs["out_dtype"] != "":
orig_dtype = call.attrs["out_dtype"]
else:
new_attrs = call.attrs

if orig_dtype is None:
# Perform type inference to determine the original type.
new_mod = IRModule.from_expr(call)
new_mod = InferType()(new_mod)
checked_arg = new_mod["main"].body
orig_dtype = checked_arg.checked_type.dtype
# Recast the output for compatibility with other graph operations.
return relay.cast(Call(new_fn, new_args, new_attrs), orig_dtype)

# Otherwise return the unchanged call.
return Call(new_fn, args, call.attrs)


def recast(expr, dtype, out_dtype, ops=None, skip_layers=None):
"""Convert the types of operations in a graph to a new value.
Note that this is primarily useful for testing performance of individual
operations at the new datatype. In a real setting, this pass will
almost certainly do a poor job converting from one datatype to another
as it just applies hard casting. For example, when recasting from float
to integer, many small values will simply be set to 0. Although this will
allow autotuning and benchmarking to produce proper timings at the new
data type, the output of the model will of course be heavily impacted.
Parameters
---------
expr: tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule
The original function that will have its type changed.
dtype: str
The target type to cast to.
out_dtype: str
The output type to cast to.
ops: List[str]
A list of operations that should have their type changed,
others will be left as is.
skip_layers: List[int]
A list of integers indicating operations that should
not have their type changed, counted starting with the
first valid operation encountered. Negative indices are
allowed and indicate starting at the last layer.
Returns
-------
output_expr : tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule
The graph after recasting to the specified datatype.
"""
return_mod = False
if isinstance(expr, tvm.ir.IRModule):
expr = expr["main"]
return_mod = True
if ops is None:
ops = ["nn.conv2d"]
if skip_layers is None:
skip_layers = []
layer_depth = count_layers(expr, ops)
recast_pass = RecastMutator(dtype, out_dtype, ops, layer_depth, skip_layers)
expr = recast_pass.visit(expr)
if return_mod:
return tvm.IRModule.from_expr(expr)
return expr
27 changes: 27 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,33 @@ def AlterOpLayout():
return _ffi_api.AlterOpLayout()


class LayoutConfig(object):
"""A structure for customizing the ConvertLayout pass."""

current = None

def __init__(self, skip_layers=None):
self.skip_counter = 0
self.skip_layers = skip_layers if skip_layers is not None else []

def check_skip(self):
skip = self.skip_counter in self.skip_layers
self.skip_counter += 1
return skip

def reset(self):
self.skip_counter = 0
self.skip_layers = []

def __enter__(self):
self._old_manager = LayoutConfig.current
LayoutConfig.current = self
return self

def __exit__(self, ptype, value, trace):
LayoutConfig.current = self._old_manager


def ConvertLayout(desired_layouts):
"""Given a dest layout, this pass transforms the expr such that most of the ops input data
layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms,
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/graph/graph_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ void GraphRuntime::SetupStorage() {
ICHECK_GE(storage_id, 0) << "Do not support runtime shape op";
DLDataType t = vtype[i];
size_t bits = t.bits * t.lanes;
ICHECK(bits % 8U == 0U || bits == 1U);
ICHECK(bits % 8U == 0U || bits == 1U || bits == 4U);
size_t bytes = ((bits + 7U) / 8U) * size;

uint32_t sid = static_cast<uint32_t>(storage_id);
Expand Down
34 changes: 34 additions & 0 deletions tests/python/relay/test_layer_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm.relay.testing import resnet
from tvm.relay.analysis import count_layers


def test_layer_count():
def verify(num_layers):
# Load a resnet with a known number of layers.
mod, _ = resnet.get_workload(num_layers=num_layers)
# Count the number of conv and dense layers.
count = count_layers(mod, valid_ops=["nn.conv2d", "nn.dense"])
assert count == num_layers

verify(18)
verify(50)


if __name__ == "__main__":
test_layer_count()
Loading

0 comments on commit b845dd3

Please sign in to comment.