Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN] Lookup operations for hard to implement operators #10053

Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
29e5e8a
initial tanh impl
AndrewZhaoLuo Jan 13, 2022
b414aeb
smalls error
AndrewZhaoLuo Jan 13, 2022
8f7a4f6
support uint and int lookup into tables
AndrewZhaoLuo Jan 18, 2022
b8a54ee
reinterpret cast, working tanh tests
AndrewZhaoLuo Jan 18, 2022
cf3eb4e
refactor relay func creation
AndrewZhaoLuo Jan 19, 2022
0c1a71d
basic casting tests
AndrewZhaoLuo Jan 19, 2022
c943ff1
explicitly say do not handle multi-channel lookups
AndrewZhaoLuo Jan 19, 2022
2073740
add example funcs
AndrewZhaoLuo Jan 19, 2022
11674d3
fix silent fail
AndrewZhaoLuo Jan 19, 2022
67baa39
fix some bugs with floating point funcs not working
AndrewZhaoLuo Jan 19, 2022
47e4b5c
add TODO
AndrewZhaoLuo Jan 19, 2022
446e25a
add tood
AndrewZhaoLuo Jan 22, 2022
87e265c
canonicalizations
AndrewZhaoLuo Jan 24, 2022
400880c
refactor integer lookup ops into own folder
AndrewZhaoLuo Jan 24, 2022
3d26528
fq2i stuff
AndrewZhaoLuo Jan 24, 2022
e60f2b4
clean up existing tests
AndrewZhaoLuo Jan 24, 2022
8bd0b44
flesh out todo
AndrewZhaoLuo Jan 24, 2022
daef150
more tests
AndrewZhaoLuo Jan 24, 2022
173e251
test on keeping shape good
AndrewZhaoLuo Jan 24, 2022
c4efbfb
lookup table fix
AndrewZhaoLuo Jan 24, 2022
ddd8dd5
replace canonicalization for rsqrt
AndrewZhaoLuo Jan 24, 2022
f65583a
remove canonicalization of rsqrt
AndrewZhaoLuo Jan 24, 2022
0b8dc75
add asf headers
AndrewZhaoLuo Jan 25, 2022
3c29f6b
topi tests
AndrewZhaoLuo Jan 25, 2022
eda9f19
gather supports unsigned integer tests
AndrewZhaoLuo Jan 25, 2022
ab25dc0
fix things
AndrewZhaoLuo Jan 25, 2022
fcc8313
move to legalization
AndrewZhaoLuo Jan 25, 2022
72e150f
jostle ci
AndrewZhaoLuo Jan 26, 2022
19de289
linting
AndrewZhaoLuo Jan 27, 2022
76fb6bc
use take instead of gather
AndrewZhaoLuo Jan 27, 2022
46f82c0
remove gather changes
AndrewZhaoLuo Jan 27, 2022
520f4f1
undo changes
AndrewZhaoLuo Jan 27, 2022
7a0f43b
undo changes
AndrewZhaoLuo Jan 27, 2022
6f8f34a
undo changes
AndrewZhaoLuo Jan 27, 2022
4e7b96a
move thing in range
Jan 28, 2022
40d5a28
initial tanh impl
AndrewZhaoLuo Jan 13, 2022
95537af
smalls error
AndrewZhaoLuo Jan 13, 2022
496c250
support uint and int lookup into tables
AndrewZhaoLuo Jan 18, 2022
2334e1c
reinterpret cast, working tanh tests
AndrewZhaoLuo Jan 18, 2022
5c65eb1
refactor relay func creation
AndrewZhaoLuo Jan 19, 2022
7b865e0
basic casting tests
AndrewZhaoLuo Jan 19, 2022
f2934c0
explicitly say do not handle multi-channel lookups
AndrewZhaoLuo Jan 19, 2022
a16a352
add example funcs
AndrewZhaoLuo Jan 19, 2022
b28a65e
fix silent fail
AndrewZhaoLuo Jan 19, 2022
fb22ee3
fix some bugs with floating point funcs not working
AndrewZhaoLuo Jan 19, 2022
0a03d46
add TODO
AndrewZhaoLuo Jan 19, 2022
f8a5114
add tood
AndrewZhaoLuo Jan 22, 2022
cc2f5a9
canonicalizations
AndrewZhaoLuo Jan 24, 2022
16aad84
refactor integer lookup ops into own folder
AndrewZhaoLuo Jan 24, 2022
eacf383
fq2i stuff
AndrewZhaoLuo Jan 24, 2022
f1753c9
clean up existing tests
AndrewZhaoLuo Jan 24, 2022
76cef1b
flesh out todo
AndrewZhaoLuo Jan 24, 2022
e996279
more tests
AndrewZhaoLuo Jan 24, 2022
1ff3adc
test on keeping shape good
AndrewZhaoLuo Jan 24, 2022
eabd40a
lookup table fix
AndrewZhaoLuo Jan 24, 2022
efe7b1a
replace canonicalization for rsqrt
AndrewZhaoLuo Jan 24, 2022
3b00080
remove canonicalization of rsqrt
AndrewZhaoLuo Jan 24, 2022
3adcb9e
add asf headers
AndrewZhaoLuo Jan 25, 2022
7928957
gather supports unsigned integer tests
AndrewZhaoLuo Jan 25, 2022
3b5759b
fix things
AndrewZhaoLuo Jan 25, 2022
a2f4c5e
move to legalization
AndrewZhaoLuo Jan 25, 2022
b5ec138
jostle ci
AndrewZhaoLuo Jan 26, 2022
fe54fa3
linting
AndrewZhaoLuo Jan 27, 2022
804e9fb
use take instead of gather
AndrewZhaoLuo Jan 27, 2022
9a22774
remove gather changes
AndrewZhaoLuo Jan 27, 2022
a148ff1
undo changes
AndrewZhaoLuo Jan 27, 2022
a75ea9f
undo changes
AndrewZhaoLuo Jan 27, 2022
3b3c685
undo changes
AndrewZhaoLuo Jan 27, 2022
b609d63
move thing in range
Jan 28, 2022
b0b7676
lint
AndrewZhaoLuo Jan 31, 2022
5b919f1
remove unneeded line
AndrewZhaoLuo Feb 7, 2022
3240c86
jostle
AndrewZhaoLuo Feb 7, 2022
858d6de
Merge branch 'aluo/fq2i/elemwise-lookup-ops-qnn-flavor' of github.com…
AndrewZhaoLuo Feb 8, 2022
35730c3
Merge branch 'main' into aluo/fq2i/elemwise-lookup-ops-qnn-flavor
AndrewZhaoLuo Feb 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,7 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
ICHECK_GE(indices_dim_i, 1);
}
ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());
ICHECK(indices->dtype.is_int());

