From 3d671a0fbef159e970b060d3f942fba481bafc8b Mon Sep 17 00:00:00 2001 From: Chris Hegarty <62058229+ChrisHegarty@users.noreply.github.com> Date: Thu, 16 May 2024 14:59:56 +0100 Subject: [PATCH] Fix bug in SQ when just a single vector present in a segment (#13374) This commit fixes a corner case in the ScalarQuantizer when just a single vector is present. I ran into this when updating a test that previously passed successfully with Lucene 9.10 but fails in 9.x. The score error correction is calculated to be NaN, as there are no score docs or variance. --- lucene/CHANGES.txt | 2 + .../util/quantization/ScalarQuantizer.java | 2 +- ...stLucene99ScalarQuantizedVectorScorer.java | 55 +++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index a00c594f9acd..8afa946ca723 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -372,6 +372,8 @@ Bug Fixes * GITHUB#13378: Fix points writing with no values (Chris Hegarty) +* GITHUB#13374: Fix bug in SQ when just a single vector present in a segment (Chris Hegarty) + Build --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java index 421c181f0210..fb07e0055719 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java @@ -698,7 +698,7 @@ public ScoreErrorCorrelator( } corr.add(1 - errors.var() / scoreVariance); } - return corr.mean; + return Double.isNaN(corr.mean) ? 0.0 : corr.mean; } } } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java index cbfecdeb1da0..a244302692e1 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java @@ -29,12 +29,14 @@ import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; @@ -252,6 +254,59 @@ private static void indexVectors( } } + public void testSingleVectorPerSegmentCosine() throws IOException { + testSingleVectorPerSegment(VectorSimilarityFunction.COSINE); + } + + public void testSingleVectorPerSegmentDot() throws IOException { + testSingleVectorPerSegment(VectorSimilarityFunction.DOT_PRODUCT); + } + + public void testSingleVectorPerSegmentEuclidean() throws IOException { + testSingleVectorPerSegment(VectorSimilarityFunction.EUCLIDEAN); + } + + public void testSingleVectorPerSegmentMIP() throws IOException { + testSingleVectorPerSegment(VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT); + } + + private void testSingleVectorPerSegment(VectorSimilarityFunction sim) throws IOException { + var codec = getCodec(7, false); + try (Directory dir = newDirectory()) { + try (IndexWriter writer = new IndexWriter(dir, new IndexWriterConfig().setCodec(codec))) { + Document doc2 = new Document(); + doc2.add(new KnnFloatVectorField("field", new float[] {0.8f, 0.6f}, sim)); + doc2.add(newTextField("id", "A", Field.Store.YES)); + writer.addDocument(doc2); + writer.commit(); + + Document doc1 = new Document(); + doc1.add(new KnnFloatVectorField("field", new float[] {0.6f, 0.8f}, sim)); + doc1.add(newTextField("id", "B", Field.Store.YES)); + writer.addDocument(doc1); + writer.commit(); + + Document doc3 = new Document(); + doc3.add(new KnnFloatVectorField("field", new float[] {-0.6f, -0.8f}, sim)); + doc3.add(newTextField("id", "C", Field.Store.YES)); + writer.addDocument(doc3); + writer.commit(); + + writer.forceMerge(1); + } + try (DirectoryReader reader = DirectoryReader.open(dir)) { + LeafReader leafReader = getOnlyLeafReader(reader); + StoredFields storedFields = reader.storedFields(); + float[] queryVector = new float[] {0.6f, 0.8f}; + var hits = leafReader.searchNearestVectors("field", queryVector, 3, null, 100); + assertEquals(hits.scoreDocs.length, 3); + assertEquals("B", storedFields.document(hits.scoreDocs[0].doc).get("id")); + assertEquals("A", storedFields.document(hits.scoreDocs[1].doc).get("id")); + assertEquals("C", storedFields.document(hits.scoreDocs[2].doc).get("id")); + } + } + } + private static byte[] floatToByteArray(float value) { return ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putFloat(value).array(); }