Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose FlatVectorsFormat #13469

Merged
merged 7 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
null);
}

@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}

static class Lucene99RWScalarQuantizedFormat extends Lucene99ScalarQuantizedVectorsFormat {
private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
Expand Down
3 changes: 2 additions & 1 deletion lucene/core/src/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@
org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
provides org.apache.lucene.codecs.KnnVectorsFormat with
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat,
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat,
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
provides org.apache.lucene.codecs.PostingsFormat with
org.apache.lucene.codecs.lucene99.Lucene99PostingsFormat;
provides org.apache.lucene.index.SortFieldProvider with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.lucene.codecs.hnsw;

import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
Expand All @@ -27,14 +28,23 @@
*
* @lucene.experimental
*/
public abstract class FlatVectorsFormat {
public abstract class FlatVectorsFormat extends KnnVectorsFormat {

/** Sole constructor */
protected FlatVectorsFormat() {}
protected FlatVectorsFormat(String name) {
super(name);
}

/** Returns a {@link FlatVectorsWriter} to write the vectors to the index. */
@Override
public abstract FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException;

/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
@Override
public abstract FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException;

@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@

package org.apache.lucene.codecs.hnsw;

import java.io.Closeable;
import java.io.IOException;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomVectorScorer;

/**
Expand All @@ -39,7 +38,7 @@
*
* @lucene.experimental
*/
public abstract class FlatVectorsReader implements Closeable, Accountable {
public abstract class FlatVectorsReader extends KnnVectorsReader implements Accountable {

/** Scorer for flat vectors */
protected final FlatVectorsScorer vectorScorer;
Expand All @@ -56,6 +55,18 @@ public FlatVectorsScorer getFlatVectorScorer() {
return vectorScorer;
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
// don't scan stored field data. If we didn't index it, produce no search results
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
// don't scan stored field data. If we didn't index it, produce no search results
}

/**
* Returns a {@link RandomVectorScorer} for the given field and target vector.
*
Expand All @@ -77,28 +88,4 @@ public abstract RandomVectorScorer getRandomVectorScorer(String field, float[] t
*/
public abstract RandomVectorScorer getRandomVectorScorer(String field, byte[] target)
throws IOException;

/**
* Checks consistency of this reader.
*
* <p>Note that this may be costly in terms of I/O, e.g. may involve computing a checksum value
* against large data files.
*
* @lucene.internal
*/
public abstract void checkIntegrity() throws IOException;

/**
* Returns the {@link FloatVectorValues} for the given {@code field}. The behavior is undefined if
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
* never {@code null}.
*/
public abstract FloatVectorValues getFloatVectorValues(String field) throws IOException;

/**
* Returns the {@link ByteVectorValues} for the given {@code field}. The behavior is undefined if
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
* never {@code null}.
*/
public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,19 @@

package org.apache.lucene.codecs.hnsw;

import java.io.Closeable;
import java.io.IOException;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;

/**
* Vectors' writer for a field that allows additional indexing logic to be implemented by the caller
*
* @lucene.experimental
*/
public abstract class FlatVectorsWriter implements Accountable, Closeable {
public abstract class FlatVectorsWriter extends KnnVectorsWriter {
/** Scorer for flat vectors */
protected final FlatVectorsScorer vectorsScorer;

Expand Down Expand Up @@ -60,6 +57,11 @@ public FlatVectorsScorer getFlatVectorScorer() {
public abstract FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException;

@Override
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return addField(fieldInfo, null);
}

/**
* Write the field for merging, providing a scorer over the newly merged flat vectors. This way
* any additional merging logic can be implemented by the user of this class.
Expand All @@ -72,15 +74,4 @@ public abstract FlatFieldVectorsWriter<?> addField(
*/
public abstract CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
FieldInfo fieldInfo, MergeState mergeState) throws IOException;

/** Write field for merging */
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
IOUtils.close(mergeOneFieldToIndex(fieldInfo, mergeState));
}

/** Called once at the end before close */
public abstract void finish() throws IOException;

/** Flush all buffered data on disk * */
public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
*/
public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat {

static final String NAME = "Lucene99FlatVectorsFormat";
static final String META_CODEC_NAME = "Lucene99FlatVectorsFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "Lucene99FlatVectorsFormatData";
static final String META_EXTENSION = "vemf";
Expand All @@ -80,6 +81,7 @@ public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat {

/** Constructs a format */
public Lucene99FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
super(NAME);
this.vectorsScorer = vectorsScorer;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ public FlatFieldVectorsWriter<?> addField(
return newField;
}

@Override
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return addField(fieldInfo, null);
}

@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter<?> field : fields) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public Lucene99ScalarQuantizedVectorsFormat() {
*/
public Lucene99ScalarQuantizedVectorsFormat(
Float confidenceInterval, int bits, boolean compress) {
super(NAME);
if (confidenceInterval != null
&& confidenceInterval != DYNAMIC_CONFIDENCE_INTERVAL
&& (confidenceInterval < MINIMUM_CONFIDENCE_INTERVAL
Expand Down
28 changes: 22 additions & 6 deletions lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.codecs.FieldsProducer;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.DocumentStoredFieldVisitor;
import org.apache.lucene.index.CheckIndex.Status.DocValuesStatus;
Expand Down Expand Up @@ -2739,6 +2742,14 @@ public static Status.VectorValuesStatus testVectors(
return status;
}

private static boolean vectorsReaderSupportsSearch(CodecReader codecReader, String fieldName) {
KnnVectorsReader vectorsReader = codecReader.getVectorReader();
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader) {
vectorsReader = perFieldReader.getFieldReader(fieldName);
}
return (vectorsReader instanceof FlatVectorsReader) == false;
}

private static void checkFloatVectorValues(
FloatVectorValues values,
FieldInfo fieldInfo,
Expand All @@ -2751,11 +2762,15 @@ private static void checkFloatVectorValues(
// search the first maxNumSearches vectors to exercise the graph
if (values.docID() % everyNdoc == 0) {
KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE);
codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null);
TopDocs docs = collector.topDocs();
if (docs.scoreDocs.length == 0) {
throw new CheckIndexException(
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
if (vectorsReaderSupportsSearch(codecReader, fieldInfo.name)) {
codecReader
.getVectorReader()
.search(fieldInfo.name, values.vectorValue(), collector, null);
TopDocs docs = collector.topDocs();
if (docs.scoreDocs.length == 0) {
throw new CheckIndexException(
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
}
}
}
int valueLength = values.vectorValue().length;
Expand Down Expand Up @@ -2791,9 +2806,10 @@ private static void checkByteVectorValues(
throws IOException {
int docCount = 0;
int everyNdoc = Math.max(values.size() / 64, 1);
boolean supportsSearch = vectorsReaderSupportsSearch(codecReader, fieldInfo.name);
while (values.nextDoc() != NO_MORE_DOCS) {
// search the first maxNumSearches vectors to exercise the graph
if (values.docID() % everyNdoc == 0) {
if (supportsSearch && values.docID() % everyNdoc == 0) {
KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE);
codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null);
TopDocs docs = collector.topDocs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@

org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat
Loading