Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN] Implement 'qnn.softmax' #14536

Merged
merged 7 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,3 +1304,7 @@ def leaky_relu(x, alpha, input_scale, input_zero_point, output_scale, output_zer
output_scale,
output_zero_point,
)


def softmax(x, scale, zero_point, output_scale, output_zero_point, axis=-1):
return _make.softmax(x, axis, scale, zero_point, output_scale, output_zero_point)
13 changes: 13 additions & 0 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,3 +633,16 @@ def take(expr, type_map):

out = relay.op.take(arg, indices, **expr.attrs)
return [out, t]


@register_fake_quantization_to_integer("nn.softmax")
def softmax(expr, type_map):
"""Rewrite a softmax op"""
arg = expr.args[0]
arg_t = type_map[arg]
out_t = type_map[expr]

out = relay.qnn.op.softmax(
arg, arg_t.scale, arg_t.zero_point, out_t.scale, out_t.zero_point, **expr.attrs
)
return [out, out_t]
144 changes: 144 additions & 0 deletions src/relay/qnn/op/softmax.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/qnn/op/softmax.cc
* \brief QNN softmax operator.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>

#include "op_common.h"
#include "tvm/ir/expr.h"
#include "tvm/relay/attrs/nn.h"
#include "tvm/relay/type.h"
#include "tvm/runtime/data_type.h"
#include "tvm/runtime/logging.h"
#include "tvm/topi/reduction.h"

namespace tvm {
namespace relay {
namespace qnn {

bool QnnSoftmaxRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: input, scale, zero_point, output_scale, output_zero_point, output
ICHECK_EQ(types.size(), 6);
const auto* x = types[0].as<TensorTypeNode>();
if (x == nullptr) return false;
ICHECK(x->dtype == DataType::Int(8))
<< "Expected quantized softmax type(int8) for input but was " << x->dtype;

// Check the types of scale and zero points.
for (size_t i = 1; i < 5; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}

ICHECK(IsScalarType(types[1], DataType::Float(32))); // scale
ICHECK(IsScalarType(types[2], DataType::Int(32))); // zero_point
ICHECK(IsScalarType(types[3], DataType::Float(32))); // scale
ICHECK(IsScalarType(types[4], DataType::Int(32))); // zero_point

// Assign types for scale and zero points.
reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // scale
reporter->Assign(types[2], TensorType({}, DataType::Int(32))); // zero_point
reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // scale
reporter->Assign(types[4], TensorType({}, DataType::Int(32))); // zero_point

// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// IdentityRel infer type function.
Array<Type> tensor_types = {types[0], types[5]};
return IdentityRel(tensor_types, 2, attrs, reporter);
}

// Positional relay function to create quantized softmax operator used by frontend FFI.
Expr MakeQuantizedSoftmax(Expr x, int axis, Expr scale, Expr zero_point, Expr output_scale,
Expr output_zero_point) {
auto attrs = make_object<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("qnn.softmax");
return Call(op, {x, scale, zero_point, output_scale, output_zero_point}, Attrs(attrs), {});
}

/*
* \brief Canonicalizes the QNN softmax op.
masahi marked this conversation as resolved.
Show resolved Hide resolved
*/
Expr QnnSoftmaxCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
// Expected: input, scale, zero_point, output_scale, output_zero_point
ICHECK_EQ(new_args.size(), 5);

const Expr input_scale = new_args[1];
const Expr input_zero_point = new_args[2];
const Expr output_scale = new_args[3];
const Expr output_zero_point = new_args[4];
const int axis = attrs.as<SoftmaxAttrs>()->axis;

// Refer to the Algorithm 1 in https://arxiv.org/pdf/2207.01405.pdf

const Expr quantized_data =
Subtract(Cast(new_args[0], DataType::Int(32)), Cast(input_zero_point, DataType::Int(32)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to cast input_zero_point to int32? It is assumed that input_zero_point is of type int32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is apparently a redundant cast. I should remove it.


const Expr x_0 = ConvertDtype(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

This is in ShiftExp under algorithm 1

First off, if S > 1 then we already have issues as x_0 (I_0 in paper) maybe 0...

For attention I would expect output activation range to potentially be very large (see LLM int8 paper) so having a high scale factor is not unreasonable for some schemes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am not understanding something but seems like an obvious flaw...

I think you can get around this by not rounding I_0 in algorithm1 but keeping it a float and rounding when needed. However this would introduce runtime FLOPS.

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo Apr 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is when decomposing 2^(S * I_p) into integer and decimal components, you instead decompose:

2^[(S * 2^-n * I_p) * 2^n]. At compile time we can choose an n to make S * 2^-n << 1 to get around this problem. You can then apply the decomposition routine to the internal terms in parentheses and the outer 2^n now merely becomes another shift.

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe S here is defined as:

image

Which can be be artbirary depending on the range m.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can fallback to the fp32-based impl when S > 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that is another reasonable thing to do.

Round(Divide(MakeConstantScalar(DataType::Float(32), 1.f), input_scale)), DataType::Int(32));
const Expr max = Max(quantized_data, {axis}, true, false);
const Expr x = Subtract(quantized_data, max);

const auto const_i32 = [&](int32_t val) { return MakeConstantScalar(DataType::Int(32), val); };
const int n = 8;
const int m = 30;
const int bits = 8;
const Expr x_p = Subtract(Add(x, RightShift(x, const_i32(1))), RightShift(x, const_i32(4)));
const Expr q = Divide(x_p, Negative(x_0));
const Expr r = Subtract(x_p, Multiply(q, Negative(x_0)));
const Expr x_b = Add(RightShift(r, const_i32(1)), x_0);
const Expr exps = LeftShift(x_b, Subtract(const_i32(n), q));
const Expr sums = Sum(exps, {axis}, true, false);
const Expr output =
RightShift(Multiply(Divide(const_i32(1 << m), sums), exps), const_i32(m - (bits - 1)));
const Expr requantized =
Requantize(output, arg_types[0].as<TensorTypeNode>()->shape,
MakeConstantScalar(DataType::Float(32), 1.f / (1 << (bits - 1))), const_i32(0),
output_scale, output_zero_point, DataType::Int(bits), 0);

return requantized;
}

RELAY_REGISTER_OP("qnn.softmax")
.describe("Softmax for quantized tensors.")
.set_attrs_type<SoftmaxAttrs>()
.set_num_inputs(5)
.add_argument("data", "Quantized Tensor", "The input data.")
.add_argument("scale", "Tensor", "The quantization scale of the input tensor.")
.add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.")
.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
.add_argument("output_zero_point", "Tensor",
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("QSoftmax", QnnSoftmaxRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnSoftmaxCanonicalize);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.softmax").set_body_typed(MakeQuantizedSoftmax);

} // namespace qnn
} // namespace relay
} // namespace tvm
4 changes: 4 additions & 0 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,10 @@ inline Expr Copy(Expr data) {
return Call(op, {data}, Attrs(), {});
}

inline Expr Max(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
return MakeReduce(data, axis, keepdims, exclude, "max");
}

inline Expr Mean(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
return MakeReduce(data, axis, keepdims, exclude, "mean");
}
Expand Down
31 changes: 31 additions & 0 deletions tests/python/relay/test_pass_fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,5 +1114,36 @@ def test_fake_quantize_take():
compare_fq_to_int(op, [x_np])


def test_fake_quantize_softmax():
Copy link
Contributor

@ibsidorenko ibsidorenko Apr 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this test does not allow to check accuracy in full.

I have printed out output and found that ~70% of output values is equal to 0.0 in this test. This is because output after qnn.quantize operation is equal to "-128". It is not very interesting/representative case for "int8" data type.

Can you slightly modify this test in the following way:

