Skip to content

Commit

Permalink
[QNN] Lookup operations for hard to implement operators (apache#10053)
Browse files Browse the repository at this point in the history
* initial tanh impl

* smalls error

* support uint and int lookup into tables

* reinterpret cast, working tanh tests

* refactor relay func creation

* basic casting tests

* explicitly say do not handle multi-channel lookups

* add example funcs

* fix silent fail

* fix some bugs with floating point funcs not working

* add TODO

* add tood

* canonicalizations

* refactor integer lookup ops into own folder

* fq2i stuff

* clean up existing tests

* flesh out todo

* more tests

* test on keeping shape good

* lookup table fix

* replace canonicalization for rsqrt

* remove canonicalization of rsqrt

* add asf headers

* topi tests

* gather supports unsigned integer tests

* fix things

* move to legalization

* jostle ci

* linting

* use take instead of gather

* remove gather changes

* undo changes

* undo changes

* undo changes

* move thing in range

* initial tanh impl

* smalls error

* support uint and int lookup into tables

* reinterpret cast, working tanh tests

* refactor relay func creation

* basic casting tests

* explicitly say do not handle multi-channel lookups

* add example funcs

* fix silent fail

* fix some bugs with floating point funcs not working

* add TODO

* add tood

* canonicalizations

* refactor integer lookup ops into own folder

* fq2i stuff

* clean up existing tests

* flesh out todo

* more tests

* test on keeping shape good

* lookup table fix

* replace canonicalization for rsqrt

* remove canonicalization of rsqrt

* add asf headers

* gather supports unsigned integer tests

* fix things

* move to legalization

* jostle ci

* linting

* use take instead of gather

* remove gather changes

* undo changes

* undo changes

* undo changes

* move thing in range

* lint

* remove unneeded line

* jostle

Co-authored-by: andrewzhaoluo (generated by with_the_same_user script) <andrewzhaoluo@system76-pc.localdomain>
  • Loading branch information
2 people authored and ylc committed Feb 16, 2022
1 parent a7f3cb8 commit 6d69ff3
Show file tree
Hide file tree
Showing 12 changed files with 461 additions and 69 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/qnn/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
"""QNN dialect related operators."""
from __future__ import absolute_import as _abs
from .qnn import *
from .op import register_qnn_legalize
from . import _qnn, legalizations, layout_conversions
from .op import register_qnn_legalize, register_qnn_canonicalize
from . import _qnn, legalizations, layout_conversions, canonicalizations
160 changes: 160 additions & 0 deletions python/tvm/relay/qnn/op/canonicalizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Consist of utilities and methods for lowering QNN into mainline relay."""
from typing import Callable

import numpy as np
import tvm
from tvm import relay


def run_const_expr(expr: "relay.Expr") -> np.ndarray:
"""Evaluate a const expression, receiving result as np array."""
mod = tvm.IRModule.from_expr(expr)
vm_exe = relay.create_executor("vm", mod=mod)
return vm_exe.evaluate()().asnumpy()


def create_integer_lookup_table(
floating_point_func: Callable[[np.ndarray], np.ndarray],
input_scale: "relay.Expr",
input_zero_point: "relay.Expr",
output_scale: "relay.Expr",
output_zero_point: "relay.Expr",
in_axis: int = -1,
out_axis: int = -1,
in_dtype: str = "uint8",
out_dtype: str = "uint8",
) -> np.ndarray:
"""
Return a table where each input indexes to the output quantizing the given function.
Note this also supports mapping unsigned and signed integers to each other.
Args:
floating_point_func: The numpy function which this table is to approximate
input_scale: The scale of the quantized input tensor.
input_zero_point: The zero point of the quantized input tensor.
output_scale: The scale of the quantized output tensor.
output_zero_point: The zero point of the quantized output tensor.
in_axis: The axis for multi-channel quantization of the input if applicable.
out_axis: The axis for multi-channel quantization of the output if applicable.
in_dtype: The dtype of the input tensor.
out_dtype: The wanted dtype of the output tensor.
Returns:
A numpy array where values in quantized space will index to the output in quantized space
approximating the given function.
"""
if not np.issubdtype(np.dtype(in_dtype), np.integer) or not np.issubdtype(
np.dtype(out_dtype), np.integer
):
raise ValueError(
f"Only integer dtypes allowed got {in_dtype} and {out_dtype} for in and out dtypes."
)

dtype_info = np.iinfo(in_dtype)

num_bits = dtype_info.bits

# Use TVMs quantization methods via relay to be consistent
# inputs_quantized = np.array(range(dtype_info.min, dtype_info.max + 1)).astype(in_dtype)

# First generate a list of all num_bit integer patterns
inputs_quantized = np.array(range(0, 2 ** num_bits), dtype=f"uint{num_bits}")

# Reinterpret bits as the real datatype
# Note what we are doing here is a bit tricky, the canonical view of our lookup table
# is using the uintX version. When we run the lookup in the relay graph, we cast the
# bit pattern back into this form.
inputs_quantized = inputs_quantized.view(in_dtype)
inputs_quantized = relay.const(inputs_quantized, dtype=in_dtype)
inputs_dequantized = run_const_expr(
relay.qnn.op.dequantize(
inputs_quantized,
input_scale=input_scale,
input_zero_point=input_zero_point,
axis=in_axis,
)
)

output_dequantized = relay.const(floating_point_func(inputs_dequantized))
output_quantized = run_const_expr(
relay.qnn.op.quantize(
output_dequantized, output_scale, output_zero_point, out_axis, out_dtype
)
)

return output_quantized


def create_integer_lookup_op(
input_arg: "relay.Expr",
floating_point_func: Callable[[np.array], np.array],
in_scale: "relay.Expr",
in_zero_point: "relay.Expr",
out_scale: "relay.Expr",
out_zero_point: "relay.Expr",
in_axis: int = -1,
out_axis: int = -1,
in_dtype: str = "uint8",
out_dtype: str = "uint8",
) -> "relay.Expr":
"""
Create a quantized version of the given floating point unary operation using table lookup.
Args:
input_arg: The quantized input to the final function.
floating_point_func: The numpy function which this table is to approximate
in_scale: The scale of the quantized input tensor.
in_zero_point: The zero point of the quantized input tensor.
out_scale: The scale of the quantized output tensor.
out_zero_point: The zero point of the quantized output tensor.
in_axis: The axis for multi-channel quantization of the input if applicable.
out_axis: The axis for multi-channel quantization of the output if applicable.
in_dtype: The dtype of the input tensor.
out_dtype: The wanted dtype of the output tensor.
Returns:
A Relay expression representing a quantized version of the given function.
"""

# TODO: handle multi-channel q, below will fail with multi-channel q
in_scale = in_scale.data.numpy().item()
in_zero_point = in_zero_point.data.numpy().item()
out_scale = out_scale.data.numpy().item()
out_zero_point = out_zero_point.data.numpy().item()

lookup_table = create_integer_lookup_table(
floating_point_func,
relay.const(in_scale),
relay.const(in_zero_point, dtype="int32"),
relay.const(out_scale),
relay.const(out_zero_point, dtype="int32"),
in_axis=in_axis,
in_dtype=in_dtype,
out_axis=out_axis,
out_dtype=out_dtype,
)

in_dtype_info = np.iinfo(in_dtype)
in_dtype_num_bits = in_dtype_info.bits

lookup_table = relay.const(lookup_table)
index_tensor = relay.reinterpret(input_arg, f"uint{in_dtype_num_bits}")
result = relay.take(lookup_table, index_tensor, axis=0, mode="fast")
return result
20 changes: 18 additions & 2 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
# pylint: disable=invalid-name, unused-argument
"""Backend QNN related feature registration"""
import numpy as np

import tvm
from tvm import relay
from tvm._ffi.base import TVMError
from .. import op as reg
from tvm.relay.qnn.op.canonicalizations import create_integer_lookup_op

from ....topi.x86.utils import target_has_sse42
from .. import op as reg

#################################################
# Register the functions for different operators.
Expand All @@ -46,6 +47,21 @@ def legalize_qnn_dense(attrs, inputs, types):
return qnn_dense_legalize(attrs, inputs, types)


# Registering QNN dense legalization function.
@reg.register_qnn_legalize("qnn.rsqrt")
def legalize_qnn_rsqrt(attrs, inputs, types):
return create_integer_lookup_op(
input_arg=inputs[0],
floating_point_func=lambda arr: 1 / np.sqrt(arr),
in_scale=inputs[1],
in_zero_point=inputs[2],
out_scale=inputs[3],
out_zero_point=inputs[4],
in_dtype=types[0].dtype,
out_dtype=types[0].dtype,
)


# Default to None. If overridden by target, this will not be run.
# Generic QNN Conv2D legalization function.
@tvm.target.generic_func
Expand Down
25 changes: 24 additions & 1 deletion python/tvm/relay/qnn/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@


def register_qnn_legalize(op_name, legal_op=None, level=10):
"""Register legal transformation function for a QNN op
"""Register legal transformation function for a QNN op.
This helps QNN match hardware intrinsics better and is run before
canonicalization.
Parameters
----------
Expand All @@ -34,3 +37,23 @@ def register_qnn_legalize(op_name, legal_op=None, level=10):
The priority level
"""
return tvm.ir.register_op_attr(op_name, "FTVMQnnLegalize", legal_op, level)


def register_qnn_canonicalize(op_name, legal_op=None, level=10):
"""Register canonicalization function for a QNN op.
This transforms QNN ops to mainline Relay components.
Parameters
----------
op_name : str
The name of the operator
legal_op: function (Attrs, List[Expr], List[relay.Type]) -> Expr
The function for transforming an expr to another expr.
level : int
The priority level
"""

return tvm.ir.register_op_attr(op_name, "FTVMQnnCanonicalize", legal_op, level)
4 changes: 4 additions & 0 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import tvm
from tvm import relay
from tvm.ir import TensorAffineType, TupleAffineType

# import to register canonicalization funcs for fq2i
# pylint: disable=unused-import
from tvm.relay.qnn.op import canonicalizations
from tvm.tir import bijective_layout

from ..op import register_fake_quantization_to_integer
Expand Down
3 changes: 1 addition & 2 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3376,8 +3376,7 @@ bool GatherRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "Gather: expect indices type to be TensorType but get " << types[1];
return false;
}
ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
<< "indices of gather must be tensor of integer";
ICHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
const auto param = attrs.as<GatherAttrs>();
ICHECK(param != nullptr);
ICHECK(param->axis.defined());
Expand Down
42 changes: 4 additions & 38 deletions src/relay/qnn/op/rsqrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,42 +69,9 @@ Expr MakeQuantizedRsqrt(Expr x, Expr scale, Expr zero_point, Expr output_scale,
return Call(op, {x, scale, zero_point, output_scale, output_zero_point}, Attrs(), {});
}