Array<PrimExpr> out_shape;
for (size_t i = 0; i < ndim_i; ++i) {
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/qnn/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
"""QNN dialect related operators."""
from __future__ import absolute_import as _abs
from .qnn import *
from .op import register_qnn_legalize
from . import _qnn, legalizations, layout_conversions
from .op import register_qnn_legalize, register_qnn_canonicalize
from . import _qnn, legalizations, layout_conversions, canonicalizations
160 changes: 160 additions & 0 deletions python/tvm/relay/qnn/op/canonicalizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Consist of utilities and methods for lowering QNN into mainline relay."""
from typing import Callable

import numpy as np
import tvm
from tvm import relay


def run_const_expr(expr: "relay.Expr") -> np.ndarray:
"""Evaluate a const expression, receiving result as np array."""
mod = tvm.IRModule.from_expr(expr)
vm_exe = relay.create_executor("vm", mod=mod)
return vm_exe.evaluate()().asnumpy()


def create_integer_lookup_table(
floating_point_func: Callable[[np.ndarray], np.ndarray],
input_scale: "relay.Expr",
input_zero_point: "relay.Expr",
output_scale: "relay.Expr",
output_zero_point: "relay.Expr",
in_axis: int = -1,
out_axis: int = -1,
in_dtype: str = "uint8",
out_dtype: str = "uint8",
) -> np.ndarray:
"""
Return a table where each input indexes to the output quantizing the given function.

Note this also supports mapping unsigned and signed integers to each other.

Args:
floating_point_func: The numpy function which this table is to approximate
input_scale: The scale of the quantized input tensor.
input_zero_point: The zero point of the quantized input tensor.
output_scale: The scale of the quantized output tensor.
output_zero_point: The zero point of the quantized output tensor.
in_axis: The axis for multi-channel quantization of the input if applicable.
out_axis: The axis for multi-channel quantization of the output if applicable.
in_dtype: The dtype of the input tensor.
out_dtype: The wanted dtype of the output tensor.

Returns:
A numpy array where values in quantized space will index to the output in quantized space
approximating the given function.
"""
if not np.issubdtype(np.dtype(in_dtype), np.integer) or not np.issubdtype(
np.dtype(out_dtype), np.integer
):
raise ValueError(
f"Only integer dtypes allowed got {in_dtype} and {out_dtype} for in and out dtypes."
)

dtype_info = np.iinfo(in_dtype)

num_bits = dtype_info.bits

# Use TVMs quantization methods via relay to be consistent
# inputs_quantized = np.array(range(dtype_info.min, dtype_info.max + 1)).astype(in_dtype)

# First generate a list of all num_bit integer patterns
inputs_quantized = np.array(range(0, 2 ** num_bits), dtype=f"uint{num_bits}")

# Reinterpret bits as the real datatype
# Note what we are doing here is a bit tricky, the canonical view of our lookup table
# is using the uintX version. When we run the lookup in the relay graph, we cast the
# bit pattern back into this form.
inputs_quantized = inputs_quantized.view(in_dtype)
inputs_quantized = relay.const(inputs_quantized, dtype=in_dtype)
inputs_dequantized = run_const_expr(
relay.qnn.op.dequantize(
inputs_quantized,
input_scale=input_scale,
input_zero_point=input_zero_point,
axis=in_axis,
)
)

output_dequantized = relay.const(floating_point_func(inputs_dequantized))
output_quantized = run_const_expr(
relay.qnn.op.quantize(
output_dequantized, output_scale, output_zero_point, out_axis, out_dtype
)
)

return output_quantized


def create_integer_lookup_op(
input_arg: "relay.Expr",
floating_point_func: Callable[[np.array], np.array],
in_scale: "relay.Expr",
in_zero_point: "relay.Expr",
out_scale: "relay.Expr",
out_zero_point: "relay.Expr",
in_axis: int = -1,
out_axis: int = -1,
in_dtype: str = "uint8",
out_dtype: str = "uint8",
) -> "relay.Expr":
"""
Create a quantized version of the given floating point unary operation using table lookup.

Args:
input_arg: The quantized input to the final function.
floating_point_func: The numpy function which this table is to approximate
in_scale: The scale of the quantized input tensor.
in_zero_point: The zero point of the quantized input tensor.
out_scale: The scale of the quantized output tensor.
out_zero_point: The zero point of the quantized output tensor.
in_axis: The axis for multi-channel quantization of the input if applicable.
out_axis: The axis for multi-channel quantization of the output if applicable.
in_dtype: The dtype of the input tensor.
out_dtype: The wanted dtype of the output tensor.

Returns:
A Relay expression representing a quantized version of the given function.
"""

# TODO: handle multi-channel q, below will fail with multi-channel q
in_scale = in_scale.data.numpy().item()
in_zero_point = in_zero_point.data.numpy().item()
out_scale = out_scale.data.numpy().item()
out_zero_point = out_zero_point.data.numpy().item()

lookup_table = create_integer_lookup_table(
floating_point_func,
relay.const(in_scale),
relay.const(in_zero_point, dtype="int32"),
relay.const(out_scale),
relay.const(out_zero_point, dtype="int32"),
in_axis=in_axis,
in_dtype=in_dtype,
out_axis=out_axis,
out_dtype=out_dtype,
)

in_dtype_info = np.iinfo(in_dtype)
in_dtype_num_bits = in_dtype_info.bits

lookup_table = relay.const(lookup_table)
index_tensor = relay.reinterpret(input_arg, f"uint{in_dtype_num_bits}")
result = relay.take(lookup_table, index_tensor, axis=0, mode="fast")
return result
20 changes: 18 additions & 2 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
# pylint: disable=invalid-name, unused-argument
"""Backend QNN related feature registration"""
import numpy as np

import tvm
from tvm import relay
from tvm._ffi.base import TVMError
from .. import op as reg
from tvm.relay.qnn.op.canonicalizations import create_integer_lookup_op

from ....topi.x86.utils import target_has_sse42
from .. import op as reg

#################################################
# Register the functions for different operators.
Expand All @@ -46,6 +47,21 @@ def legalize_qnn_dense(attrs, inputs, types):
return qnn_dense_legalize(attrs, inputs, types)


# Registering QNN dense legalization function.
@reg.register_qnn_legalize("qnn.rsqrt")
def legalize_qnn_rsqrt(attrs, inputs, types):
return create_integer_lookup_op(
input_arg=inputs[0],
floating_point_func=lambda arr: 1 / np.sqrt(arr),
in_scale=inputs[1],
in_zero_point=inputs[2],
out_scale=inputs[3],
out_zero_point=inputs[4],
in_dtype=types[0].dtype,
out_dtype=types[0].dtype,
)


# Default to None. If overridden by target, this will not be run.
# Generic QNN Conv2D legalization function.
@tvm.target.generic_func
Expand Down
25 changes: 24 additions & 1 deletion python/tvm/relay/qnn/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@


def register_qnn_legalize(op_name, legal_op=None, level=10):
"""Register legal transformation function for a QNN op
"""Register legal transformation function for a QNN op.

This helps QNN match hardware intrinsics better and is run before
canonicalization.

Parameters
----------
Expand All @@ -34,3 +37,23 @@ def register_qnn_legalize(op_name, legal_op=None, level=10):
The priority level
"""
return tvm.ir.register_op_attr(op_name, "FTVMQnnLegalize", legal_op, level)


def register_qnn_canonicalize(op_name, legal_op=None, level=10):
"""Register canonicalization function for a QNN op.

This transforms QNN ops to mainline Relay components.

Parameters
----------
op_name : str
The name of the operator

legal_op: function (Attrs, List[Expr], List[relay.Type]) -> Expr
The function for transforming an expr to another expr.

level : int
The priority level
"""

return tvm.ir.register_op_attr(op_name, "FTVMQnnCanonicalize", legal_op, level)
4 changes: 4 additions & 0 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import tvm
from tvm import relay
from tvm.ir import TensorAffineType, TupleAffineType

# import to register canonicalization funcs for fq2i
# pylint: disable=unused-import
from tvm.relay.qnn.op import canonicalizations
from tvm.tir import bijective_layout

from ..op import register_fake_quantization_to_integer
Expand Down
3 changes: 1 addition & 2 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3322,8 +3322,7 @@ bool GatherRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "Gather: expect indices type to be TensorType but get " << types[1];
return false;
}
ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
<< "indices of gather must be tensor of integer";
ICHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
const auto param = attrs.as<GatherAttrs>();
ICHECK(param != nullptr);
ICHECK(param->axis.defined());
Expand Down
42 changes: 4 additions & 38 deletions src/relay/qnn/op/rsqrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,42 +69,9 @@ Expr MakeQuantizedRsqrt(Expr x, Expr scale, Expr zero_point, Expr output_scale,
return Call(op, {x, scale, zero_point, output_scale, output_zero_point}, Attrs(), {});
}