  1. Remove second qnn.dequantize. Let's check output of qnn.dequantize + softmax + qnn.quantize only
  2. Play with QNN parameters (zero point, scale) in a such way that output from quantize will be in the range [-100, +100] for example. Not only "-128" like now

P.S.
I have checked output after qnn.quantize and see that some of value have diff by 7. I think it is too much and the accuracy is unsatisfactory... any thoughts?

Copy link
Contributor Author

@maekawatoshiki maekawatoshiki Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Play with QNN parameters (zero point, scale) in a such way that output from quantize will be in the range [-100, +100] for example. Not only "-128" like now

I'm not sure why we need to modify QNN parameters (of qnn.quantize).
I think that it's enough to change the value range specified in x_np = np.random.randint(-128, 127, ...) to satisfy the qnn.quantize output to be in the range of [-100, +100].
(Sorry if my understanding is wrong)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P.S.
I have checked output after qnn.quantize and see that some of value have diff by 7. I think it is too much and the accuracy is unsatisfactory... any thoughts?

In case all computation performed in integer-only arithmetic, how big diff is allowed for softmax operation generally? I'm not sure about this.
I'm also not sure if any other algorithms outperform the current implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sure, it's up to you. My main concern here is to avoid the case when all (or almost all) output values are equal to "-128" (it is not representative case for "int8" data type.)

shape = [50, 10]
x = relay.var("x", shape=shape, dtype="int8")

x = relay.qnn.op.dequantize(x, relay.const(0.08), relay.const(-48))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things we probably want to test:

  1. Different dtypes
  2. Different scale factors
  3. Different distributions along the axis of reduction (e.g. flat distribution should give flat probabilities, muiltiple spikes, etc.)

op = relay.op.nn.softmax(x, axis=1)
op = relay.qnn.op.quantize(op, relay.const(0.0039), relay.const(-128), out_dtype="int8")
op = relay.qnn.op.dequantize(op, relay.const(0.0039), relay.const(-128))

x_np = np.random.randint(-128, 127, size=shape, dtype="int8")
args = [x_np]

mod = tvm.IRModule.from_expr(op)
mod = tvm.relay.transform.InferType()(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger(hard_fail=True)(mod)
assert not tvm.ir.structural_equal(mod, mod_int)

result = (
relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm")
.evaluate()(*args)
.numpy()
)
result_int = (
relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm")
.evaluate()(*args)
.numpy()
)

assert np.allclose(result_int, result, atol=0.05)


if __name__ == "__main__":
tvm.testing.main()