diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 4b6acceb3a83..3a3f6d5aa304 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -16,4 +16,6 @@ # under the License. # pylint: disable=wildcard-import """Contrib modules.""" +from .register import get_pattern_table, register_pattern_table + from .dnnl import * diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 45a8c8331f72..71ef430ec9c6 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -32,7 +32,9 @@ - The other way is to implement the function by themselves to check the attributes of the op and decide if it should be offloaded to DNNL. """ -from ... import op as reg +from ... import expr as _expr +from ... import op as _op +from .register import register_pattern_table def _register_external_op_helper(op_name, supported=True): @@ -49,7 +51,7 @@ def _register_external_op_helper(op_name, supported=True): f : callable A function that returns if the operator is supported by DNNL. """ - @reg.register(op_name, "target.dnnl") + @_op.register(op_name, "target.dnnl") def _func_wrapper(attrs, args): return supported @@ -63,3 +65,23 @@ def _func_wrapper(attrs, args): _register_external_op_helper("add") _register_external_op_helper("subtract") _register_external_op_helper("multiply") + + +def make_pattern(with_bias=True): + data = _expr.var("data") + weight = _expr.var("weight") + bias = _expr.var("bias") + conv = _op.nn.conv2d(data, weight) + if with_bias: + conv_out = _op.add(conv, bias) + else: + conv_out = conv + return _op.nn.relu(conv_out) + + +@register_pattern_table("dnnl") +def pattern_table(): + conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True)) + conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False)) + dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat] + return dnnl_patterns diff --git a/python/tvm/relay/op/contrib/register.py b/python/tvm/relay/op/contrib/register.py new file mode 100644 index 000000000000..b82abdb88804 --- /dev/null +++ b/python/tvm/relay/op/contrib/register.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Register utilities for external codegen.""" +_PATTERN_TABLES = {} + + +def register_pattern_table(compiler, table=None): + """Register a pattern table for an external compiler. + + Pattern tables are used to create composite functions. + See the MergeComposite pass. + + Parameters + ---------- + compiler : str + The name of compiler + + table : function, optional + A function that returns the pattern table + + Returns + ------- + fregister : function + Register function if value is not specified. + """ + def _register(t): + """internal register function""" + _PATTERN_TABLES[compiler] = t() + return t + return _register(table) if table is not None else _register + + +def get_pattern_table(compiler): + """Get the pattern table associated with a compiler (if it's registered).""" + return _PATTERN_TABLES[compiler] if compiler in _PATTERN_TABLES else None diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 8827fbf1b8b0..3261ccd0d7c9 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -19,7 +19,6 @@ import sys import numpy as np -import pytest import tvm import tvm.relay.testing @@ -31,6 +30,7 @@ from tvm.relay.backend import compile_engine from tvm.relay.expr_functor import ExprMutator from tvm.relay.op.annotation import compiler_begin, compiler_end +from tvm.relay.op.contrib.register import get_pattern_table from tvm.relay.build_module import bind_params_by_name @@ -832,21 +832,8 @@ def expected(): def test_dnnl_fuse(): - def make_pattern(with_bias=True): - data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) - weight = relay.var("weight") - bias = relay.var("bias") - conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3), - channels=8, padding=(1, 1)) - if with_bias: - conv_out = relay.add(conv, bias) - else: - conv_out = conv - return relay.nn.relu(conv_out) - - conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True)) - conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False)) - dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat] + dnnl_patterns = get_pattern_table("dnnl") + conv2d_bias_relu_pat, conv2d_relu_pat = dnnl_patterns def get_blocks(prefix, data, in_channel, out_channel, include_bn=True, include_sigmoid=False):