/*
* \brief Canonicalizes the QNN rsqrt op.
* \param attrs The empty attribute.
* \param new_args The new mutated args to the call node.
* \param arg_types The types of input and output.
* \return The sequence of Relay ops for add op.
*/
Expr QnnRsqrtCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
// At this time, due to the complexity of implementing this op in int8 or uint8,
// we dequantize the input, run the op in float, and then quantize the output (as below).
// This acts as a placeholder for future hardware enablement, where more hardware specific
// canonicalization can be provided.

// Get the args.
QnnUnaryOpArguments args(new_args);

// Get the input dtype and shape.
QnnUnaryOpTensorType input_type(arg_types, 0);

// Get the types for dequantize/quantize.
Array<tvm::relay::Type> types;
for (size_t i = 1; i < 5; ++i) {
types.push_back(arg_types[i]);
}

// Dequantize input.
auto dequantized_arg = Dequantize(args.x, args.scale, args.zero_point, types, -1);

// Compute Rsqrt(Q_x')
auto output = Rsqrt(dequantized_arg);

// Quantize output.
return Quantize(output, args.output_scale, args.output_zero_point, input_type.dtype, types, -1);
}

// Translation to relay is done via canonicalization/legalization functions in python
// e.g. python/tvm/relay/qnn/op/canonicalizations.py or
// python/tvm/relay/qnn/op/legalizations.py
RELAY_REGISTER_OP("qnn.rsqrt")
.describe("Elementwise rsqrt for quantized tensors.")
.set_num_inputs(5)
Expand All @@ -116,8 +83,7 @@ RELAY_REGISTER_OP("qnn.rsqrt")
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("QRsqrt", QnnRsqrtRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnRsqrtCanonicalize);
.set_attr<TNonComputational>("TNonComputational", true);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.rsqrt").set_body_typed(MakeQuantizedRsqrt);

Expand Down
Loading

0 comments on commit 6d69ff3

Please sign in to comment.