Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checks in KNNVectorField / KNNVectorQuery to only allow non-null, non-empty and finite vectors #12281

Merged
merged 8 commits into from
Jun 13, 2023
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ Improvements
* GITHUB#12320: Add "direct to binary" option for DaciukMihovAutomatonBuilder and use it in TermInSetQuery#visit.
(Greg Miller)

* GITHUB#12281: Require indexed KNN float vectors and query vectors to be finite. (Jonathan Ellis, Uwe Schindler)

Optimizations
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.lucene.document;

import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
Expand Down Expand Up @@ -100,7 +101,7 @@ public static FieldType createFieldType(
public KnnByteVectorField(
String name, byte[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
fieldsData = vector; // null-check done above
}

/**
Expand Down Expand Up @@ -136,6 +137,11 @@ public KnnByteVectorField(String name, byte[] vector, FieldType fieldType) {
+ " using byte[] but the field encoding is "
+ fieldType.vectorEncoding());
}
Objects.requireNonNull(vector, "vector value must not be null");
if (vector.length != fieldType.vectorDimension()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

throw new IllegalArgumentException(
"The number of vector dimensions does not match the field type");
}
fieldsData = vector;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.lucene.document;

import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
Expand Down Expand Up @@ -101,7 +102,7 @@ public static Query newVectorQuery(String field, float[] queryVector, int k) {
public KnnFloatVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
fieldsData = VectorUtil.checkFinite(vector); // null check done above
}

/**
Expand Down Expand Up @@ -137,7 +138,12 @@ public KnnFloatVectorField(String name, float[] vector, FieldType fieldType) {
+ " using float[] but the field encoding is "
+ fieldType.vectorEncoding());
}
fieldsData = vector;
Objects.requireNonNull(vector, "vector value must not be null");
if (vector.length != fieldType.vectorDimension()) {
throw new IllegalArgumentException(
"The number of vector dimensions does not match the field type");
}
fieldsData = VectorUtil.checkFinite(vector);
}

/** Return the vector value of this field */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ abstract class AbstractKnnVectorQuery extends Query {
private final Query filter;

public AbstractKnnVectorQuery(String field, int k, Query filter) {
this.field = field;
this.field = Objects.requireNonNull(field, "field");
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public KnnByteVectorQuery(String field, byte[] target, int k) {
*/
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
super(field, k, filter);
this.target = target;
this.target = Objects.requireNonNull(target, "target");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;

/**
* Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest
Expand Down Expand Up @@ -70,7 +72,7 @@ public KnnFloatVectorQuery(String field, float[] target, int k) {
*/
public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) {
super(field, k, filter);
this.target = target;
this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target"));
}

@Override
Expand Down
28 changes: 25 additions & 3 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ public static float dotProduct(float[] a, float[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.dotProduct(a, b);
float r = PROVIDER.dotProduct(a, b);
assert Float.isFinite(r);
return r;
}

/**
Expand All @@ -46,7 +48,9 @@ public static float cosine(float[] a, float[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.cosine(a, b);
float r = PROVIDER.cosine(a, b);
assert Float.isFinite(r);
return r;
}

/** Returns the cosine similarity between the two vectors. */
Expand All @@ -66,7 +70,9 @@ public static float squareDistance(float[] a, float[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.squareDistance(a, b);
float r = PROVIDER.squareDistance(a, b);
assert Float.isFinite(r);
return r;
}

/** Returns the sum of squared differences of the two vectors. */
Expand Down Expand Up @@ -154,4 +160,20 @@ public static float dotProductScore(byte[] a, byte[] b) {
float denom = (float) (a.length * (1 << 15));
return 0.5f + dotProduct(a, b) / denom;
}

/**
* Checks if a float vector only has finite components.
*
* @param v bytes containing a vector
* @return the vector for call-chaining
* @throws IllegalArgumentException if any component of vector is not finite
*/
public static float[] checkFinite(float[] v) {
for (int i = 0; i < v.length; i++) {
if (!Float.isFinite(v[i])) {
throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]);
}
}
return v;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public float cosine(float[] a, float[] b) {
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
}
return (float) (sum / Math.sqrt(norm1 * norm2));
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ public void addInOrder(int newNode, float newScore) {
float previousScore = score[size - 1];
assert ((scoresDescOrder && (previousScore >= newScore))
|| (scoresDescOrder == false && (previousScore <= newScore)))
: "Nodes are added in the incorrect order!";
: "Nodes are added in the incorrect order! Comparing "
+ newScore
+ " to "
+ Arrays.toString(ArrayUtil.copyOfSubArray(score, 0, size));
}
node[size] = newNode;
score[size] = newScore;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ public float cosine(float[] a, float[] b) {
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
}
return (float) (sum / Math.sqrt(norm1 * norm2));
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TestVectorUtil;

/**
* Test that uses a default/lucene Implementation of {@link QueryTimeout} to exit out long running
Expand Down Expand Up @@ -463,13 +464,21 @@ public void testVectorValues() throws IOException {
ExitingReaderException.class,
() ->
leaf.searchNearestVectors(
"vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE));
"vector",
TestVectorUtil.randomVector(dimension),
5,
leaf.getLiveDocs(),
Integer.MAX_VALUE));
} else {
DocIdSetIterator iter = leaf.getFloatVectorValues("vector");
scanAndRetrieve(leaf, iter);

leaf.searchNearestVectors(
"vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE);
"vector",
TestVectorUtil.randomVector(dimension),
5,
leaf.getLiveDocs(),
Integer.MAX_VALUE);
}

reader.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public void testScoresDescOrder() {
neighbors.addInOrder(1, 0.8f);

AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.9f));
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage();

neighbors.insertSorted(3, 0.9f);
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
Expand Down Expand Up @@ -76,7 +76,7 @@ public void testScoresAscOrder() {
neighbors.addInOrder(1, 0.3f);

AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.15f));
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage();

neighbors.insertSorted(3, 0.3f);
assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors);
Expand Down