diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 332e8eb4f608..18231a4cc9fa 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -164,6 +164,8 @@ Improvements * GITHUB#12320: Add "direct to binary" option for DaciukMihovAutomatonBuilder and use it in TermInSetQuery#visit. (Greg Miller) +* GITHUB#12281: Require indexed KNN float vectors and query vectors to be finite. (Jonathan Ellis, Uwe Schindler) + Optimizations --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java index fabbc5259e3b..87cb6a9f056e 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java @@ -17,6 +17,7 @@ package org.apache.lucene.document; +import java.util.Objects; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -100,7 +101,7 @@ public static FieldType createFieldType( public KnnByteVectorField( String name, byte[] vector, VectorSimilarityFunction similarityFunction) { super(name, createType(vector, similarityFunction)); - fieldsData = vector; + fieldsData = vector; // null-check done above } /** @@ -136,6 +137,11 @@ public KnnByteVectorField(String name, byte[] vector, FieldType fieldType) { + " using byte[] but the field encoding is " + fieldType.vectorEncoding()); } + Objects.requireNonNull(vector, "vector value must not be null"); + if (vector.length != fieldType.vectorDimension()) { + throw new IllegalArgumentException( + "The number of vector dimensions does not match the field type"); + } fieldsData = vector; } diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java index d6673293c720..9d1cd02c013e 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java @@ -17,6 +17,7 @@ package org.apache.lucene.document; +import java.util.Objects; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -101,7 +102,7 @@ public static Query newVectorQuery(String field, float[] queryVector, int k) { public KnnFloatVectorField( String name, float[] vector, VectorSimilarityFunction similarityFunction) { super(name, createType(vector, similarityFunction)); - fieldsData = vector; + fieldsData = VectorUtil.checkFinite(vector); // null check done above } /** @@ -137,7 +138,12 @@ public KnnFloatVectorField(String name, float[] vector, FieldType fieldType) { + " using float[] but the field encoding is " + fieldType.vectorEncoding()); } - fieldsData = vector; + Objects.requireNonNull(vector, "vector value must not be null"); + if (vector.length != fieldType.vectorDimension()) { + throw new IllegalArgumentException( + "The number of vector dimensions does not match the field type"); + } + fieldsData = VectorUtil.checkFinite(vector); } /** Return the vector value of this field */ diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index cd8d73b8c26f..eb51b623831a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -56,7 +56,7 @@ abstract class AbstractKnnVectorQuery extends Query { private final Query filter; public AbstractKnnVectorQuery(String field, int k, Query filter) { - this.field = field; + this.field = Objects.requireNonNull(field, "field"); this.k = k; if (k < 1) { throw new IllegalArgumentException("k must be at least 1, got: " + k); diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 4ec617c24470..10345cd7adf4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -71,7 +71,7 @@ public KnnByteVectorQuery(String field, byte[] target, int k) { */ public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) { super(field, k, filter); - this.target = target; + this.target = Objects.requireNonNull(target, "target"); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 2b1b3a69582e..3036e7c45162 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Objects; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.FieldInfo; @@ -25,6 +26,7 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.VectorUtil; /** * Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest @@ -70,7 +72,7 @@ public KnnFloatVectorQuery(String field, float[] target, int k) { */ public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) { super(field, k, filter); - this.target = target; + this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target")); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 068a6edc035b..c9e1d368334e 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -34,7 +34,9 @@ public static float dotProduct(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return PROVIDER.dotProduct(a, b); + float r = PROVIDER.dotProduct(a, b); + assert Float.isFinite(r); + return r; } /** @@ -46,7 +48,9 @@ public static float cosine(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return PROVIDER.cosine(a, b); + float r = PROVIDER.cosine(a, b); + assert Float.isFinite(r); + return r; } /** Returns the cosine similarity between the two vectors. */ @@ -66,7 +70,9 @@ public static float squareDistance(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return PROVIDER.squareDistance(a, b); + float r = PROVIDER.squareDistance(a, b); + assert Float.isFinite(r); + return r; } /** Returns the sum of squared differences of the two vectors. */ @@ -154,4 +160,20 @@ public static float dotProductScore(byte[] a, byte[] b) { float denom = (float) (a.length * (1 << 15)); return 0.5f + dotProduct(a, b) / denom; } + + /** + * Checks if a float vector only has finite components. + * + * @param v bytes containing a vector + * @return the vector for call-chaining + * @throws IllegalArgumentException if any component of vector is not finite + */ + public static float[] checkFinite(float[] v) { + for (int i = 0; i < v.length; i++) { + if (!Float.isFinite(v[i])) { + throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]); + } + } + return v; + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java index da8483ed04de..665181e86788 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java @@ -102,7 +102,7 @@ public float cosine(float[] a, float[] b) { norm1 += elem1 * elem1; norm2 += elem2 * elem2; } - return (float) (sum / Math.sqrt(norm1 * norm2)); + return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index a23b9b5254ee..b44f7da8b8ad 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -56,7 +56,10 @@ public void addInOrder(int newNode, float newScore) { float previousScore = score[size - 1]; assert ((scoresDescOrder && (previousScore >= newScore)) || (scoresDescOrder == false && (previousScore <= newScore))) - : "Nodes are added in the incorrect order!"; + : "Nodes are added in the incorrect order! Comparing " + + newScore + + " to " + + Arrays.toString(ArrayUtil.copyOfSubArray(score, 0, size)); } node[size] = newNode; score[size] = newScore; diff --git a/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java b/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java index fd599c232fbd..61ec15e0d22d 100644 --- a/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java +++ b/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java @@ -217,7 +217,7 @@ public float cosine(float[] a, float[] b) { norm1 += elem1 * elem1; norm2 += elem2 * elem2; } - return (float) (sum / Math.sqrt(norm1 * norm2)); + return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java index 5dc11a52fb49..4a6365b53094 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java @@ -40,6 +40,7 @@ import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.TestVectorUtil; /** * Test that uses a default/lucene Implementation of {@link QueryTimeout} to exit out long running @@ -463,13 +464,21 @@ public void testVectorValues() throws IOException { ExitingReaderException.class, () -> leaf.searchNearestVectors( - "vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE)); + "vector", + TestVectorUtil.randomVector(dimension), + 5, + leaf.getLiveDocs(), + Integer.MAX_VALUE)); } else { DocIdSetIterator iter = leaf.getFloatVectorValues("vector"); scanAndRetrieve(leaf, iter); leaf.searchNearestVectors( - "vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE); + "vector", + TestVectorUtil.randomVector(dimension), + 5, + leaf.getLiveDocs(), + Integer.MAX_VALUE); } reader.close(); diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java index 039f69c9dc4c..c81077aa6daf 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java @@ -27,7 +27,7 @@ public void testScoresDescOrder() { neighbors.addInOrder(1, 0.8f); AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.9f)); - assertEquals("Nodes are added in the incorrect order!", ex.getMessage()); + assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage(); neighbors.insertSorted(3, 0.9f); assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors); @@ -76,7 +76,7 @@ public void testScoresAscOrder() { neighbors.addInOrder(1, 0.3f); AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.15f)); - assertEquals("Nodes are added in the incorrect order!", ex.getMessage()); + assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage(); neighbors.insertSorted(3, 0.3f); assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors);