Skip to content

Commit

Permalink
Add test case to ensure scalar quantization adheres to known ranges (#…
Browse files Browse the repository at this point in the history
…13336)

Lucene provides int7 & int4 quantization, we should ensure via our tests that the quantized values are within expected ranges.
  • Loading branch information
benwtrent committed May 2, 2024
1 parent cccbd4e commit 166b47d
Showing 1 changed file with 24 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

public class TestScalarQuantizer extends LuceneTestCase {

public void testQuantizeAndDeQuantize() throws IOException {
public void testQuantizeAndDeQuantize7Bit() throws IOException {
int dims = 128;
int numVecs = 100;
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
Expand All @@ -39,15 +39,26 @@ public void testQuantizeAndDeQuantize() throws IOException {
float[] dequantized = new float[dims];
byte[] quantized = new byte[dims];
byte[] requantized = new byte[dims];
byte maxDimValue = -128;
byte minDimValue = 127;
for (int i = 0; i < numVecs; i++) {
scalarQuantizer.quantize(floats[i], quantized, similarityFunction);
scalarQuantizer.deQuantize(quantized, dequantized);
scalarQuantizer.quantize(dequantized, requantized, similarityFunction);
for (int j = 0; j < dims; j++) {
if (quantized[j] > maxDimValue) {
maxDimValue = quantized[j];
}
if (quantized[j] < minDimValue) {
minDimValue = quantized[j];
}
assertEquals(dequantized[j], floats[i][j], 0.02);
assertEquals(quantized[j], requantized[j]);
}
}
// int7 should always quantize to 0-127
assertTrue(minDimValue >= (byte) 0);
assertTrue(maxDimValue <= (byte) 127);
}

public void testQuantiles() {
Expand Down Expand Up @@ -123,7 +134,7 @@ public void testScalarWithSampling() throws IOException {
}
}

public void testFromVectorsAutoInterval() throws IOException {
public void testFromVectorsAutoInterval4Bit() throws IOException {
int dims = 128;
int numVecs = 100;
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
Expand All @@ -137,15 +148,26 @@ public void testFromVectorsAutoInterval() throws IOException {
float[] dequantized = new float[dims];
byte[] quantized = new byte[dims];
byte[] requantized = new byte[dims];
byte maxDimValue = -128;
byte minDimValue = 127;
for (int i = 0; i < numVecs; i++) {
scalarQuantizer.quantize(floats[i], quantized, similarityFunction);
scalarQuantizer.deQuantize(quantized, dequantized);
scalarQuantizer.quantize(dequantized, requantized, similarityFunction);
for (int j = 0; j < dims; j++) {
if (quantized[j] > maxDimValue) {
maxDimValue = quantized[j];
}
if (quantized[j] < minDimValue) {
minDimValue = quantized[j];
}
assertEquals(dequantized[j], floats[i][j], 0.2);
assertEquals(quantized[j], requantized[j]);
}
}
// int4 should always quantize to 0-15
assertTrue(minDimValue >= (byte) 0);
assertTrue(maxDimValue <= (byte) 15);
}

static void shuffleArray(float[] ar) {
Expand Down

0 comments on commit 166b47d

Please sign in to comment.