Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN] Implement 'qnn.softmax' #14536

Merged
merged 7 commits into from
May 15, 2023
Merged

[QNN] Implement 'qnn.softmax' #14536

merged 7 commits into from
May 15, 2023

Conversation

maekawatoshiki
Copy link
Contributor

This PR implements qnn.softmax which is the quantized operator of nn.softmax.
The implementation is based on the algorithm proposed in https://arxiv.org/pdf/2207.01405.pdf.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Apr 8, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

src/relay/qnn/op/softmax.cc Outdated Show resolved Hide resolved
src/relay/qnn/op/softmax.cc Outdated Show resolved Hide resolved
@maekawatoshiki
Copy link
Contributor Author

Thank you for the review, @masahi. I've just modified the code.

const Expr q = Divide(Negative(x_p), x_0);
const Expr r = Subtract(x_p, Multiply(q, Negative(x_0)));
const Expr x_b = Add(RightShift(r, const_i32(1)), x_0);
const Expr exps = RightShift(LeftShift(x_b, const_i32(n)), q);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could q be bigger than n? If so, aren't we ending up doing a left shift by a negative number?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the paper, it's not clear to me why it is always safe to right-shift by q. I don't see an obvious bound on it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I'm not sure their eq. 14 is correct for a positive exponent, since we cannot decompose S * I into (-q) + fraction where q is positive... thought?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the decomposition can be thought of as treating S * I as the "real number" represented by I. For all real numbers we can take the integer and decimal components. The integer component in real-number space becomes a shift, and decimal component has to be quantized using S to get an integer r in the paper.

It can always be done, though whether it can be done efficiently is I think the issue. If we look at Algo 1, when S >>1 then I think the way they calculate r and q is wrong.

Copy link
Member

@masahi masahi Apr 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but the point is to decompose in a way that the integer part is negative, so that we can "divide" (right shift) by this integer. I don't see how such decomposition is possible if the input is positive.

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I believe it is due to the fact that the integers are all normalized by subtracting the maximum value (eq 12). Therefore all values are either <= 0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. But I'm still not sure about my first question #14536 (comment).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes I think it can which is bad. I think the way around this might be exploring decomposing 2^[(S * 2^-k * I_p) * 2^k] instead of 2^(S * I_p) for some well chosen k.

src/relay/qnn/op/softmax.cc Outdated Show resolved Hide resolved
op = relay.qnn.op.quantize(op, relay.const(0.0038), relay.const(-128), out_dtype="int8")
op = relay.qnn.op.dequantize(op, relay.const(0.0038), relay.const(-128))

x_np = np.random.random_sample([5, 10]).astype(np.float32) * 20.0 - 6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please test on more inputs, ideally on sampling from all int8 / uint8 range. This choice looks very contrived.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it looks good. We should also test on uint8. A new feature involving subtle arithmetic like this should be tested very thoroughly. See also the remark by @ibsidorenko

@masahi
Copy link
Member

masahi commented Apr 10, 2023

cc @AndrewZhaoLuo @ibsidorenko @Icemist please help review.

Co-authored-by: Toshiki Maekawa <toshiki.maekawa-aoha-renesas@aoha.co.jp>
@@ -1114,5 +1114,36 @@ def test_fake_quantize_take():
compare_fq_to_int(op, [x_np])


def test_fake_quantize_softmax():
Copy link
Contributor

@ibsidorenko ibsidorenko Apr 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this test does not allow to check accuracy in full.

I have printed out output and found that ~70% of output values is equal to 0.0 in this test. This is because output after qnn.quantize operation is equal to "-128". It is not very interesting/representative case for "int8" data type.

Can you slightly modify this test in the following way:

  1. Remove second qnn.dequantize. Let's check output of qnn.dequantize + softmax + qnn.quantize only
  2. Play with QNN parameters (zero point, scale) in a such way that output from quantize will be in the range [-100, +100] for example. Not only "-128" like now

P.S.
I have checked output after qnn.quantize and see that some of value have diff by 7. I think it is too much and the accuracy is unsatisfactory... any thoughts?

Copy link
Contributor Author

@maekawatoshiki maekawatoshiki Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Play with QNN parameters (zero point, scale) in a such way that output from quantize will be in the range [-100, +100] for example. Not only "-128" like now

