From 90aea1558a45c626e8c5f753d34e4e61004f86b5 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 9 Jun 2023 16:03:54 -0500 Subject: [PATCH 1/8] improve NeighborArray assert message when results do not sort correctly (this let me figure out that it was a NaN causing problems) --- .../src/java/org/apache/lucene/util/hnsw/NeighborArray.java | 5 ++++- .../test/org/apache/lucene/util/hnsw/TestNeighborArray.java | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index a23b9b5254ee..b44f7da8b8ad 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -56,7 +56,10 @@ public void addInOrder(int newNode, float newScore) { float previousScore = score[size - 1]; assert ((scoresDescOrder && (previousScore >= newScore)) || (scoresDescOrder == false && (previousScore <= newScore))) - : "Nodes are added in the incorrect order!"; + : "Nodes are added in the incorrect order! Comparing " + + newScore + + " to " + + Arrays.toString(ArrayUtil.copyOfSubArray(score, 0, size)); } node[size] = newNode; score[size] = newScore; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java index 039f69c9dc4c..c81077aa6daf 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java @@ -27,7 +27,7 @@ public void testScoresDescOrder() { neighbors.addInOrder(1, 0.8f); AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.9f)); - assertEquals("Nodes are added in the incorrect order!", ex.getMessage()); + assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage(); neighbors.insertSorted(3, 0.9f); assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors); @@ -76,7 +76,7 @@ public void testScoresAscOrder() { neighbors.addInOrder(1, 0.3f); AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.15f)); - assertEquals("Nodes are added in the incorrect order!", ex.getMessage()); + assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage(); neighbors.insertSorted(3, 0.3f); assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors); From c9b3cb836446ea89a6e82139156966ff8a1149df Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 9 Jun 2023 16:17:14 -0500 Subject: [PATCH 2/8] add checkFinite and fix TestExitableDirectoryReader (cosine taken with 0 is undefined) --- .../util/VectorUtilDefaultProvider.java | 30 ++++++++++++++++++- .../index/TestExitableDirectoryReader.java | 13 ++++++-- .../apache/lucene/util/TestVectorUtil.java | 10 +++++++ 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java index da8483ed04de..4f87ffd92d10 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java @@ -17,6 +17,8 @@ package org.apache.lucene.util; +import java.util.Arrays; + /** The default VectorUtil provider implementation. */ final class VectorUtilDefaultProvider implements VectorUtilProvider { @@ -85,6 +87,7 @@ public float dotProduct(float[] a, float[] b) { + b[i + 6] * a[i + 6] + b[i + 7] * a[i + 7]; } + checkFinite(res, a, b, "dot product"); return res; } @@ -102,7 +105,9 @@ public float cosine(float[] a, float[] b) { norm1 += elem1 * elem1; norm2 += elem2 * elem2; } - return (float) (sum / Math.sqrt(norm1 * norm2)); + var r = (float) (sum / Math.sqrt(norm1 * norm2)); + checkFinite(r, a, b, "cosine"); + return r; } @Override @@ -117,9 +122,32 @@ public float squareDistance(float[] a, float[] b) { float diff = a[i] - b[i]; squareSum += diff * diff; } + checkFinite(squareSum, a, b, "square distance"); return squareSum; } + private static void checkFinite(float r, float[] a, float[] b, String optype) { + if (!Float.isFinite(r)) { + for (int i = 0; i < a.length; i++) { + if (!Float.isFinite(a[i])) { + throw new IllegalArgumentException("v1[" + i + "]=" + a[i]); + } + if (!Float.isFinite(b[i])) { + throw new IllegalArgumentException("v2[" + i + "]=" + b[i]); + } + } + throw new IllegalArgumentException( + "Non-finite (" + + r + + ") " + + optype + + " similarity from " + + Arrays.toString(a) + + " and " + + Arrays.toString(b)); + } + } + private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) { float diff0 = v1[index + 0] - v2[index + 0]; float diff1 = v1[index + 1] - v2[index + 1]; diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java index 5dc11a52fb49..4a6365b53094 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java @@ -40,6 +40,7 @@ import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.TestVectorUtil; /** * Test that uses a default/lucene Implementation of {@link QueryTimeout} to exit out long running @@ -463,13 +464,21 @@ public void testVectorValues() throws IOException { ExitingReaderException.class, () -> leaf.searchNearestVectors( - "vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE)); + "vector", + TestVectorUtil.randomVector(dimension), + 5, + leaf.getLiveDocs(), + Integer.MAX_VALUE)); } else { DocIdSetIterator iter = leaf.getFloatVectorValues("vector"); scanAndRetrieve(leaf, iter); leaf.searchNearestVectors( - "vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE); + "vector", + TestVectorUtil.randomVector(dimension), + 5, + leaf.getLiveDocs(), + Integer.MAX_VALUE); } reader.close(); diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java index 358db95641f2..fde211ab35f8 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java @@ -103,6 +103,16 @@ public void testCosineThrowsForDimensionMismatch() { expectThrows(IllegalArgumentException.class, () -> VectorUtil.cosine(u, v)); } + public void testCosineThrowsForNaN() { + float[] v = {1, 0, Float.NaN}, u = {0, 0, 0}; + expectThrows(IllegalArgumentException.class, () -> VectorUtil.cosine(u, v)); + } + + public void testCosineThrowsForInfinity() { + float[] v = {1, 0, Float.NEGATIVE_INFINITY}, u = {0, 0, 0}; + expectThrows(IllegalArgumentException.class, () -> VectorUtil.cosine(u, v)); + } + public void testNormalize() { float[] v = randomVector(); v[random().nextInt(v.length)] = 1; // ensure vector is not all zeroes From 113c4ad5d43652c102b105e02e3bcfeb490822f5 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Sat, 10 Jun 2023 13:38:54 -0500 Subject: [PATCH 3/8] revert changes to VUDefaultProvider; add checkFinite to VectorUtil instead, and call from KFVF.createType --- .../lucene/document/KnnFloatVectorField.java | 3 ++ .../org/apache/lucene/util/VectorUtil.java | 38 +++++++++++++++++-- .../util/VectorUtilDefaultProvider.java | 30 +-------------- 3 files changed, 39 insertions(+), 32 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java index d6673293c720..20e00dc1427e 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java @@ -17,6 +17,8 @@ package org.apache.lucene.document; +import static org.apache.lucene.util.VectorUtil.checkFinite; + import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -50,6 +52,7 @@ private static FieldType createType(float[] v, VectorSimilarityFunction similari throw new IllegalArgumentException( "cannot index vectors with dimension greater than " + FloatVectorValues.MAX_DIMENSIONS); } + checkFinite(v); if (similarityFunction == null) { throw new IllegalArgumentException("similarity function must not be null"); } diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 068a6edc035b..dc40b3876be9 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -17,6 +17,8 @@ package org.apache.lucene.util; +import java.util.Arrays; + /** Utilities for computations with numeric arrays */ public final class VectorUtil { @@ -34,7 +36,9 @@ public static float dotProduct(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return PROVIDER.dotProduct(a, b); + float r = PROVIDER.dotProduct(a, b); + checkFinite(r, a, b, "dot product"); + return r; } /** @@ -46,7 +50,9 @@ public static float cosine(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return PROVIDER.cosine(a, b); + float r = PROVIDER.cosine(a, b); + checkFinite(r, a, b, "dot product"); + return r; } /** Returns the cosine similarity between the two vectors. */ @@ -66,7 +72,9 @@ public static float squareDistance(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return PROVIDER.squareDistance(a, b); + float r = PROVIDER.squareDistance(a, b); + checkFinite(r, a, b, "square distance"); + return r; } /** Returns the sum of squared differences of the two vectors. */ @@ -154,4 +162,28 @@ public static float dotProductScore(byte[] a, byte[] b) { float denom = (float) (a.length * (1 << 15)); return 0.5f + dotProduct(a, b) / denom; } + + private static void checkFinite(float r, float[] a, float[] b, String optype) { + if (!Float.isFinite(r)) { + checkFinite(a); + checkFinite(b); + throw new IllegalArgumentException( + "Non-finite (" + + r + + ") " + + optype + + " similarity from " + + Arrays.toString(a) + + " and " + + Arrays.toString(b)); + } + } + + public static void checkFinite(float[] a) { + for (int i = 0; i < a.length; i++) { + if (!Float.isFinite(a[i])) { + throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + a[i]); + } + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java index 4f87ffd92d10..da8483ed04de 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java @@ -17,8 +17,6 @@ package org.apache.lucene.util; -import java.util.Arrays; - /** The default VectorUtil provider implementation. */ final class VectorUtilDefaultProvider implements VectorUtilProvider { @@ -87,7 +85,6 @@ public float dotProduct(float[] a, float[] b) { + b[i + 6] * a[i + 6] + b[i + 7] * a[i + 7]; } - checkFinite(res, a, b, "dot product"); return res; } @@ -105,9 +102,7 @@ public float cosine(float[] a, float[] b) { norm1 += elem1 * elem1; norm2 += elem2 * elem2; } - var r = (float) (sum / Math.sqrt(norm1 * norm2)); - checkFinite(r, a, b, "cosine"); - return r; + return (float) (sum / Math.sqrt(norm1 * norm2)); } @Override @@ -122,32 +117,9 @@ public float squareDistance(float[] a, float[] b) { float diff = a[i] - b[i]; squareSum += diff * diff; } - checkFinite(squareSum, a, b, "square distance"); return squareSum; } - private static void checkFinite(float r, float[] a, float[] b, String optype) { - if (!Float.isFinite(r)) { - for (int i = 0; i < a.length; i++) { - if (!Float.isFinite(a[i])) { - throw new IllegalArgumentException("v1[" + i + "]=" + a[i]); - } - if (!Float.isFinite(b[i])) { - throw new IllegalArgumentException("v2[" + i + "]=" + b[i]); - } - } - throw new IllegalArgumentException( - "Non-finite (" - + r - + ") " - + optype - + " similarity from " - + Arrays.toString(a) - + " and " - + Arrays.toString(b)); - } - } - private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) { float diff0 = v1[index + 0] - v2[index + 0]; float diff1 = v1[index + 1] - v2[index + 1]; From f0009287e9ad7420cb1cf52a1788cd95cd63cafa Mon Sep 17 00:00:00 2001 From: Uwe Schindler Date: Sun, 11 Jun 2023 13:49:38 +0200 Subject: [PATCH 4/8] Check vector parameters to be finite in KnnFloat types and queries --- .../lucene/document/KnnByteVectorField.java | 5 ++- .../lucene/document/KnnFloatVectorField.java | 9 ++--- .../lucene/search/AbstractKnnVectorQuery.java | 2 +- .../lucene/search/KnnByteVectorQuery.java | 2 +- .../lucene/search/KnnFloatVectorQuery.java | 4 +- .../org/apache/lucene/util/VectorUtil.java | 39 +++++++------------ .../apache/lucene/util/TestVectorUtil.java | 10 ----- 7 files changed, 26 insertions(+), 45 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java index fabbc5259e3b..281e4dffe999 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java @@ -17,6 +17,7 @@ package org.apache.lucene.document; +import java.util.Objects; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -100,7 +101,7 @@ public static FieldType createFieldType( public KnnByteVectorField( String name, byte[] vector, VectorSimilarityFunction similarityFunction) { super(name, createType(vector, similarityFunction)); - fieldsData = vector; + fieldsData = vector; // null-check done above } /** @@ -136,7 +137,7 @@ public KnnByteVectorField(String name, byte[] vector, FieldType fieldType) { + " using byte[] but the field encoding is " + fieldType.vectorEncoding()); } - fieldsData = vector; + fieldsData = Objects.requireNonNull(vector, "vector value must not be null"); } /** Return the vector value of this field */ diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java index 20e00dc1427e..bc63a2642d3a 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java @@ -17,8 +17,7 @@ package org.apache.lucene.document; -import static org.apache.lucene.util.VectorUtil.checkFinite; - +import java.util.Objects; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -52,7 +51,6 @@ private static FieldType createType(float[] v, VectorSimilarityFunction similari throw new IllegalArgumentException( "cannot index vectors with dimension greater than " + FloatVectorValues.MAX_DIMENSIONS); } - checkFinite(v); if (similarityFunction == null) { throw new IllegalArgumentException("similarity function must not be null"); } @@ -104,7 +102,7 @@ public static Query newVectorQuery(String field, float[] queryVector, int k) { public KnnFloatVectorField( String name, float[] vector, VectorSimilarityFunction similarityFunction) { super(name, createType(vector, similarityFunction)); - fieldsData = vector; + fieldsData = VectorUtil.checkFinite(vector); // null check done above } /** @@ -140,7 +138,8 @@ public KnnFloatVectorField(String name, float[] vector, FieldType fieldType) { + " using float[] but the field encoding is " + fieldType.vectorEncoding()); } - fieldsData = vector; + fieldsData = + VectorUtil.checkFinite(Objects.requireNonNull(vector, "vector value must not be null")); } /** Return the vector value of this field */ diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index cd8d73b8c26f..eb51b623831a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -56,7 +56,7 @@ abstract class AbstractKnnVectorQuery extends Query { private final Query filter; public AbstractKnnVectorQuery(String field, int k, Query filter) { - this.field = field; + this.field = Objects.requireNonNull(field, "field"); this.k = k; if (k < 1) { throw new IllegalArgumentException("k must be at least 1, got: " + k); diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 4ec617c24470..10345cd7adf4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -71,7 +71,7 @@ public KnnByteVectorQuery(String field, byte[] target, int k) { */ public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) { super(field, k, filter); - this.target = target; + this.target = Objects.requireNonNull(target, "target"); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 2b1b3a69582e..3036e7c45162 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Objects; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.FieldInfo; @@ -25,6 +26,7 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.VectorUtil; /** * Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest @@ -70,7 +72,7 @@ public KnnFloatVectorQuery(String field, float[] target, int k) { */ public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) { super(field, k, filter); - this.target = target; + this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target")); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index dc40b3876be9..d313a2714d0c 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -17,8 +17,6 @@ package org.apache.lucene.util; -import java.util.Arrays; - /** Utilities for computations with numeric arrays */ public final class VectorUtil { @@ -37,7 +35,7 @@ public static float dotProduct(float[] a, float[] b) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } float r = PROVIDER.dotProduct(a, b); - checkFinite(r, a, b, "dot product"); + assert Float.isFinite(r); return r; } @@ -51,7 +49,7 @@ public static float cosine(float[] a, float[] b) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } float r = PROVIDER.cosine(a, b); - checkFinite(r, a, b, "dot product"); + assert Float.isFinite(r); return r; } @@ -73,7 +71,7 @@ public static float squareDistance(float[] a, float[] b) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } float r = PROVIDER.squareDistance(a, b); - checkFinite(r, a, b, "square distance"); + assert Float.isFinite(r); return r; } @@ -163,27 +161,18 @@ public static float dotProductScore(byte[] a, byte[] b) { return 0.5f + dotProduct(a, b) / denom; } - private static void checkFinite(float r, float[] a, float[] b, String optype) { - if (!Float.isFinite(r)) { - checkFinite(a); - checkFinite(b); - throw new IllegalArgumentException( - "Non-finite (" - + r - + ") " - + optype - + " similarity from " - + Arrays.toString(a) - + " and " - + Arrays.toString(b)); - } - } - - public static void checkFinite(float[] a) { - for (int i = 0; i < a.length; i++) { - if (!Float.isFinite(a[i])) { - throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + a[i]); + /** + * Checks if a float vector only has finite components. + * + * @param v bytes containing a vector + * @return the vector for call-chaining + */ + public static float[] checkFinite(float[] v) { + for (int i = 0; i < v.length; i++) { + if (!Float.isFinite(v[i])) { + throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]); } } + return v; } } diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java index fde211ab35f8..358db95641f2 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java @@ -103,16 +103,6 @@ public void testCosineThrowsForDimensionMismatch() { expectThrows(IllegalArgumentException.class, () -> VectorUtil.cosine(u, v)); } - public void testCosineThrowsForNaN() { - float[] v = {1, 0, Float.NaN}, u = {0, 0, 0}; - expectThrows(IllegalArgumentException.class, () -> VectorUtil.cosine(u, v)); - } - - public void testCosineThrowsForInfinity() { - float[] v = {1, 0, Float.NEGATIVE_INFINITY}, u = {0, 0, 0}; - expectThrows(IllegalArgumentException.class, () -> VectorUtil.cosine(u, v)); - } - public void testNormalize() { float[] v = randomVector(); v[random().nextInt(v.length)] = 1; // ensure vector is not all zeroes From 942536bb700caf5b268347cb1b32f97fa95d2548 Mon Sep 17 00:00:00 2001 From: Uwe Schindler Date: Sun, 11 Jun 2023 14:06:42 +0200 Subject: [PATCH 5/8] improve javadocs --- lucene/core/src/java/org/apache/lucene/util/VectorUtil.java | 1 + 1 file changed, 1 insertion(+) diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index d313a2714d0c..c9e1d368334e 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -166,6 +166,7 @@ public static float dotProductScore(byte[] a, byte[] b) { * * @param v bytes containing a vector * @return the vector for call-chaining + * @throws IllegalArgumentException if any component of vector is not finite */ public static float[] checkFinite(float[] v) { for (int i = 0; i < v.length; i++) { From 2dc84849a40629ec6564bd05f39de09921506a4f Mon Sep 17 00:00:00 2001 From: Uwe Schindler Date: Mon, 12 Jun 2023 16:52:45 +0200 Subject: [PATCH 6/8] Make sure the last multiplication of norms are explicitly using double. The current code is quite unclear, so it is better to be explicit (Math.sqrt uses double argument anyways) --- .../java/org/apache/lucene/util/VectorUtilDefaultProvider.java | 2 +- .../java20/org/apache/lucene/util/VectorUtilPanamaProvider.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java index da8483ed04de..665181e86788 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java @@ -102,7 +102,7 @@ public float cosine(float[] a, float[] b) { norm1 += elem1 * elem1; norm2 += elem2 * elem2; } - return (float) (sum / Math.sqrt(norm1 * norm2)); + return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); } @Override diff --git a/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java b/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java index fd599c232fbd..61ec15e0d22d 100644 --- a/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java +++ b/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java @@ -217,7 +217,7 @@ public float cosine(float[] a, float[] b) { norm1 += elem1 * elem1; norm2 += elem2 * elem2; } - return (float) (sum / Math.sqrt(norm1 * norm2)); + return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); } @Override From 0bde96904565c88ce0766b67763cf17ed6ba389a Mon Sep 17 00:00:00 2001 From: Uwe Schindler Date: Mon, 12 Jun 2023 17:08:51 +0200 Subject: [PATCH 7/8] Add length check for the constructor that does not create a new field type --- .../org/apache/lucene/document/KnnByteVectorField.java | 7 ++++++- .../org/apache/lucene/document/KnnFloatVectorField.java | 8 ++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java index 281e4dffe999..87cb6a9f056e 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java @@ -137,7 +137,12 @@ public KnnByteVectorField(String name, byte[] vector, FieldType fieldType) { + " using byte[] but the field encoding is " + fieldType.vectorEncoding()); } - fieldsData = Objects.requireNonNull(vector, "vector value must not be null"); + Objects.requireNonNull(vector, "vector value must not be null"); + if (vector.length != fieldType.vectorDimension()) { + throw new IllegalArgumentException( + "The number of vector dimensions does not match the field type"); + } + fieldsData = vector; } /** Return the vector value of this field */ diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java index bc63a2642d3a..9d1cd02c013e 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java @@ -138,8 +138,12 @@ public KnnFloatVectorField(String name, float[] vector, FieldType fieldType) { + " using float[] but the field encoding is " + fieldType.vectorEncoding()); } - fieldsData = - VectorUtil.checkFinite(Objects.requireNonNull(vector, "vector value must not be null")); + Objects.requireNonNull(vector, "vector value must not be null"); + if (vector.length != fieldType.vectorDimension()) { + throw new IllegalArgumentException( + "The number of vector dimensions does not match the field type"); + } + fieldsData = VectorUtil.checkFinite(vector); } /** Return the vector value of this field */ From 0ab0be67c3107d6090f3eb28cbcd595905865c70 Mon Sep 17 00:00:00 2001 From: Uwe Schindler Date: Mon, 12 Jun 2023 17:13:29 +0200 Subject: [PATCH 8/8] Add CHANGES.txt --- lucene/CHANGES.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 332e8eb4f608..18231a4cc9fa 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -164,6 +164,8 @@ Improvements * GITHUB#12320: Add "direct to binary" option for DaciukMihovAutomatonBuilder and use it in TermInSetQuery#visit. (Greg Miller) +* GITHUB#12281: Require indexed KNN float vectors and query vectors to be finite. (Jonathan Ellis, Uwe Schindler) + Optimizations ---------------------