Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve int4 compressed comparisons performance #13321

Merged
merged 19 commits into from
May 1, 2024

Conversation

benwtrent
Copy link
Member

This updates the int4 dot-product comparison to have an optimized one for when one of the vectors are compressed (the most common search case). This change actually makes the compressed search on ARM faster than the uncompressed. However, on AVX512/256, it still slightly slower than uncompressed, but it still much faster now with this optimization than before (eagerly decompressing).

This optimized is tied tightly with how the vectors are actually compressed and stored, consequently, I added a new scorer that is within the lucene99 codec.

So, this gives us 8x reduction over float32, well more than 2x faster queries than float32, and no need to rerank as the recall and accuracy are excellent.

Here are some lucene-util numbers over CohereV3 at 1024 dimensions:

New compressed numbers on ARM

recall	latency	nDoc	fanout	maxConn	beamWidth	visited	
0.891	 1.02	500000	0	64	250	        4182
0.910	 1.05	500000	0	64	250	        4348
0.925	 1.06	500000	0	64	250	        4511
0.974	 1.29	500000	0	64	250	        5782
0.986	 1.72	500000	0	64	250	        7285

Compared with uncompressed on ARM:

recall	latency	nDoc	fanout	maxConn	beamWidth	visited
0.891	 1.18 	500000	0	64	250	        4182
0.910	 1.24	500000	0	64	250	        4348
0.925	 1.25	500000	0	64	250	        4511
0.974	 1.57 	500000	0	64	250	        5782
0.986	 2.15	500000	0	64	250	        7285

