Skip to content

Commit

Permalink
revert changes to VUDefaultProvider; add checkFinite to VectorUtil in…
Browse files Browse the repository at this point in the history
…stead, and call from KFVF.createType
  • Loading branch information
jbellis committed Jun 10, 2023
1 parent c9b3cb8 commit 113c4ad
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.lucene.document;

import static org.apache.lucene.util.VectorUtil.checkFinite;

import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
Expand Down Expand Up @@ -50,6 +52,7 @@ private static FieldType createType(float[] v, VectorSimilarityFunction similari
throw new IllegalArgumentException(
"cannot index vectors with dimension greater than " + FloatVectorValues.MAX_DIMENSIONS);
}
checkFinite(v);
if (similarityFunction == null) {
throw new IllegalArgumentException("similarity function must not be null");
}
Expand Down
38 changes: 35 additions & 3 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.lucene.util;

import java.util.Arrays;

/** Utilities for computations with numeric arrays */
public final class VectorUtil {

Expand All @@ -34,7 +36,9 @@ public static float dotProduct(float[] a, float[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.dotProduct(a, b);
float r = PROVIDER.dotProduct(a, b);
checkFinite(r, a, b, "dot product");
return r;
}

/**
Expand All @@ -46,7 +50,9 @@ public static float cosine(float[] a, float[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.cosine(a, b);
float r = PROVIDER.cosine(a, b);
checkFinite(r, a, b, "dot product");
return r;
}

/** Returns the cosine similarity between the two vectors. */
Expand All @@ -66,7 +72,9 @@ public static float squareDistance(float[] a, float[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return PROVIDER.squareDistance(a, b);
float r = PROVIDER.squareDistance(a, b);
checkFinite(r, a, b, "square distance");
return r;
}

/** Returns the sum of squared differences of the two vectors. */
Expand Down Expand Up @@ -154,4 +162,28 @@ public static float dotProductScore(byte[] a, byte[] b) {
float denom = (float) (a.length * (1 << 15));
return 0.5f + dotProduct(a, b) / denom;
}

private static void checkFinite(float r, float[] a, float[] b, String optype) {
if (!Float.isFinite(r)) {
checkFinite(a);
checkFinite(b);
throw new IllegalArgumentException(
"Non-finite ("
+ r
+ ") "
+ optype
+ " similarity from "
+ Arrays.toString(a)
+ " and "
+ Arrays.toString(b));
}
}

public static void checkFinite(float[] a) {
for (int i = 0; i < a.length; i++) {
if (!Float.isFinite(a[i])) {
throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + a[i]);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.lucene.util;

import java.util.Arrays;

/** The default VectorUtil provider implementation. */
final class VectorUtilDefaultProvider implements VectorUtilProvider {

Expand Down Expand Up @@ -87,7 +85,6 @@ public float dotProduct(float[] a, float[] b) {
+ b[i + 6] * a[i + 6]
+ b[i + 7] * a[i + 7];
}
checkFinite(res, a, b, "dot product");
return res;
}

Expand All @@ -105,9 +102,7 @@ public float cosine(float[] a, float[] b) {
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
}
var r = (float) (sum / Math.sqrt(norm1 * norm2));
checkFinite(r, a, b, "cosine");
return r;
return (float) (sum / Math.sqrt(norm1 * norm2));
}

@Override
Expand All @@ -122,32 +117,9 @@ public float squareDistance(float[] a, float[] b) {
float diff = a[i] - b[i];
squareSum += diff * diff;
}
checkFinite(squareSum, a, b, "square distance");
return squareSum;
}

private static void checkFinite(float r, float[] a, float[] b, String optype) {
if (!Float.isFinite(r)) {
for (int i = 0; i < a.length; i++) {
if (!Float.isFinite(a[i])) {
throw new IllegalArgumentException("v1[" + i + "]=" + a[i]);
}
if (!Float.isFinite(b[i])) {
throw new IllegalArgumentException("v2[" + i + "]=" + b[i]);
}
}
throw new IllegalArgumentException(
"Non-finite ("
+ r
+ ") "
+ optype
+ " similarity from "
+ Arrays.toString(a)
+ " and "
+ Arrays.toString(b));
}
}

private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
float diff0 = v1[index + 0] - v2[index + 0];
float diff1 = v1[index + 1] - v2[index + 1];
Expand Down

0 comments on commit 113c4ad

Please sign in to comment.