I'm not sure why we need to modify QNN parameters (of qnn.quantize).
I think that it's enough to change the value range specified in x_np = np.random.randint(-128, 127, ...) to satisfy the qnn.quantize output to be in the range of [-100, +100].
(Sorry if my understanding is wrong)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P.S.
I have checked output after qnn.quantize and see that some of value have diff by 7. I think it is too much and the accuracy is unsatisfactory... any thoughts?

In case all computation performed in integer-only arithmetic, how big diff is allowed for softmax operation generally? I'm not sure about this.
I'm also not sure if any other algorithms outperform the current implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sure, it's up to you. My main concern here is to avoid the case when all (or almost all) output values are equal to "-128" (it is not representative case for "int8" data type.)

@ibsidorenko
Copy link
Contributor

ibsidorenko commented Apr 10, 2023

Hi, @maekawatoshiki ! Thank you for your PR, really very interesting work!
But I have my doubts about accuracy. Can you check my comment for unit test.

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting PR. I agree with folks that if this is the default implementation, it better be fairly accurate across a wide range of scales and inputs. While the paper you reference targets ViTs, we cannot assume only ViTs will use this implementation.

The paper itself is kind of poor when it comes to consistent notation but I think I've deciphered the intention.

Unfortunately, I have doubts on the base-line implementation discussed in the paper on the grounds of accuracy, much like many of the others have commented.

Furthermore, I believe a main issue with the algorithm is it can only deal with small input scale factors. I think this is probably a tenable problem but requires more work and rigorous testing.

const Expr quantized_data =
Subtract(Cast(new_args[0], DataType::Int(32)), Cast(input_zero_point, DataType::Int(32)));