Here are some JMH numbers as well (note, I am excluding odd number of indices as these don't support compression).

NOTE: PackedUnpacked is eagerly decompressing the vectors and then using dot-product, what is occurring now.

ARM:

VectorUtilBenchmark.binaryHalfByteScalar                   128  thrpt    5   25.072 ±  0.364  ops/us
VectorUtilBenchmark.binaryHalfByteScalar                   256  thrpt    5   12.534 ±  0.152  ops/us
VectorUtilBenchmark.binaryHalfByteScalar                   300  thrpt    5   10.715 ±  0.116  ops/us
VectorUtilBenchmark.binaryHalfByteScalar                   512  thrpt    5    6.275 ±  0.019  ops/us
VectorUtilBenchmark.binaryHalfByteScalar                   702  thrpt    5    4.577 ±  0.019  ops/us
VectorUtilBenchmark.binaryHalfByteScalar                  1024  thrpt    5    3.113 ±  0.010  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPacked             128  thrpt    5   24.161 ±  0.183  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPacked             256  thrpt    5   12.261 ±  0.356  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPacked             300  thrpt    5   10.535 ±  0.264  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPacked             512  thrpt    5    6.157 ±  0.062  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPacked             702  thrpt    5    4.505 ±  0.022  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPacked            1024  thrpt    5    3.104 ±  0.013  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPackedUnpacked     128  thrpt    5   15.179 ±  0.307  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPackedUnpacked     256  thrpt    5    7.883 ±  0.126  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPackedUnpacked     300  thrpt    5    6.826 ±  0.014  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPackedUnpacked     512  thrpt    5    3.996 ±  0.013  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPackedUnpacked     702  thrpt    5    2.934 ±  0.010  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPackedUnpacked    1024  thrpt    5    2.008 ±  0.026  ops/us
VectorUtilBenchmark.binaryHalfByteVector                   128  thrpt    5   69.386 ±  0.371  ops/us
VectorUtilBenchmark.binaryHalfByteVector                   256  thrpt    5   51.016 ±  0.180  ops/us
VectorUtilBenchmark.binaryHalfByteVector                   300  thrpt    5   40.186 ±  0.117  ops/us
VectorUtilBenchmark.binaryHalfByteVector                   512  thrpt    5   33.453 ±  0.096  ops/us
VectorUtilBenchmark.binaryHalfByteVector                   702  thrpt    5   23.627 ±  0.429  ops/us
VectorUtilBenchmark.binaryHalfByteVector                  1024  thrpt    5   19.833 ±  0.065  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked             128  thrpt    5   66.502 ±  0.335  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked             256  thrpt    5   47.178 ±  0.546  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked             300  thrpt    5   36.942 ±  0.122  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked             512  thrpt    5   29.735 ±  0.328  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked             702  thrpt    5   21.145 ±  0.085  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked            1024  thrpt    5   17.103 ±  0.050  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPackedUnpacked     128  thrpt    5   25.077 ±  0.459  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPackedUnpacked     256  thrpt    5   15.033 ±  0.041  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPackedUnpacked     300  thrpt    5   12.681 ±  0.222  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPackedUnpacked     512  thrpt    5    8.240 ±  0.461  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPackedUnpacked     702  thrpt    5    6.034 ±  0.022  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPackedUnpacked    1024  thrpt    5    4.320 ±  0.509  ops/us

AVX512:

VectorUtilBenchmark.binaryHalfByteScalar                   128  thrpt   15   17.767 ± 0.123  ops/us
VectorUtilBenchmark.binaryHalfByteScalar                   256  thrpt   15    9.248 ± 0.112  ops/us
VectorUtilBenchmark.binaryHalfByteScalar                   300  thrpt   15    8.095 ± 0.102  ops/us
VectorUtilBenchmark.binaryHalfByteScalar                   512  thrpt   15    4.723 ± 0.054  ops/us
VectorUtilBenchmark.binaryHalfByteScalar                   702  thrpt   15    3.580 ± 0.030  ops/us
VectorUtilBenchmark.binaryHalfByteScalar                  1024  thrpt   15    2.346 ± 0.047  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPacked             128  thrpt   15   14.119 ± 0.069  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPacked             256  thrpt   15    6.478 ± 0.037  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPacked             300  thrpt   15    4.157 ± 0.048  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPacked             512  thrpt   15    2.490 ± 0.017  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPacked             702  thrpt   15    1.817 ± 0.011  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPacked            1024  thrpt   15    1.240 ± 0.009  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPackedUnpacked     128  thrpt   15   10.022 ± 0.068  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPackedUnpacked     256  thrpt   15    5.583 ± 0.048  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPackedUnpacked     300  thrpt   15    4.667 ± 0.083  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPackedUnpacked     512  thrpt   15    2.698 ± 0.034  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPackedUnpacked     702  thrpt   15    1.931 ± 0.019  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPackedUnpacked    1024  thrpt   15    1.294 ± 0.019  ops/us
VectorUtilBenchmark.binaryHalfByteVector                   128  thrpt   15   84.577 ± 2.424  ops/us
VectorUtilBenchmark.binaryHalfByteVector                   207  thrpt   15   44.973 ± 0.448  ops/us
VectorUtilBenchmark.binaryHalfByteVector                   256  thrpt   15   51.049 ± 0.379  ops/us
VectorUtilBenchmark.binaryHalfByteVector                   300  thrpt   15   39.401 ± 0.527  ops/us
VectorUtilBenchmark.binaryHalfByteVector                   512  thrpt   15   27.654 ± 0.145  ops/us
VectorUtilBenchmark.binaryHalfByteVector                   702  thrpt   15   20.007 ± 0.120  ops/us
VectorUtilBenchmark.binaryHalfByteVector                  1024  thrpt   15   14.378 ± 0.070  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked             128  thrpt   15   58.249 ± 0.375  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked             256  thrpt   15   30.865 ± 0.164  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked             300  thrpt   15   22.795 ± 0.280  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked             512  thrpt   15   16.406 ± 0.506  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked             702  thrpt   15    9.555 ± 0.167  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked            1024  thrpt   15    8.638 ± 0.095  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPackedUnpacked     128  thrpt   15   15.507 ± 0.122  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPackedUnpacked     256  thrpt   15    9.079 ± 0.068  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPackedUnpacked     300  thrpt   15    7.788 ± 0.083  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPackedUnpacked     512  thrpt   15    4.992 ± 0.064  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPackedUnpacked     702  thrpt   15    3.622 ± 0.033  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPackedUnpacked    1024  thrpt   15    2.488 ± 0.019  ops/us

@benwtrent benwtrent added this to the 9.11.0 milestone Apr 23, 2024
Copy link
Contributor

@ChrisHegarty ChrisHegarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this is very nice. I left a few initial small comments.

Copy link
Contributor

@ChrisHegarty ChrisHegarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Separately, we know whether or not the vectors are packed when creating the scorer, in which case why not just pack the query vector, allowing to do the comparison more easily - the query vector can be coerced to match that of what is in the values. But this requires a little more separation, which can follow if worth it.

@benwtrent
Copy link
Member Author

benwtrent commented Apr 25, 2024

in which case why not just pack the query vector, allowing to do the comparison more easily

Because this makes things measurably slower. Having to decompressing only one of the vectors, and not adding the cost of compressing the query vector is the reason why this is so fast.

And I am not sure how it makes the comparisons easier. We then have to add LoC to decompress both and then compare both. The iteration code, etc. all remains the same.

[apologies - I incorrectly edited this comment when trying to reply to it. It should be restored now.]

@ChrisHegarty
Copy link
Contributor

Because this makes things measurably slower. ..

Ok. That's a very good reason! ;-)

And I am not sure how it makes the comparisons easier. We then have to add LoC to decompress both and then compare both. The iteration code, etc. all remains the same.

Right. It's just that the masking and shifting would apply equally to each of the inputs. And we'd load less data - since the number of bytes in the query would be less. I'm perfectly ok with the code as it is, I was mainly ensuring that such was considered (which is was).

@rmuir
Copy link
Member

rmuir commented Apr 25, 2024

panama logic looks good to me. Not for this PR, just a question, is it possible to use this same trick for 8 bit integer calculations too? I imagine this trick could probably help there as well, if it is safe.

// iterate in chunks of 1024 items to ensure we don't overflow the short accumulator

@ChrisHegarty
Copy link
Contributor

I went over the Panama code again, and confused myself about the potential to overflow - there is no issue, but it was not obvious as one has to match the number of potential values (not bytes) to the number of lanes x accumulators. Anyway, I added a test for boundary conditions - which will help when we move to optimising wider bit-widths.

@benwtrent
Copy link
Member Author

panama logic looks good to me. Not for this PR, just a question, is it possible to use this same trick for 8 bit integer calculations too? I imagine this trick could probably help there as well, if it is safe.

Some napkin math (this would need to be double checked...)

For int4, the max dot-product value is just 15*15=225, which means we can fit 290 summations in a short (with saturation) and 145 without saturation (since 1024 is Lucene's max vector length, I went without saturation).

For int8, the max dot-product (since we are signed) is -128*-128=16384. Which saturates a short with 2 values, 16384 * 2 = 32768 = Short.MAX_VALUE + 1. So, since its already saturating at 2 values, 3 values might as well be summed (4 "double saturates?" I don't know the correct term, but 4 loops over again and thus won't work).

128 bitwidth fits 8 shorts.

In int4 this means it sum up 1160 dimensions (each lane only summing 1160/8=145 multiplications) without saturation. Well within the 1024 limit.

In int8, this is reduced significantly to 24 dimensions (each lane summing 3 multiplications with saturation). While this small, it could reduce the number of vector lane expansions due to type sizes and java type casting, but the overhead of accounting for the saturation of shorts when expanding to integers could slow things down.

@benwtrent
Copy link
Member Author

OK, I ran on Google's ARM machine (Tau T2A machine series) to make sure the ARM performance improvements still exist for int4 (and it wasn't some silly macos thing):

Benchmark                                       (size)   Mode  Cnt  Score   Error   Units
VectorUtilBenchmark.binaryDotProductScalar        1024  thrpt   15  2.850 ± 0.002  ops/us
VectorUtilBenchmark.binaryDotProductVector        1024  thrpt   15  2.771 ± 0.016  ops/us
VectorUtilBenchmark.binaryHalfByteScalar          1024  thrpt   15  2.845 ± 0.009  ops/us
VectorUtilBenchmark.binaryHalfByteScalarPacked    1024  thrpt   15  2.128 ± 0.003  ops/us
VectorUtilBenchmark.binaryHalfByteVector          1024  thrpt   15  7.667 ± 0.007  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked    1024  thrpt   15  7.009 ± 0.025  ops/us

Something else funny is that this is almost at the same speed as floatVector on this hardware.

Micro-benchmarks are VERY close to float. This means the reduction in bytes read & parsing will make int4 much faster than float.

Benchmark                                  (size)   Mode  Cnt  Score   Error   Units
VectorUtilBenchmark.floatDotProductScalar    1024  thrpt   15  2.476 ± 0.028  ops/us
VectorUtilBenchmark.floatDotProductVector    1024  thrpt   75  8.703 ± 0.300  ops/us

@ChrisHegarty
Copy link
Contributor

In int8, this is reduced significantly to 24 dimensions (each lane summing 3 multiplications with saturation). While this small, it could reduce the number of vector lane expansions due to type sizes and java type casting, but the overhead of accounting for the saturation of shorts when expanding to integers could slow things down.

This is a great optimisation for int4, but won't work for int8.

We can get a bit more perf out of int4 on AVX 2 and AVX 512, by applying a similar same pattern.

AVX 256

A maximum of 2320 dimensions - each lane summing 2320/16=145 multiplications. Let's set an inner loop bound of 2048.

before

Benchmark                                       (size)   Mode  Cnt  Score   Error   Units
VectorUtilBenchmark.binaryHalfByteVectorPacked    1024  thrpt   15  9.715 ± 0.021  ops/us

after

Benchmark                                       (size)   Mode  Cnt   Score   Error   Units
VectorUtilBenchmark.binaryHalfByteVectorPacked    1024  thrpt   15  12.190 ± 0.029  ops/us

AVX 512

A maximum of 4640 dimensions - each lane summing 4640/32=145 multiplications. Let's set an inner loop bound of 4096.

before

Benchmark                                       (size)   Mode  Cnt   Score   Error   Units
VectorUtilBenchmark.binaryHalfByteVectorPacked    1024  thrpt   15  12.506 ± 0.064  ops/us

after

Benchmark                                       (size)   Mode  Cnt   Score   Error   Units
VectorUtilBenchmark.binaryHalfByteVectorPacked    1024  thrpt   15  14.235 ± 0.043  ops/us

@benwtrent
Copy link
Member Author

I double checked @ChrisHegarty 's improvements on 512:

My previous numbers:

VectorUtilBenchmark.binaryHalfByteVectorPacked            1024  thrpt   15    8.638 ± 0.095  ops/us

Here are the new numbers (I include float32 and regular binaryVector for comparisons):

VectorUtilBenchmark.binaryDotProductVector        1024  thrpt   15  14.248 ± 0.044  ops/us
VectorUtilBenchmark.binaryHalfByteVectorPacked    1024  thrpt   15  10.157 ± 0.044  ops/us
VectorUtilBenchmark.floatDotProductVector         1024  thrpt   75  17.633 ± 0.330  ops/us

@benwtrent
Copy link
Member Author

@uschindler I addressed your comments.

@uschindler
Copy link
Contributor

uschindler commented May 1, 2024

To me changes look fine.

For discussion: In my opinion the conditional code falling back to default code if neither is compressed should possibly be moved to the VectorUtil class and i think them the booleans could be removed and we only have two variants directly called from VectorUtil.

Uwe

@benwtrent benwtrent merged commit e40e108 into apache:main May 1, 2024
3 checks passed
@benwtrent benwtrent deleted the feature/packed-optimized-int4 branch May 1, 2024 14:05
benwtrent added a commit that referenced this pull request May 1, 2024
This updates the int4 dot-product comparison to have an optimized one for when one of the vectors are compressed (the most common search case). This change actually makes the compressed search on ARM faster than the uncompressed. However, on AVX512/256, it still slightly slower than uncompressed, but it still much faster now with this optimization than before (eagerly decompressing).

This optimized is tied tightly with how the vectors are actually compressed and stored, consequently, I added a new scorer that is within the lucene99 codec.

So, this gives us 8x reduction over float32, well more than 2x faster queries than float32, and no need to rerank as the recall and accuracy are excellent.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants