Skip to content

Commit

Permalink
Fix vector scorer interface consistency (#13365)
Browse files Browse the repository at this point in the history
Follow up to: #13181

I noticed the quantized interface had a slightly different name.

Additionally, testing showed we are inconsistent when there aren't any vectors to score. This makes the response consistent (e.g. null when there aren't any vectors).
  • Loading branch information
benwtrent committed May 13, 2024
1 parent fd98698 commit 8580d50
Show file tree
Hide file tree
Showing 13 changed files with 33 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ public Bits getAcceptOrds(Bits acceptDocs) {

@Override
public VectorScorer scorer(byte[] query) {
throw new UnsupportedOperationException();
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ public Bits getAcceptOrds(Bits acceptDocs) {

@Override
public VectorScorer scorer(float[] query) {
throw new UnsupportedOperationException();
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ public int advance(int target) throws IOException {

@Override
public VectorScorer scorer(float[] query) throws IOException {
return quantizedVectorValues.vectorScorer(query);
return quantizedVectorValues.scorer(query);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ public float getScoreCorrectionConstant() throws IOException {
}

@Override
public VectorScorer vectorScorer(float[] target) throws IOException {
public VectorScorer scorer(float[] target) throws IOException {
throw new UnsupportedOperationException();
}
}
Expand Down Expand Up @@ -1097,7 +1097,7 @@ public int advance(int target) throws IOException {
}

@Override
public VectorScorer vectorScorer(float[] target) throws IOException {
public VectorScorer scorer(float[] target) throws IOException {
throw new UnsupportedOperationException();
}

Expand Down Expand Up @@ -1203,7 +1203,7 @@ public int advance(int target) throws IOException {
}

@Override
public VectorScorer vectorScorer(float[] target) throws IOException {
public VectorScorer scorer(float[] target) throws IOException {
throw new UnsupportedOperationException();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ public Bits getAcceptOrds(Bits acceptDocs) {
}

@Override
public VectorScorer vectorScorer(float[] target) throws IOException {
public VectorScorer scorer(float[] target) throws IOException {
DenseOffHeapVectorValues copy = copy();
RandomVectorScorer vectorScorer =
vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
Expand Down Expand Up @@ -370,7 +370,7 @@ public int length() {
}

@Override
public VectorScorer vectorScorer(float[] target) throws IOException {
public VectorScorer scorer(float[] target) throws IOException {
SparseOffHeapVectorValues copy = copy();
RandomVectorScorer vectorScorer =
vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
Expand Down Expand Up @@ -457,8 +457,8 @@ public Bits getAcceptOrds(Bits acceptDocs) {
}

@Override
public VectorScorer vectorScorer(float[] target) {
throw new UnsupportedOperationException();
public VectorScorer scorer(float[] target) {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public static void checkField(LeafReader in, String field) {
* iteration over the scorer will not affect the iteration of this {@link ByteVectorValues}.
*
* @param query the query vector
* @return a {@link VectorScorer} instance
* @return a {@link VectorScorer} instance or null
*/
public abstract VectorScorer scorer(byte[] query) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public static void checkField(LeafReader in, String field) {
* iteration of this {@link FloatVectorValues}.
*
* @param query the query vector
* @return a {@link VectorScorer} instance
* @return a {@link VectorScorer} instance or null
*/
public abstract VectorScorer scorer(float[] query) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ static VectorSimilarityScorer fromAcceptDocs(
VectorScorer scorer,
DocIdSetIterator acceptDocs,
float threshold) {
if (scorer == null) {
return null;
}
float[] cachedScore = new float[1];
DocIdSetIterator vectorIterator = scorer.iterator();
DocIdSetIterator conjunction =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,6 @@ public FloatVectorSimilarityValuesSource(float[] vector, String fieldName) {
this.queryVector = vector;
}

@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
if (vectorValues == null) {
FloatVectorValues.checkField(ctx.reader(), fieldName);
return DoubleValues.EMPTY;
}
return new DoubleValues() {
private final VectorScorer scorer = vectorValues.scorer(queryVector);
private final DocIdSetIterator iterator = scorer.iterator();

@Override
public double doubleValue() throws IOException {
return scorer.score();
}

@Override
public boolean advanceExact(int doc) throws IOException {
return doc >= iterator.docID() && (iterator.docID() == doc || iterator.advance(doc) == doc);
}
};
}

@Override
public VectorScorer getScorer(LeafReaderContext ctx) throws IOException {
final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public final long cost() {
* Return a {@link VectorScorer} for the given query vector.
*
* @param query the query vector
* @return a {@link VectorScorer} instance
* @return a {@link VectorScorer} instance or null
*/
public abstract VectorScorer vectorScorer(float[] query) throws IOException;
public abstract VectorScorer scorer(float[] query) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ protected TopDocs exactSearch(
}

VectorScorer scorer = byteVectorValues.scorer(query);
if (scorer == null) {
return NO_RESULTS;
}
DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer vectorScorer =
new DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer(
acceptIterator, parentBitSet, scorer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,13 @@ protected TopDocs exactSearch(
if (parentBitSet == null) {
return NO_RESULTS;
}
VectorScorer floatVectorScorer = floatVectorValues.scorer(query);
if (floatVectorScorer == null) {
return NO_RESULTS;
}

DiversifyingChildrenVectorScorer vectorScorer =
new DiversifyingChildrenVectorScorer(
acceptIterator, parentBitSet, floatVectorValues.scorer(query));
new DiversifyingChildrenVectorScorer(acceptIterator, parentBitSet, floatVectorScorer);
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
HitQueue queue = new HitQueue(queueSize, true);
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,10 @@ public void testFloatVectorScorerIteration() throws Exception {
if (vectorValues == null) {
continue;
}
if (vectorValues.size() == 0) {
assertNull(vectorValues.scorer(vectorToScore));
continue;
}
VectorScorer scorer = vectorValues.scorer(vectorToScore);
assertNotNull(scorer);
DocIdSetIterator iterator = scorer.iterator();
Expand Down Expand Up @@ -818,6 +822,10 @@ public void testByteVectorScorerIteration() throws Exception {
if (vectorValues == null) {
continue;
}
if (vectorValues.size() == 0) {
assertNull(vectorValues.scorer(vectorToScore));
continue;
}
VectorScorer scorer = vectorValues.scorer(vectorToScore);
assertNotNull(scorer);
DocIdSetIterator iterator = scorer.iterator();
Expand Down

0 comments on commit 8580d50

Please sign in to comment.