From 340ab8705494ee622598ab40d42015879badce8a Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Wed, 14 Jun 2023 15:41:07 -0500 Subject: [PATCH] rename checkFinite to checkInBounds and require that float vector components are smaller than 1E17 to prevent overflowing to Infinity --- .../apache/lucene/document/KnnFloatVectorField.java | 4 ++-- .../org/apache/lucene/index/FloatVectorValues.java | 9 +++++++++ .../apache/lucene/search/KnnFloatVectorQuery.java | 2 +- .../src/java/org/apache/lucene/util/VectorUtil.java | 11 +++++++++-- .../test/org/apache/lucene/util/TestVectorUtil.java | 13 +++++++++++++ .../valuesource/ConstKnnFloatValueSource.java | 2 +- 6 files changed, 35 insertions(+), 6 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java index 9d1cd02c013e..4a5f770bd76a 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java @@ -102,7 +102,7 @@ public static Query newVectorQuery(String field, float[] queryVector, int k) { public KnnFloatVectorField( String name, float[] vector, VectorSimilarityFunction similarityFunction) { super(name, createType(vector, similarityFunction)); - fieldsData = VectorUtil.checkFinite(vector); // null check done above + fieldsData = VectorUtil.checkInBounds(vector); // null check done above } /** @@ -143,7 +143,7 @@ public KnnFloatVectorField(String name, float[] vector, FieldType fieldType) { throw new IllegalArgumentException( "The number of vector dimensions does not match the field type"); } - fieldsData = VectorUtil.checkFinite(vector); + fieldsData = VectorUtil.checkInBounds(vector); } /** Return the vector value of this field */ diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 9de6b57531e2..02eba71763be 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -31,6 +31,15 @@ public abstract class FloatVectorValues extends DocIdSetIterator { /** The maximum length of a vector */ public static final int MAX_DIMENSIONS = 1024; + /** + * This is the largest float vector value that we allow. + * + *

The largest float32 that you can square without overflowing is about 1.8E19. We reduce that + * further to accommodate the addition of multiple such components in similarity computations in + * vectors up to MAX_DIMENSIONS in length. + */ + public static final float MAX_FLOAT32_COMPONENT = 1E17f; + /** Sole constructor */ protected FloatVectorValues() {} diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 3036e7c45162..4a373a1e7efa 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -72,7 +72,7 @@ public KnnFloatVectorQuery(String field, float[] target, int k) { */ public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) { super(field, k, filter); - this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target")); + this.target = VectorUtil.checkInBounds(Objects.requireNonNull(target, "target")); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index c9e1d368334e..b1cd0f6d0768 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -17,6 +17,8 @@ package org.apache.lucene.util; +import org.apache.lucene.index.FloatVectorValues; + /** Utilities for computations with numeric arrays */ public final class VectorUtil { @@ -162,17 +164,22 @@ public static float dotProductScore(byte[] a, byte[] b) { } /** - * Checks if a float vector only has finite components. + * Checks if a float vector only has components with absolute value less than + * MAX_FLOAT32_COMPONENT. NaN is not allowed. * * @param v bytes containing a vector * @return the vector for call-chaining * @throws IllegalArgumentException if any component of vector is not finite */ - public static float[] checkFinite(float[] v) { + public static float[] checkInBounds(float[] v) { for (int i = 0; i < v.length; i++) { if (!Float.isFinite(v[i])) { throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]); } + + if (Math.abs(v[i]) > FloatVectorValues.MAX_FLOAT32_COMPONENT) { + throw new IllegalArgumentException("Out-of-bounds value at vector[" + i + "]=" + v[i]); + } } return v; } diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java index 358db95641f2..3fb93dce0278 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java @@ -16,7 +16,11 @@ */ package org.apache.lucene.util; +import static org.apache.lucene.index.FloatVectorValues.MAX_FLOAT32_COMPONENT; + +import java.util.Arrays; import java.util.Random; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -262,4 +266,13 @@ public void testOrthogonalCosineBytes() { u[1] = -v[0]; assertEquals(0, VectorUtil.cosine(u, v), DELTA); } + + public void testLargeVectorSimilarities() { + float[] v = new float[FloatVectorValues.MAX_DIMENSIONS]; + Arrays.fill(v, MAX_FLOAT32_COMPONENT); + + assertTrue(Float.isFinite(VectorUtil.cosine(v, v))); + assertTrue(Float.isFinite(VectorUtil.dotProduct(v, v))); + assertTrue(Float.isFinite(VectorUtil.squareDistance(v, v))); + } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java index 57c016eb793e..9c49d1ca2c38 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java @@ -30,7 +30,7 @@ public class ConstKnnFloatValueSource extends ValueSource { private final float[] vector; public ConstKnnFloatValueSource(float[] constVector) { - this.vector = VectorUtil.checkFinite(Objects.requireNonNull(constVector, "constVector")); + this.vector = VectorUtil.checkInBounds(Objects.requireNonNull(constVector, "constVector")); } @Override