Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checks in KNNVectorField / KNNVectorQuery to only allow non-null, non-empty and finite vectors #12281

Merged
merged 8 commits into from
Jun 13, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.lucene.document;

import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
Expand Down Expand Up @@ -100,7 +101,7 @@ public static FieldType createFieldType(
public KnnByteVectorField(
String name, byte[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
fieldsData = vector; // null-check done above
}

/**
Expand Down Expand Up @@ -136,7 +137,7 @@ public KnnByteVectorField(String name, byte[] vector, FieldType fieldType) {
+ " using byte[] but the field encoding is "
+ fieldType.vectorEncoding());
}
fieldsData = vector;
fieldsData = Objects.requireNonNull(vector, "vector value must not be null");
}

/** Return the vector value of this field */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.lucene.document;

import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
Expand Down Expand Up @@ -101,7 +102,7 @@ public static Query newVectorQuery(String field, float[] queryVector, int k) {
public KnnFloatVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
fieldsData = VectorUtil.checkFinite(vector); // null check done above
}

/**
Expand Down Expand Up @@ -137,7 +138,8 @@ public KnnFloatVectorField(String name, float[] vector, FieldType fieldType) {
+ " using float[] but the field encoding is "
+ fieldType.vectorEncoding());
}
fieldsData = vector;
fieldsData =
VectorUtil.checkFinite(Objects.requireNonNull(vector, "vector value must not be null"));
}

/** Return the vector value of this field */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ abstract class AbstractKnnVectorQuery extends Query {
private final Query filter;

public AbstractKnnVectorQuery(String field, int k, Query filter) {
this.field = field;
this.field = Objects.requireNonNull(field, "field");
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public KnnByteVectorQuery(String field, byte[] target, int k) {
*/
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
super(field, k, filter);
this.target = target;
this.target = Objects.requireNonNull(target, "target");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;

/**
* Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest
Expand Down Expand Up @@ -70,7 +72,7 @@ public KnnFloatVectorQuery(String field, float[] target, int k) {
*/
public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) {
super(field, k, filter);
this.target = target;
this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target"));
}

@Override
Expand Down
27 changes: 24 additions & 3 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ public static float dotProduct(float[] a, float[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.dotProduct(a, b);
float r = PROVIDER.dotProduct(a, b);
assert Float.isFinite(r);
return r;
}

/**
Expand All @@ -46,7 +48,9 @@ public static float cosine(float[] a, float[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.cosine(a, b);
float r = PROVIDER.cosine(a, b);
assert Float.isFinite(r);
return r;
}

/** Returns the cosine similarity between the two vectors. */
Expand All @@ -66,7 +70,9 @@ public static float squareDistance(float[] a, float[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.squareDistance(a, b);
float r = PROVIDER.squareDistance(a, b);
assert Float.isFinite(r);
return r;
}

/** Returns the sum of squared differences of the two vectors. */
Expand Down Expand Up @@ -154,4 +160,19 @@ public static float dotProductScore(byte[] a, byte[] b) {
float denom = (float) (a.length * (1 << 15));
return 0.5f + dotProduct(a, b) / denom;
}

/**
* Checks if a float vector only has finite components.
*
* @param v bytes containing a vector
* @return the vector for call-chaining
*/
public static float[] checkFinite(float[] v) {
for (int i = 0; i < v.length; i++) {
if (!Float.isFinite(v[i])) {
throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]);
}
}
return v;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ public void addInOrder(int newNode, float newScore) {
float previousScore = score[size - 1];
assert ((scoresDescOrder && (previousScore >= newScore))
|| (scoresDescOrder == false && (previousScore <= newScore)))
: "Nodes are added in the incorrect order!";
: "Nodes are added in the incorrect order! Comparing "
+ newScore
+ " to "
+ Arrays.toString(ArrayUtil.copyOfSubArray(score, 0, size));
}
node[size] = newNode;
score[size] = newScore;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TestVectorUtil;

/**
* Test that uses a default/lucene Implementation of {@link QueryTimeout} to exit out long running
Expand Down Expand Up @@ -463,13 +464,21 @@ public void testVectorValues() throws IOException {
ExitingReaderException.class,
() ->
leaf.searchNearestVectors(
"vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE));
"vector",
TestVectorUtil.randomVector(dimension),
5,
leaf.getLiveDocs(),
Integer.MAX_VALUE));
} else {
DocIdSetIterator iter = leaf.getFloatVectorValues("vector");
scanAndRetrieve(leaf, iter);

leaf.searchNearestVectors(
"vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE);
"vector",
TestVectorUtil.randomVector(dimension),
5,
leaf.getLiveDocs(),
Integer.MAX_VALUE);
}

reader.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public void testScoresDescOrder() {
neighbors.addInOrder(1, 0.8f);

AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.9f));
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage();

neighbors.insertSorted(3, 0.9f);
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
Expand Down Expand Up @@ -76,7 +76,7 @@ public void testScoresAscOrder() {
neighbors.addInOrder(1, 0.3f);

AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.15f));
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage();

neighbors.insertSorted(3, 0.3f);
assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors);
Expand Down