Skip to content

Commit

Permalink
Move max vector dims limit to Codec (apache#12436)
Browse files Browse the repository at this point in the history
Move vector max dimension limits enforcement into the default Codec's
KnnVectorsFormat implementation. This allows different implementation
of knn search algorithms define their own limits of a maximum
vector dimensions that they can handle.

Closes apache#12309
  • Loading branch information
mayya-sharipova committed Jul 27, 2023
1 parent 941f897 commit b247afe
Show file tree
Hide file tree
Showing 13 changed files with 164 additions and 44 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ API Changes
* GITHUB#11248: IntBlockPool's SliceReader, SliceWriter, and all int slice functionality are moved out to MemoryIndex.
(Stefan Vodita)

* GITHUB#12436: Move max vector dims limit to Codec (Mayya Sharipova)

New Features
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
*/
public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {

/** The maximum number of vector dimensions */
public static final int DEFAULT_MAX_DIMENSIONS = 1024;

/**
* This static holder class prevents classloading deadlock by delaying init of doc values formats
* until needed.
Expand Down Expand Up @@ -76,6 +79,19 @@ public static KnnVectorsFormat forName(String name) {
/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
public abstract KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException;

/**
* Returns the maximum number of vector dimensions supported by this codec for the given field
* name
*
* <p>Codecs should override this method to specify the maximum number of dimensions they support.
*
* @param fieldName the field name
* @return the maximum number of vector dimensions.
*/
public int getMaxDimensions(String fieldName) {
return DEFAULT_MAX_DIMENSIONS;
}

/**
* EMPTY throws an exception when written. It acts as a sentinel indicating a Codec that does not
* support vectors.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
return new Lucene95HnswVectorsReader(state);
}

@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}

@Override
public String toString() {
return "Lucene95HnswVectorsFormat(name=Lucene95HnswVectorsFormat, maxConn="
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
return new FieldsReader(state);
}

@Override
public int getMaxDimensions(String fieldName) {
return getKnnVectorsFormatForField(fieldName).getMaxDimensions(fieldName);
}

/**
* Returns the numeric vector format that should be used for writing new segments of <code>field
* </code>.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.util.Objects;
import org.apache.lucene.analysis.Analyzer; // javadocs
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.index.PointValues;
Expand Down Expand Up @@ -378,13 +377,6 @@ public void setVectorAttributes(
if (numDimensions <= 0) {
throw new IllegalArgumentException("vector numDimensions must be > 0; got " + numDimensions);
}
if (numDimensions > FloatVectorValues.MAX_DIMENSIONS) {
throw new IllegalArgumentException(
"vector numDimensions must be <= FloatVectorValues.MAX_DIMENSIONS (="
+ FloatVectorValues.MAX_DIMENSIONS
+ "); got "
+ numDimensions);
}
this.vectorDimension = numDimensions;
this.vectorSimilarityFunction = Objects.requireNonNull(similarity);
this.vectorEncoding = Objects.requireNonNull(encoding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ private static FieldType createType(byte[] v, VectorSimilarityFunction similarit
if (dimension == 0) {
throw new IllegalArgumentException("cannot index an empty vector");
}
if (dimension > ByteVectorValues.MAX_DIMENSIONS) {
throw new IllegalArgumentException(
"cannot index vectors with dimension greater than " + ByteVectorValues.MAX_DIMENSIONS);
}
if (similarityFunction == null) {
throw new IllegalArgumentException("similarity function must not be null");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ private static FieldType createType(float[] v, VectorSimilarityFunction similari
if (dimension == 0) {
throw new IllegalArgumentException("cannot index an empty vector");
}
if (dimension > FloatVectorValues.MAX_DIMENSIONS) {
throw new IllegalArgumentException(
"cannot index vectors with dimension greater than " + FloatVectorValues.MAX_DIMENSIONS);
}
if (similarityFunction == null) {
throw new IllegalArgumentException("similarity function must not be null");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
*/
public abstract class ByteVectorValues extends DocIdSetIterator {

/** The maximum length of a vector */
public static final int MAX_DIMENSIONS = 1024;

/** Sole constructor */
protected ByteVectorValues() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
*/
public abstract class FloatVectorValues extends DocIdSetIterator {

/** The maximum length of a vector */
public static final int MAX_DIMENSIONS = 1024;

/** Sole constructor */
protected FloatVectorValues() {}

Expand Down
20 changes: 20 additions & 0 deletions lucene/core/src/java/org/apache/lucene/index/IndexingChain.java
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,12 @@ private void initializeFieldInfo(PerField pf) throws IOException {
final Sort indexSort = indexWriterConfig.getIndexSort();
validateIndexSortDVType(indexSort, pf.fieldName, s.docValuesType);
}
if (s.vectorDimension != 0) {
validateMaxVectorDimension(
pf.fieldName,
s.vectorDimension,
indexWriterConfig.getCodec().knnVectorsFormat().getMaxDimensions(pf.fieldName));
}
FieldInfo fi =
fieldInfos.add(
new FieldInfo(
Expand Down Expand Up @@ -831,6 +837,20 @@ private static void verifyUnIndexedFieldType(String name, IndexableFieldType ft)
}
}

private static void validateMaxVectorDimension(
String fieldName, int vectorDim, int maxVectorDim) {
if (vectorDim > maxVectorDim) {
throw new IllegalArgumentException(
"Field ["
+ fieldName
+ "]"
+ "vector's dimensions must be <= ["
+ maxVectorDim
+ "]; got "
+ vectorDim);
}
}

private void validateIndexSortDVType(Sort indexSort, String fieldToValidate, DocValuesType dvType)
throws IOException {
for (SortField sortField : indexSort.getSort()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnFloatVectorField;
Expand All @@ -43,6 +44,9 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
Expand Down Expand Up @@ -162,6 +166,50 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
}
}

public void testMaxDimensionsPerFieldFormat() throws IOException {
try (Directory directory = newDirectory()) {
IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random()));
KnnVectorsFormat format1 =
new KnnVectorsFormatMaxDims32(new Lucene95HnswVectorsFormat(16, 100));
KnnVectorsFormat format2 = new Lucene95HnswVectorsFormat(16, 100);
iwc.setCodec(
new AssertingCodec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
if ("field1".equals(field)) {
return format1;
} else {
return format2;
}
}
});
try (IndexWriter writer = new IndexWriter(directory, iwc)) {
Document doc1 = new Document();
doc1.add(new KnnFloatVectorField("field1", new float[33]));
Exception exc =
expectThrows(IllegalArgumentException.class, () -> writer.addDocument(doc1));
assertTrue(exc.getMessage().contains("vector's dimensions must be <= [32]"));

Document doc2 = new Document();
doc2.add(new KnnFloatVectorField("field1", new float[32]));
doc2.add(new KnnFloatVectorField("field2", new float[33]));
writer.addDocument(doc2);
}

// Check that the vectors were written
try (IndexReader reader = DirectoryReader.open(directory)) {
IndexSearcher searcher = new IndexSearcher(reader);
Query query1 = new KnnFloatVectorQuery("field1", new float[32], 10);
TopDocs topDocs1 = searcher.search(query1, 1);
assertEquals(1, topDocs1.scoreDocs.length);

Query query2 = new KnnFloatVectorQuery("field2", new float[33], 10);
TopDocs topDocs2 = searcher.search(query2, 1);
assertEquals(1, topDocs2.scoreDocs.length);
}
}
}

