Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose FlatVectorsFormat #13469

Merged
merged 7 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,9 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
null);
}

@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}

static class Lucene99RWScalarQuantizedFormat extends Lucene99ScalarQuantizedVectorsFormat {
private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
new Lucene99FlatVectorsFormat("FlatVectorsFormat", new DefaultFlatVectorScorer());

@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
public final class HnswBitVectorsFormat extends KnnVectorsFormat {

public static final String NAME = "HnswBitVectorsFormat";
public static final String NAME_FLAT = "HnswBitVectorsFlatFormat";

/**
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
Expand Down Expand Up @@ -128,7 +129,7 @@ public HnswBitVectorsFormat(
} else {
this.mergeExec = null;
}
this.flatVectorsFormat = new Lucene99FlatVectorsFormat(new FlatBitVectorsScorer());
this.flatVectorsFormat = new Lucene99FlatVectorsFormat(NAME_FLAT, new FlatBitVectorsScorer());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why a unique name here is required.

The top level format name is HnswBitVectorsFormat, and it doesn't really do anything to the inner format, it simply overrides the scoring mechanism.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the subject of the brute-force query I started to look at adding one and found we already have o.a.l.queries.function.valuesource.VectorSimilarityFunction -- I think maybe all we need here is a static convenience method that produces a FunctionQuery wrapping one of these things although I'll confess I'm not entirely conversant with this API. It seems like it might want to be rewritten in terms of VectorScorer.

}

@Override
Expand Down
3 changes: 2 additions & 1 deletion lucene/core/src/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@
org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
provides org.apache.lucene.codecs.KnnVectorsFormat with
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat,
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat,
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
provides org.apache.lucene.codecs.PostingsFormat with
org.apache.lucene.codecs.lucene99.Lucene99PostingsFormat;
provides org.apache.lucene.index.SortFieldProvider with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.lucene.codecs.hnsw;

import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
Expand All @@ -27,14 +28,23 @@
*
* @lucene.experimental
*/
public abstract class FlatVectorsFormat {
public abstract class FlatVectorsFormat extends KnnVectorsFormat {

/** Sole constructor */
protected FlatVectorsFormat() {}
protected FlatVectorsFormat(String name) {
super(name);
}

/** Returns a {@link FlatVectorsWriter} to write the vectors to the index. */
@Override
public abstract FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException;

/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
@Override
public abstract FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException;

@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@

package org.apache.lucene.codecs.hnsw;

import java.io.Closeable;
import java.io.IOException;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.hnsw.RandomVectorScorer;

Expand All @@ -39,7 +36,7 @@
*
* @lucene.experimental
*/
public abstract class FlatVectorsReader implements Closeable, Accountable {
public abstract class FlatVectorsReader extends KnnVectorsReader implements Accountable {

/** Scorer for flat vectors */
protected final FlatVectorsScorer vectorScorer;
Expand Down Expand Up @@ -77,28 +74,4 @@ public abstract RandomVectorScorer getRandomVectorScorer(String field, float[] t
*/
public abstract RandomVectorScorer getRandomVectorScorer(String field, byte[] target)
throws IOException;

/**
* Checks consistency of this reader.
*
* <p>Note that this may be costly in terms of I/O, e.g. may involve computing a checksum value
* against large data files.
*
* @lucene.internal
*/
public abstract void checkIntegrity() throws IOException;

/**
* Returns the {@link FloatVectorValues} for the given {@code field}. The behavior is undefined if
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
* never {@code null}.
*/
public abstract FloatVectorValues getFloatVectorValues(String field) throws IOException;

/**
* Returns the {@link ByteVectorValues} for the given {@code field}. The behavior is undefined if
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
* never {@code null}.
*/
public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,19 @@

package org.apache.lucene.codecs.hnsw;

import java.io.Closeable;
import java.io.IOException;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;

/**
* Vectors' writer for a field that allows additional indexing logic to be implemented by the caller
*
* @lucene.experimental
*/
public abstract class FlatVectorsWriter implements Accountable, Closeable {
public abstract class FlatVectorsWriter extends KnnVectorsWriter {
/** Scorer for flat vectors */
protected final FlatVectorsScorer vectorsScorer;

Expand Down Expand Up @@ -60,6 +57,11 @@ public FlatVectorsScorer getFlatVectorScorer() {
public abstract FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException;

@Override
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return addField(fieldInfo, null);
}

/**
* Write the field for merging, providing a scorer over the newly merged flat vectors. This way
* any additional merging logic can be implemented by the user of this class.
Expand All @@ -72,15 +74,4 @@ public abstract FlatFieldVectorsWriter<?> addField(
*/
public abstract CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
FieldInfo fieldInfo, MergeState mergeState) throws IOException;

/** Write field for merging */
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
IOUtils.close(mergeOneFieldToIndex(fieldInfo, mergeState));
}

/** Called once at the end before close */
public abstract void finish() throws IOException;

/** Flush all buffered data on disk * */
public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat {
private final FlatVectorsScorer vectorsScorer;

/** Constructs a format */
public Lucene99FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
public Lucene99FlatVectorsFormat(String name, FlatVectorsScorer vectorsScorer) {
super(name);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why for Lucene99FlatVectorsFormat we allow an external name to be provided. Is this because it isn't really ever loaded via the SPI?

Seems like Lucene99FlatVectorsFormat shouldn't accept a name as a parameter and simply provide super with the correct name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. TBH I kind of just filled out the blanks with this one without thinking much about it, but I agree the name is really only useful for the SPI interface.

this.vectorsScorer = vectorsScorer;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
Expand Down Expand Up @@ -217,6 +219,18 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
vectorData);
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}

@Override
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
FieldEntry fieldEntry = fields.get(field);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ public FlatFieldVectorsWriter<?> addField(
return newField;
}

@Override
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return addField(fieldInfo, null);
}

@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter<?> field : fields) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {

/** The format for storing, reading, merging vectors on disk */
private static final FlatVectorsFormat flatVectorsFormat =
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
new Lucene99FlatVectorsFormat(
"Lucene99FlatVectorsFormat", FlatVectorScorerUtil.getLucene99FlatVectorsScorer());

private final int numMergeWorkers;
private final TaskExecutor mergeExec;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
static final String VECTOR_DATA_EXTENSION = "veq";

private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
new Lucene99FlatVectorsFormat(
"Lucene99FlatVectorsFormat", FlatVectorScorerUtil.getLucene99FlatVectorsScorer());

/** The minimum confidence interval */
private static final float MINIMUM_CONFIDENCE_INTERVAL = 0.9f;
Expand Down Expand Up @@ -89,6 +90,7 @@ public Lucene99ScalarQuantizedVectorsFormat() {
*/
public Lucene99ScalarQuantizedVectorsFormat(
Float confidenceInterval, int bits, boolean compress) {
super("Lucene99ScalarQuantizedVectorsFormat");
if (confidenceInterval != null
&& confidenceInterval != DYNAMIC_CONFIDENCE_INTERVAL
&& (confidenceInterval < MINIMUM_CONFIDENCE_INTERVAL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
Expand Down Expand Up @@ -189,6 +191,18 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
return rawVectorsReader.getByteVectorValues(field);
}

@Override
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, doing this means the KnnFloatVectorQuery will throw an error if its used against a field with a codec using this reader.

Traditionally, query's were fairly ignorant of the underlying codec used to store the field.

It does seem weird to have a codec that the Knn queries could be executed against throw an exception.

I think either:

  • The knn queries should automatically do the correct things (e.g. not eagerly rewrite themselves, etc.)
  • We don't throw on search like this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The knn queries should automatically do the correct things (e.g. not eagerly rewrite themselves, etc.)

I am not sure how to do this other than having something that says "Supports approximate search" to the leaf readers. Or a try{}catch() in the query, but then we are using exceptions as a query flow, which seems trappy and bad.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did toy with adding a "supportsSearch" flag when this problem surfaced due to BaseKnnVectorsFormatTestCase. Potentially we could return no results. I feel like this is the moral equivalent of a stored Field with IndexOptions.NONE (Like StoredField.TYPE)? But we have tended to not want to make such vector-representation choices a part of the Fields API, so it looks different.

public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}

private static IndexInput openDataInput(
SegmentReadState state,
int versionMeta,
Expand Down
28 changes: 22 additions & 6 deletions lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.codecs.FieldsProducer;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.DocumentStoredFieldVisitor;
import org.apache.lucene.index.CheckIndex.Status.DocValuesStatus;
Expand Down Expand Up @@ -2739,6 +2742,14 @@ public static Status.VectorValuesStatus testVectors(
return status;
}

private static boolean vectorsReaderSupportsSearch(CodecReader codecReader, String fieldName) {
KnnVectorsReader vectorsReader = codecReader.getVectorReader();
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader) {
vectorsReader = perFieldReader.getFieldReader(fieldName);
}
return (vectorsReader instanceof FlatVectorsReader) == false;
}

private static void checkFloatVectorValues(
FloatVectorValues values,
FieldInfo fieldInfo,
Expand All @@ -2751,11 +2762,15 @@ private static void checkFloatVectorValues(
// search the first maxNumSearches vectors to exercise the graph
if (values.docID() % everyNdoc == 0) {
KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE);
codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null);
TopDocs docs = collector.topDocs();
if (docs.scoreDocs.length == 0) {
throw new CheckIndexException(
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
if (vectorsReaderSupportsSearch(codecReader, fieldInfo.name)) {
codecReader
.getVectorReader()
.search(fieldInfo.name, values.vectorValue(), collector, null);
TopDocs docs = collector.topDocs();
if (docs.scoreDocs.length == 0) {
throw new CheckIndexException(
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
}
}
}
int valueLength = values.vectorValue().length;
Expand Down Expand Up @@ -2791,9 +2806,10 @@ private static void checkByteVectorValues(
throws IOException {
int docCount = 0;
int everyNdoc = Math.max(values.size() / 64, 1);
boolean supportsSearch = vectorsReaderSupportsSearch(codecReader, fieldInfo.name);
while (values.nextDoc() != NO_MORE_DOCS) {
// search the first maxNumSearches vectors to exercise the graph
if (values.docID() % everyNdoc == 0) {
if (supportsSearch && values.docID() % everyNdoc == 0) {
KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE);
codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null);
TopDocs docs = collector.topDocs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@

org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat
Loading