/*
* \brief Canonicalizes the QNN rsqrt op.
* \param attrs The empty attribute.
* \param new_args The new mutated args to the call node.
* \param arg_types The types of input and output.
* \return The sequence of Relay ops for add op.
*/
Expr QnnRsqrtCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
// At this time, due to the complexity of implementing this op in int8 or uint8,
// we dequantize the input, run the op in float, and then quantize the output (as below).
// This acts as a placeholder for future hardware enablement, where more hardware specific
// canonicalization can be provided.

// Get the args.
QnnUnaryOpArguments args(new_args);

// Get the input dtype and shape.
QnnUnaryOpTensorType input_type(arg_types, 0);

// Get the types for dequantize/quantize.
Array<tvm::relay::Type> types;
for (size_t i = 1; i < 5; ++i) {
types.push_back(arg_types[i]);
}

// Dequantize input.
auto dequantized_arg = Dequantize(args.x, args.scale, args.zero_point, types, -1);

// Compute Rsqrt(Q_x')
auto output = Rsqrt(dequantized_arg);

// Quantize output.
return Quantize(output, args.output_scale, args.output_zero_point, input_type.dtype, types, -1);
}

// Translation to relay is done via canonicalization/legalization functions in python
// e.g. python/tvm/relay/qnn/op/canonicalizations.py or
// python/tvm/relay/qnn/op/legalizations.py
RELAY_REGISTER_OP("qnn.rsqrt")
.describe("Elementwise rsqrt for quantized tensors.")
.set_num_inputs(5)
Expand All @@ -116,8 +83,7 @@ RELAY_REGISTER_OP("qnn.rsqrt")
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("QRsqrt", QnnRsqrtRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnRsqrtCanonicalize);
.set_attr<TNonComputational>("TNonComputational", true);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.rsqrt").set_body_typed(MakeQuantizedRsqrt);

Expand Down
Loading