private static class WriteRecordingKnnVectorsFormat extends KnnVectorsFormat {
private final KnnVectorsFormat delegate;
private final Set<String> fieldsWritten;
Expand Down Expand Up @@ -216,4 +264,28 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
return delegate.fieldsReader(state);
}
}

private static class KnnVectorsFormatMaxDims32 extends KnnVectorsFormat {
private final KnnVectorsFormat delegate;

public KnnVectorsFormatMaxDims32(KnnVectorsFormat delegate) {
super(delegate.getName());
this.delegate = delegate;
}

@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return delegate.fieldsWriter(state);
}

@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return delegate.fieldsReader(state);
}

@Override
public int getMaxDimensions(String fieldName) {
return 32;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.index.PointValues;
Expand Down Expand Up @@ -280,7 +279,7 @@ public void testRandom() throws Exception {
var builder = INDEX_PACKAGE_ACCESS.newFieldInfosBuilder(softDeletesField);

for (String field : fieldNames) {
IndexableFieldType fieldType = randomFieldType(random());
IndexableFieldType fieldType = randomFieldType(random(), field);
boolean storeTermVectors = false;
boolean storePayloads = false;
boolean omitNorms = false;
Expand Down Expand Up @@ -319,7 +318,11 @@ public void testRandom() throws Exception {
dir.close();
}

private IndexableFieldType randomFieldType(Random r) {
private int getVectorsMaxDimensions(String fieldName) {
return Codec.getDefault().knnVectorsFormat().getMaxDimensions(fieldName);
}

private IndexableFieldType randomFieldType(Random r, String fieldName) {
FieldType type = new FieldType();

if (r.nextBoolean()) {
Expand Down Expand Up @@ -352,7 +355,7 @@ private IndexableFieldType randomFieldType(Random r) {
}

if (r.nextBoolean()) {
int dimension = 1 + r.nextInt(FloatVectorValues.MAX_DIMENSIONS);
int dimension = 1 + r.nextInt(getVectorsMaxDimensions(fieldName));
VectorSimilarityFunction similarityFunction =
RandomPicks.randomFrom(r, VectorSimilarityFunction.values());
VectorEncoding encoding = RandomPicks.randomFrom(r, VectorEncoding.values());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.NumericDocValuesField;
Expand Down Expand Up @@ -91,6 +90,10 @@ protected void addRandomFields(Document doc) {
}
}

private int getVectorsMaxDimensions(String fieldName) {
return Codec.getDefault().knnVectorsFormat().getMaxDimensions(fieldName);
}

public void testFieldConstructor() {
float[] v = new float[1];
KnnFloatVectorField field = new KnnFloatVectorField("f", v);
Expand All @@ -106,14 +109,6 @@ public void testFieldConstructorExceptions() {
IllegalArgumentException.class,
() -> new KnnFloatVectorField("f", new float[1], (VectorSimilarityFunction) null));
expectThrows(IllegalArgumentException.class, () -> new KnnFloatVectorField("f", new float[0]));
expectThrows(
IllegalArgumentException.class,
() -> new KnnFloatVectorField("f", new float[FloatVectorValues.MAX_DIMENSIONS + 1]));
expectThrows(
IllegalArgumentException.class,
() ->
new KnnFloatVectorField(
"f", new float[FloatVectorValues.MAX_DIMENSIONS + 1], (FieldType) null));
}

public void testFieldSetValue() {
Expand Down Expand Up @@ -483,18 +478,42 @@ public void testIllegalDimensionTooLarge() throws Exception {
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
expectThrows(
IllegalArgumentException.class,
() ->
doc.add(
new KnnFloatVectorField(
"f",
new float[FloatVectorValues.MAX_DIMENSIONS + 1],
VectorSimilarityFunction.DOT_PRODUCT)));
doc.add(
new KnnFloatVectorField(
"f",
new float[getVectorsMaxDimensions("f") + 1],
VectorSimilarityFunction.DOT_PRODUCT));
Exception exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc));
assertTrue(
exc.getMessage()
.contains("vector's dimensions must be <= [" + getVectorsMaxDimensions("f") + "]"));

Document doc2 = new Document();
doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.EUCLIDEAN));
doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc2);

Document doc3 = new Document();
doc3.add(
new KnnFloatVectorField(
"f",
new float[getVectorsMaxDimensions("f") + 1],
VectorSimilarityFunction.DOT_PRODUCT));
exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc3));
assertTrue(
exc.getMessage()
.contains("Inconsistency of field data structures across documents for field [f]"));
w.flush();

Document doc4 = new Document();
doc4.add(
new KnnFloatVectorField(
"f",
new float[getVectorsMaxDimensions("f") + 1],
VectorSimilarityFunction.DOT_PRODUCT));
exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc4));
assertTrue(
exc.getMessage()
.contains("vector's dimensions must be <= [" + getVectorsMaxDimensions("f") + "]"));
}
}

Expand Down

0 comments on commit b247afe

Please sign in to comment.