const Expr x_0 = ConvertDtype(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

This is in ShiftExp under algorithm 1

First off, if S > 1 then we already have issues as x_0 (I_0 in paper) maybe 0...

For attention I would expect output activation range to potentially be very large (see LLM int8 paper) so having a high scale factor is not unreasonable for some schemes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am not understanding something but seems like an obvious flaw...

I think you can get around this by not rounding I_0 in algorithm1 but keeping it a float and rounding when needed. However this would introduce runtime FLOPS.

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo Apr 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is when decomposing 2^(S * I_p) into integer and decimal components, you instead decompose:

2^[(S * 2^-n * I_p) * 2^n]. At compile time we can choose an n to make S * 2^-n << 1 to get around this problem. You can then apply the decomposition routine to the internal terms in parentheses and the outer 2^n now merely becomes another shift.

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe S here is defined as:

image

Which can be be artbirary depending on the range m.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can fallback to the fp32-based impl when S > 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that is another reasonable thing to do.

shape = [50, 10]
x = relay.var("x", shape=shape, dtype="int8")

x = relay.qnn.op.dequantize(x, relay.const(0.08), relay.const(-48))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things we probably want to test:

  1. Different dtypes
  2. Different scale factors
  3. Different distributions along the axis of reduction (e.g. flat distribution should give flat probabilities, muiltiple spikes, etc.)

@maekawatoshiki
Copy link
Contributor Author

maekawatoshiki commented Apr 12, 2023

I appreciate your review!

I found a quantized softmax implementation in tensorflow lite: (https://github.com/tensorflow/tensorflow/blob/305fec9fddc3bdb5bb574a134b955bf4b07fd795/tensorflow/lite/kernels/internal/reference/softmax.h#L60), so I'm going to try it and compare the accuracy with my current implementation.

Besides, from #14536 (comment) and #14536 (comment), the current unit tests need to be improved, I acknowledged.

// Refer to the Algorithm 1 in https://arxiv.org/pdf/2207.01405.pdf

const Expr quantized_data =
Subtract(Cast(new_args[0], DataType::Int(32)), Cast(input_zero_point, DataType::Int(32)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to cast input_zero_point to int32? It is assumed that input_zero_point is of type int32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is apparently a redundant cast. I should remove it.

@maekawatoshiki
Copy link
Contributor Author

I took a look at the latest tensorflow implementation and noticed it is, for some reason, using float arithmetic: https://github.com/tensorflow/tensorflow/blob/98a187a8bfcdcf0c55c16f07b4a06b50e06a9a26/tensorflow/lite/kernels/internal/optimized/optimized_ops.h#L3471-L3477.

I also found another paper proposing quantized softmax (https://arxiv.org/pdf/2101.01321.pdf) and its implementation (https://github.com/kssteven418/I-BERT/blob/1b09c759d6aeb71312df9c6ef74fa268a87c934e/fairseq/quantization/utils/quant_modules.py#L578).
However, unlike the proposed algorithm in the paper, its implementation looks like using float arithmetic in several places of code.

I realized that it's difficult to implement integer-only quantized softmax with satisfying quality in a variety of input scales.

Let me abandon this PR for now to investigate further, and hopefully, I'll make another PR.

@AndrewZhaoLuo
Copy link
Contributor

Yeah it is challenging problem. I would not necessarily constrain yourself to making a single implementation that can handle all cases however.

Feel free to plan for several future specialization with fallbacks if needed imo. Just make sure tests reflect this.

Though, will leave up to you.

@shinh
Copy link
Contributor

shinh commented May 9, 2023

Does this make sense to add this qnn.softmax implementation as an optional feature? By default, we wouldn't enable qnn.softmax, but users could activate it when they find its precision satisfactory for their use case. To be more specific, I propose the following:

  1. Add @register_optional_fake_quantization_to_integer and use it in fake_quantization_to_integer.py for softmax:
@register_optional_fake_quantization_to_integer("nn.softmax")
def softmax(expr, type_map):
  ...
  1. Modify fake_quantization_to_integer.cc so that optional rewriters will be ignored unless users explicitly state they want to use quantized softmax by something like
relay.transform.FakeQuantizationToInteger(optional_qnn_ops={"nn.softmax"})(mod)

I guess it's OK to relax checks of unittests if this feature is optional? What are your thoughts?

@masahi
Copy link
Member

masahi commented May 9, 2023

@shinh Sounds good to me, as long as things won't break by default, I don't see any problem.

maekawatoshiki and others added 5 commits May 14, 2023 17:24
Co-authored-by: Toshiki Maekawa <toshiki.maekawa-aoha-renesas@aoha.co.jp>
Co-authored-by: Toshiki Maekawa <toshiki.maekawa-aoha-renesas@aoha.co.jp>
Co-authored-by: Toshiki Maekawa <toshiki.maekawa-aoha-renesas@aoha.co.jp>
Co-authored-by: Toshiki Maekawa <toshiki.maekawa-aoha-renesas@aoha.co.jp>
Co-authored-by: Toshiki Maekawa <toshiki.maekawa-aoha-renesas@aoha.co.jp>
Copy link
Contributor Author

@maekawatoshiki maekawatoshiki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-implemented the QOp softmax as an optional feature.
I'd appreciate your advice.

Comment on lines +92 to +96
const auto const_input_scale = new_args[1].as<ConstantNode>();
ICHECK(const_input_scale) << "Input scale should be constant.";
ICHECK(const_input_scale->is_scalar()) << "Input scale should be scalar.";
const float input_scale = static_cast<float*>(const_input_scale->data->data)[0];
ICHECK(input_scale <= 1.f) << "Input scale should be less than or equal to 1.";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion fails when the input scale does not meet the condition.

Comment on lines +1136 to +1138
mod_int = tvm.relay.transform.FakeQuantizationToInteger(
hard_fail=True, optional_qnn_ops=["nn.softmax"]
)(mod)
Copy link
Contributor Author

@maekawatoshiki maekawatoshiki May 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pass does not use the fully-integer implementation for nn.softmax, unless it is specified in optional_qnn_ops.

Comment on lines +1152 to +1162
# Check at least the softmax output is in ascending order,
# since it is difficult to use allclose due to not-so-good accuracy.
for qdq, qop in zip(result, result_int):
assert is_sorted(qdq)
assert is_sorted(qop)

try:
np.testing.assert_allclose(result_int, result, atol=1)
except AssertionError as e:
# To see the difference
print(e)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about this?
I checked the max absolute difference here is, in most cases, 0~6.
Also the overall trend of the softmax output didn't differ much between the QOp and QDQ implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as this implementation is useful for your use case, I'm fine with this. cc @ibsidorenko @AndrewZhaoLuo

Co-authored-by: Toshiki Maekawa <toshiki.maekawa-aoha-renesas@aoha.co.jp>
@masahi masahi merged commit eb1ea97 into apache:main May 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants