forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TVM changes to support introduction and typing of new custom operations.
Merged in SIM-6711 (pull request apache#36) Approved-by: Mikael Sevenier Approved-by: Joey Chou
- Loading branch information
Ashok Sudarsanam
committed
May 6, 2021
1 parent
ea3337c
commit 75b00e8
Showing
5 changed files
with
822 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=invalid-name, too-many-lines | ||
"""Custom operation configuration interface.""" | ||
from typing import List, Dict, Callable | ||
from dataclasses import dataclass | ||
from tvm.ir import TensorType | ||
import json | ||
|
||
|
||
@dataclass() | ||
class CustomOpConfigInfo(): | ||
""" | ||
Dataclass that contains configuration information for a custom operation. | ||
This dataclass contains the following fields: | ||
1. code: a string that contains the corresponding C code implementation. | ||
2. func_name: the name of the function in the C code that implements the | ||
custom operation. | ||
3. datatype: a string that specifies the underlying tensor datatype that | ||
is assumed by the C code implementation. Currently supported values are | ||
“int8”, “float”, and “double”. | ||
4. type_func: a Python function that returns the type of the custom opera- | ||
tion, based on the types of the input tensor(s) and relevant attributes. | ||
5. compiler_flags: a string that contains custom operation-specific flags | ||
for the target compiler. | ||
""" | ||
|
||
code: str | ||
func_name: str | ||
datatype: str | ||
type_func: Callable[..., TensorType] | ||
compiler_flags: str | ||
|
||
|
||
class CustomOperationConfig: | ||
""" | ||
Singleton class that contains configuration information for each custom | ||
operation that exists in an ML model. This information is used during | ||
the construction and typing of custom operations. | ||
""" | ||
|
||
__instance = None | ||
config_dict: Dict[str, CustomOpConfigInfo] = dict() | ||
|
||
@staticmethod | ||
def get_instance(): | ||
if CustomOperationConfig.__instance == None: | ||
CustomOperationConfig() | ||
return CustomOperationConfig.__instance | ||
|
||
def __init__(self): | ||
if CustomOperationConfig.__instance != None: | ||
raise Exception("CustomOperationConfig class is a singleton.") | ||
else: | ||
CustomOperationConfig.__instance = self | ||
|
||
def add_config_for_custom_op(self, custom_op_name: str, | ||
custom_op_config_info: CustomOpConfigInfo): | ||
self.config_dict[custom_op_name] = custom_op_config_info | ||
|
||
def get_config_for_custom_op(self, custom_op_name: str) -> CustomOpConfigInfo: | ||
return self.config_dict[custom_op_name] | ||
|
||
def get_custom_ops(self) -> List[str]: | ||
return list(self.config_dict.keys()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=invalid-name, too-many-lines | ||
"""Neural network operations for custom ops.""" | ||
from tvm.relay import expr | ||
from tvm.ir import Attrs | ||
|
||
from . import _make | ||
from typing import List, Tuple | ||
from tvm.custom_operation_config import ( | ||
CustomOpConfigInfo, CustomOperationConfig | ||
) | ||
import tvm._ffi | ||
import json | ||
|
||
|
||
MAX_TENSOR_INPUTS = 5 | ||
|
||
|
||
def custom_op(inputs, input_types, name, code, func_name, datatype, compiler_flags): | ||
""" | ||
Create a Relay IR node for the custom operation. Specifically, a | ||
CallNode around an operator nn.custom_op_{i} is returned, where {i} | ||
denotes the total number of input tensor operands in the custom | ||
operation. The number of input tensor operands cannot exceed 5. | ||
The inputs to a custom operation may also include constant values | ||
that represent attributes of the operation. Each attribute must | ||
be a string, an integer, a floating point value, a list of integers, | ||
or a list of floating-point values. | ||
In the custom operation specification in the ML network, the tensor | ||
operands must appear first, followed by the constant attributes. | ||
""" | ||
|
||
# Partition the inputs into tensor operands and constant attributes. | ||
tensor_inputs = [] | ||
constant_attrs = [] | ||
for input in inputs: | ||
if isinstance(input, tvm.relay.expr.ExprWithOp): | ||
if len(constant_attrs) == 0: | ||
tensor_inputs.append(input) | ||
else: | ||
raise AssertionError("Tensor operands must precede constant attributes.") | ||
elif is_valid_attribute(input): | ||
constant_attrs.append(input) | ||
else: | ||
raise AssertionError(f"Input {input} is neither a tensor nor a constant attribute.") | ||
|
||
# Store all attributes of the custom operation in a dictionary. | ||
# The following string attributes are common to all custom | ||
# operations: | ||
# 1. Custom operation name. | ||
# 2. C code implementation. | ||
# 3. C code function name. | ||
# 4. C code datatype. | ||
# 5. Operation-specific compiler flags. | ||
# | ||
# A custom operation may also have constant attributes that are | ||
# specific to it. | ||
custom_op_attrs = { | ||
"name": name, | ||
"code": code, | ||
"func_name": func_name, | ||
"datatype": datatype, | ||
"compiler_flags": compiler_flags, | ||
"constant_attrs": constant_attrs | ||
} | ||
|
||
custom_op_attr_str = json.dumps(custom_op_attrs) | ||
|
||
if len(tensor_inputs) == 1: | ||
return _make.custom_op_1(*tensor_inputs, custom_op_attr_str) | ||
elif len(tensor_inputs) == 2: | ||
return _make.custom_op_2(*tensor_inputs, custom_op_attr_str) | ||
elif len(tensor_inputs) == 3: | ||
return _make.custom_op_3(*tensor_inputs, custom_op_attr_str) | ||
elif len(tensor_inputs) == 4: | ||
return _make.custom_op_4(*tensor_inputs, custom_op_attr_str) | ||
elif len(tensor_inputs) == 5: | ||
return _make.custom_op_5(*tensor_inputs, custom_op_attr_str) | ||
else: | ||
msg = "Unsupported number of input tensor arguments (%d)." % (len(tensor_inputs)) | ||
raise AssertionError(msg) | ||
|
||
|
||
def is_valid_attribute(input): | ||
""" | ||
Returns True if the input operand is a string, an integer, a floating | ||
point number, a list of integers, or a list of floating-point numbers. | ||
""" | ||
|
||
input_type = type(input) | ||
if input_type == str or input_type == int or input_type == float: | ||
return True | ||
|
||
if input_type == list and type(input[0]) in [int, float]: | ||
for elem in input: | ||
if type(elem) != type(input[0]): | ||
return False | ||
return True | ||
|
||
return False | ||
|
||
|
||
@tvm._ffi.register_func("relay.op.nn.custom_op_type_func") | ||
def custom_op_type_func(types, num_inputs, attrs): | ||
""" | ||
Return the type of the specified custom operation, based on the | ||
input types and constant attribute values. This function is | ||
invoked by the registered add_type_rel() function in the C++ code. | ||
""" | ||
|
||
custom_op_attrs = json.loads(attrs.custom_op_attrs) | ||
custom_op_name = custom_op_attrs["name"] | ||
constant_attrs = custom_op_attrs["constant_attrs"] | ||
|
||
# Get the typing function associated with the custom operation. | ||
custom_op_config = CustomOperationConfig.get_instance() | ||
config_info = custom_op_config.get_config_for_custom_op(custom_op_name) | ||
type_func = config_info.type_func | ||
|
||
msg = f"Unsupported number of input tensor arguments {num_inputs} (max = {MAX_TENSOR_INPUTS})" | ||
assert 0 < num_inputs <= MAX_TENSOR_INPUTS, msg | ||
|
||
input_args = tuple([types[i] for i in range(num_inputs)]) | ||
return type_func(*input_args, *constant_attrs) | ||
|
||
|
||
@tvm._ffi.register_object("relay.attrs.CustomOpAttrs") | ||
class CustomOpAttrs(Attrs): | ||
"""Attributes for nn custom operations""" | ||
|
||
|
||
def make_custom_op(name, code, func_name, datatype, compiler_flags): | ||
def custom_op_func(inputs, input_types): | ||
return custom_op(inputs, input_types, name, | ||
code, func_name, datatype, | ||
compiler_flags) | ||
|
||
return custom_op_func | ||
|
||
|
||
def get_convert_map_from_custom_op_config(): | ||
""" | ||
Construct a mapping from custom operation name to Relay IR | ||
creation function. This mapping will get inserted into | ||
the front-end's operator conversion map. | ||
""" | ||
|
||
convert_map = {} | ||
custom_op_config = CustomOperationConfig.get_instance() | ||
custom_op_names = custom_op_config.get_custom_ops() | ||
|
||
for custom_op_name in custom_op_names: | ||
config_info = custom_op_config.get_config_for_custom_op(custom_op_name) | ||
code = config_info.code | ||
func_name = config_info.func_name | ||
datatype = config_info.datatype | ||
compiler_flags = config_info.compiler_flags | ||
|
||
convert_map[custom_op_name] = make_custom_op(custom_op_name, code, | ||
func_name, datatype, | ||
compiler_flags) | ||
|
||
return convert_map |
Oops, something went wrong.