Skip to content

Commit

Permalink
Add checks in KNNVectorField / KNNVectorQuery to only allow non-null,…
Browse files Browse the repository at this point in the history
… non-empty and finite vectors (#12281)


---------

Co-authored-by: Uwe Schindler <uschindler@apache.org>
  • Loading branch information
jbellis and uschindler authored Jun 13, 2023
1 parent 30eba6d commit 071461e
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 16 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ Improvements
* GITHUB#12320: Add "direct to binary" option for DaciukMihovAutomatonBuilder and use it in TermInSetQuery#visit.
(Greg Miller)

* GITHUB#12281: Require indexed KNN float vectors and query vectors to be finite. (Jonathan Ellis, Uwe Schindler)

Optimizations
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.lucene.document;

import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
Expand Down Expand Up @@ -100,7 +101,7 @@ public static FieldType createFieldType(
public KnnByteVectorField(
String name, byte[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
fieldsData = vector; // null-check done above
}

/**
Expand Down Expand Up @@ -136,6 +137,11 @@ public KnnByteVectorField(String name, byte[] vector, FieldType fieldType) {
+ " using byte[] but the field encoding is "
+ fieldType.vectorEncoding());
}
Objects.requireNonNull(vector, "vector value must not be null");
if (vector.length != fieldType.vectorDimension()) {
throw new IllegalArgumentException(
"The number of vector dimensions does not match the field type");
}
fieldsData = vector;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.lucene.document;

import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
Expand Down Expand Up @@ -101,7 +102,7 @@ public static Query newVectorQuery(String field, float[] queryVector, int k) {
public KnnFloatVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
fieldsData = VectorUtil.checkFinite(vector); // null check done above
}

/**
Expand Down Expand Up @@ -137,7 +138,12 @@ public KnnFloatVectorField(String name, float[] vector, FieldType fieldType) {
+ " using float[] but the field encoding is "
+ fieldType.vectorEncoding());
}
fieldsData = vector;
Objects.requireNonNull(vector, "vector value must not be null");
if (vector.length != fieldType.vectorDimension()) {
throw new IllegalArgumentException(
"The number of vector dimensions does not match the field type");
}
fieldsData = VectorUtil.checkFinite(vector);
}

/** Return the vector value of this field */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ abstract class AbstractKnnVectorQuery extends Query {
private final Query filter;

public AbstractKnnVectorQuery(String field, int k, Query filter) {
this.field = field;
this.field = Objects.requireNonNull(field, "field");
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public KnnByteVectorQuery(String field, byte[] target, int k) {
*/
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
super(field, k, filter);
this.target = target;
this.target = Objects.requireNonNull(target, "target");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;

/**
* Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest
Expand Down Expand Up @@ -70,7 +72,7 @@ public KnnFloatVectorQuery(String field, float[] target, int k) {
*/
public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) {
super(field, k, filter);
this.target = target;
this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target"));
}

@Override
Expand Down
28 changes: 25 additions & 3 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ public static float dotProduct(float[] a, float[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.dotProduct(a, b);
float r = PROVIDER.dotProduct(a, b);
assert Float.isFinite(r);
return r;
}

/**
Expand All @@ -46,7 +48,9 @@ public static float cosine(float[] a, float[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.cosine(a, b);
float r = PROVIDER.cosine(a, b);
assert Float.isFinite(r);
return r;
}

/** Returns the cosine similarity between the two vectors. */
Expand All @@ -66,7 +70,9 @@ public static float squareDistance(float[] a, float[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.squareDistance(a, b);
float r = PROVIDER.squareDistance(a, b);
assert Float.isFinite(r);
return r;
}

/** Returns the sum of squared differences of the two vectors. */
Expand Down Expand Up @@ -154,4 +160,20 @@ public static float dotProductScore(byte[] a, byte[] b) {
float denom = (float) (a.length * (1 << 15));
return 0.5f + dotProduct(a, b) / denom;
}

/**
* Checks if a float vector only has finite components.
*
* @param v bytes containing a vector
* @return the vector for call-chaining
* @throws IllegalArgumentException if any component of vector is not finite
*/
public static float[] checkFinite(float[] v) {
for (int i = 0; i < v.length; i++) {
if (!Float.isFinite(v[i])) {
throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]);
}
}
return v;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public float cosine(float[] a, float[] b) {
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
}
return (float) (sum / Math.sqrt(norm1 * norm2));
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ public void addInOrder(int newNode, float newScore) {
float previousScore = score[size - 1];
assert ((scoresDescOrder && (previousScore >= newScore))
|| (scoresDescOrder == false && (previousScore <= newScore)))
: "Nodes are added in the incorrect order!";
: "Nodes are added in the incorrect order! Comparing "
+ newScore
+ " to "
+ Arrays.toString(ArrayUtil.copyOfSubArray(score, 0, size));
}
node[size] = newNode;
score[size] = newScore;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ public float cosine(float[] a, float[] b) {
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
}
return (float) (sum / Math.sqrt(norm1 * norm2));
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TestVectorUtil;

/**
* Test that uses a default/lucene Implementation of {@link QueryTimeout} to exit out long running
Expand Down Expand Up @@ -463,13 +464,21 @@ public void testVectorValues() throws IOException {
ExitingReaderException.class,
() ->
leaf.searchNearestVectors(
"vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE));
"vector",
TestVectorUtil.randomVector(dimension),
5,
leaf.getLiveDocs(),
Integer.MAX_VALUE));
} else {
DocIdSetIterator iter = leaf.getFloatVectorValues("vector");
scanAndRetrieve(leaf, iter);

leaf.searchNearestVectors(
"vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE);
"vector",
TestVectorUtil.randomVector(dimension),
5,
leaf.getLiveDocs(),
Integer.MAX_VALUE);
}

reader.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public void testScoresDescOrder() {
neighbors.addInOrder(1, 0.8f);

AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.9f));
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage();

neighbors.insertSorted(3, 0.9f);
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
Expand Down Expand Up @@ -76,7 +76,7 @@ public void testScoresAscOrder() {
neighbors.addInOrder(1, 0.3f);

AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.15f));
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage();

neighbors.insertSorted(3, 0.3f);
assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors);
Expand Down

0 comments on commit 071461e

Please sign in to comment.