From b247afe8fcfcbdc522107283e567d720820bbcef Mon Sep 17 00:00:00 2001 From: Mayya Sharipova Date: Thu, 27 Jul 2023 14:50:33 -0400 Subject: [PATCH] Move max vector dims limit to Codec (#12436) Move vector max dimension limits enforcement into the default Codec's KnnVectorsFormat implementation. This allows different implementation of knn search algorithms define their own limits of a maximum vector dimensions that they can handle. Closes #12309 --- lucene/CHANGES.txt | 2 + .../lucene/codecs/KnnVectorsFormat.java | 16 +++++ .../lucene95/Lucene95HnswVectorsFormat.java | 5 ++ .../perfield/PerFieldKnnVectorsFormat.java | 5 ++ .../org/apache/lucene/document/FieldType.java | 8 --- .../lucene/document/KnnByteVectorField.java | 4 -- .../lucene/document/KnnFloatVectorField.java | 4 -- .../apache/lucene/index/ByteVectorValues.java | 3 - .../lucene/index/FloatVectorValues.java | 3 - .../apache/lucene/index/IndexingChain.java | 20 ++++++ .../TestPerFieldKnnVectorsFormat.java | 72 +++++++++++++++++++ .../index/BaseFieldInfoFormatTestCase.java | 11 +-- .../index/BaseKnnVectorsFormatTestCase.java | 55 +++++++++----- 13 files changed, 164 insertions(+), 44 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index e06d6ee6b3ba..b5120f63aabb 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -11,6 +11,8 @@ API Changes * GITHUB#11248: IntBlockPool's SliceReader, SliceWriter, and all int slice functionality are moved out to MemoryIndex. (Stefan Vodita) +* GITHUB#12436: Move max vector dims limit to Codec (Mayya Sharipova) + New Features --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java index 66623bacadc6..a17d844d8111 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java @@ -32,6 +32,9 @@ */ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI { + /** The maximum number of vector dimensions */ + public static final int DEFAULT_MAX_DIMENSIONS = 1024; + /** * This static holder class prevents classloading deadlock by delaying init of doc values formats * until needed. @@ -76,6 +79,19 @@ public static KnnVectorsFormat forName(String name) { /** Returns a {@link KnnVectorsReader} to read the vectors from the index. */ public abstract KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException; + /** + * Returns the maximum number of vector dimensions supported by this codec for the given field + * name + * + *

Codecs should override this method to specify the maximum number of dimensions they support. + * + * @param fieldName the field name + * @return the maximum number of vector dimensions. + */ + public int getMaxDimensions(String fieldName) { + return DEFAULT_MAX_DIMENSIONS; + } + /** * EMPTY throws an exception when written. It acts as a sentinel indicating a Codec that does not * support vectors. diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsFormat.java index cb3e5ef8b10c..f1594b3068ca 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsFormat.java @@ -185,6 +185,11 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException return new Lucene95HnswVectorsReader(state); } + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + @Override public String toString() { return "Lucene95HnswVectorsFormat(name=Lucene95HnswVectorsFormat, maxConn=" diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java index 1247dff556fd..6d344629f1b1 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java @@ -80,6 +80,11 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException return new FieldsReader(state); } + @Override + public int getMaxDimensions(String fieldName) { + return getKnnVectorsFormatForField(fieldName).getMaxDimensions(fieldName); + } + /** * Returns the numeric vector format that should be used for writing new segments of field * . diff --git a/lucene/core/src/java/org/apache/lucene/document/FieldType.java b/lucene/core/src/java/org/apache/lucene/document/FieldType.java index aba3fa3c3bbf..5b37955cc145 100644 --- a/lucene/core/src/java/org/apache/lucene/document/FieldType.java +++ b/lucene/core/src/java/org/apache/lucene/document/FieldType.java @@ -21,7 +21,6 @@ import java.util.Objects; import org.apache.lucene.analysis.Analyzer; // javadocs import org.apache.lucene.index.DocValuesType; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexableFieldType; import org.apache.lucene.index.PointValues; @@ -378,13 +377,6 @@ public void setVectorAttributes( if (numDimensions <= 0) { throw new IllegalArgumentException("vector numDimensions must be > 0; got " + numDimensions); } - if (numDimensions > FloatVectorValues.MAX_DIMENSIONS) { - throw new IllegalArgumentException( - "vector numDimensions must be <= FloatVectorValues.MAX_DIMENSIONS (=" - + FloatVectorValues.MAX_DIMENSIONS - + "); got " - + numDimensions); - } this.vectorDimension = numDimensions; this.vectorSimilarityFunction = Objects.requireNonNull(similarity); this.vectorEncoding = Objects.requireNonNull(encoding); diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java index 87cb6a9f056e..ddeef44c72ae 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java @@ -46,10 +46,6 @@ private static FieldType createType(byte[] v, VectorSimilarityFunction similarit if (dimension == 0) { throw new IllegalArgumentException("cannot index an empty vector"); } - if (dimension > ByteVectorValues.MAX_DIMENSIONS) { - throw new IllegalArgumentException( - "cannot index vectors with dimension greater than " + ByteVectorValues.MAX_DIMENSIONS); - } if (similarityFunction == null) { throw new IllegalArgumentException("similarity function must not be null"); } diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java index 9d1cd02c013e..63a55ddc669f 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java @@ -47,10 +47,6 @@ private static FieldType createType(float[] v, VectorSimilarityFunction similari if (dimension == 0) { throw new IllegalArgumentException("cannot index an empty vector"); } - if (dimension > FloatVectorValues.MAX_DIMENSIONS) { - throw new IllegalArgumentException( - "cannot index vectors with dimension greater than " + FloatVectorValues.MAX_DIMENSIONS); - } if (similarityFunction == null) { throw new IllegalArgumentException("similarity function must not be null"); } diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index 5d532a7f9e88..e731e727aa8c 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -28,9 +28,6 @@ */ public abstract class ByteVectorValues extends DocIdSetIterator { - /** The maximum length of a vector */ - public static final int MAX_DIMENSIONS = 1024; - /** Sole constructor */ protected ByteVectorValues() {} diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 9de6b57531e2..0c3194bfac8f 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -28,9 +28,6 @@ */ public abstract class FloatVectorValues extends DocIdSetIterator { - /** The maximum length of a vector */ - public static final int MAX_DIMENSIONS = 1024; - /** Sole constructor */ protected FloatVectorValues() {} diff --git a/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java b/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java index a68a84ff5ac7..fff4e0d04460 100644 --- a/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java +++ b/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java @@ -621,6 +621,12 @@ private void initializeFieldInfo(PerField pf) throws IOException { final Sort indexSort = indexWriterConfig.getIndexSort(); validateIndexSortDVType(indexSort, pf.fieldName, s.docValuesType); } + if (s.vectorDimension != 0) { + validateMaxVectorDimension( + pf.fieldName, + s.vectorDimension, + indexWriterConfig.getCodec().knnVectorsFormat().getMaxDimensions(pf.fieldName)); + } FieldInfo fi = fieldInfos.add( new FieldInfo( @@ -831,6 +837,20 @@ private static void verifyUnIndexedFieldType(String name, IndexableFieldType ft) } } + private static void validateMaxVectorDimension( + String fieldName, int vectorDim, int maxVectorDim) { + if (vectorDim > maxVectorDim) { + throw new IllegalArgumentException( + "Field [" + + fieldName + + "]" + + "vector's dimensions must be <= [" + + maxVectorDim + + "]; got " + + vectorDim); + } + } + private void validateIndexSortDVType(Sort indexSort, String fieldToValidate, DocValuesType dvType) throws IOException { for (SortField sortField : indexSort.getSort()) { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java index c435de1047d6..dd06dd6c0618 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java @@ -29,6 +29,7 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnFloatVectorField; @@ -43,6 +44,9 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; @@ -162,6 +166,50 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } } + public void testMaxDimensionsPerFieldFormat() throws IOException { + try (Directory directory = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random())); + KnnVectorsFormat format1 = + new KnnVectorsFormatMaxDims32(new Lucene95HnswVectorsFormat(16, 100)); + KnnVectorsFormat format2 = new Lucene95HnswVectorsFormat(16, 100); + iwc.setCodec( + new AssertingCodec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + if ("field1".equals(field)) { + return format1; + } else { + return format2; + } + } + }); + try (IndexWriter writer = new IndexWriter(directory, iwc)) { + Document doc1 = new Document(); + doc1.add(new KnnFloatVectorField("field1", new float[33])); + Exception exc = + expectThrows(IllegalArgumentException.class, () -> writer.addDocument(doc1)); + assertTrue(exc.getMessage().contains("vector's dimensions must be <= [32]")); + + Document doc2 = new Document(); + doc2.add(new KnnFloatVectorField("field1", new float[32])); + doc2.add(new KnnFloatVectorField("field2", new float[33])); + writer.addDocument(doc2); + } + + // Check that the vectors were written + try (IndexReader reader = DirectoryReader.open(directory)) { + IndexSearcher searcher = new IndexSearcher(reader); + Query query1 = new KnnFloatVectorQuery("field1", new float[32], 10); + TopDocs topDocs1 = searcher.search(query1, 1); + assertEquals(1, topDocs1.scoreDocs.length); + + Query query2 = new KnnFloatVectorQuery("field2", new float[33], 10); + TopDocs topDocs2 = searcher.search(query2, 1); + assertEquals(1, topDocs2.scoreDocs.length); + } + } + } + private static class WriteRecordingKnnVectorsFormat extends KnnVectorsFormat { private final KnnVectorsFormat delegate; private final Set fieldsWritten; @@ -216,4 +264,28 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException return delegate.fieldsReader(state); } } + + private static class KnnVectorsFormatMaxDims32 extends KnnVectorsFormat { + private final KnnVectorsFormat delegate; + + public KnnVectorsFormatMaxDims32(KnnVectorsFormat delegate) { + super(delegate.getName()); + this.delegate = delegate; + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return delegate.fieldsWriter(state); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return delegate.fieldsReader(state); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 32; + } + } } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseFieldInfoFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseFieldInfoFormatTestCase.java index f4f3d29a6552..7a054c290c3e 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseFieldInfoFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseFieldInfoFormatTestCase.java @@ -32,7 +32,6 @@ import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexableFieldType; import org.apache.lucene.index.PointValues; @@ -280,7 +279,7 @@ public void testRandom() throws Exception { var builder = INDEX_PACKAGE_ACCESS.newFieldInfosBuilder(softDeletesField); for (String field : fieldNames) { - IndexableFieldType fieldType = randomFieldType(random()); + IndexableFieldType fieldType = randomFieldType(random(), field); boolean storeTermVectors = false; boolean storePayloads = false; boolean omitNorms = false; @@ -319,7 +318,11 @@ public void testRandom() throws Exception { dir.close(); } - private IndexableFieldType randomFieldType(Random r) { + private int getVectorsMaxDimensions(String fieldName) { + return Codec.getDefault().knnVectorsFormat().getMaxDimensions(fieldName); + } + + private IndexableFieldType randomFieldType(Random r, String fieldName) { FieldType type = new FieldType(); if (r.nextBoolean()) { @@ -352,7 +355,7 @@ private IndexableFieldType randomFieldType(Random r) { } if (r.nextBoolean()) { - int dimension = 1 + r.nextInt(FloatVectorValues.MAX_DIMENSIONS); + int dimension = 1 + r.nextInt(getVectorsMaxDimensions(fieldName)); VectorSimilarityFunction similarityFunction = RandomPicks.randomFrom(r, VectorSimilarityFunction.values()); VectorEncoding encoding = RandomPicks.randomFrom(r, VectorEncoding.values()); diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 69925fdb3191..3cf5944fd921 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -27,7 +27,6 @@ import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; -import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.document.NumericDocValuesField; @@ -91,6 +90,10 @@ protected void addRandomFields(Document doc) { } } + private int getVectorsMaxDimensions(String fieldName) { + return Codec.getDefault().knnVectorsFormat().getMaxDimensions(fieldName); + } + public void testFieldConstructor() { float[] v = new float[1]; KnnFloatVectorField field = new KnnFloatVectorField("f", v); @@ -106,14 +109,6 @@ public void testFieldConstructorExceptions() { IllegalArgumentException.class, () -> new KnnFloatVectorField("f", new float[1], (VectorSimilarityFunction) null)); expectThrows(IllegalArgumentException.class, () -> new KnnFloatVectorField("f", new float[0])); - expectThrows( - IllegalArgumentException.class, - () -> new KnnFloatVectorField("f", new float[FloatVectorValues.MAX_DIMENSIONS + 1])); - expectThrows( - IllegalArgumentException.class, - () -> - new KnnFloatVectorField( - "f", new float[FloatVectorValues.MAX_DIMENSIONS + 1], (FieldType) null)); } public void testFieldSetValue() { @@ -483,18 +478,42 @@ public void testIllegalDimensionTooLarge() throws Exception { try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - expectThrows( - IllegalArgumentException.class, - () -> - doc.add( - new KnnFloatVectorField( - "f", - new float[FloatVectorValues.MAX_DIMENSIONS + 1], - VectorSimilarityFunction.DOT_PRODUCT))); + doc.add( + new KnnFloatVectorField( + "f", + new float[getVectorsMaxDimensions("f") + 1], + VectorSimilarityFunction.DOT_PRODUCT)); + Exception exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc)); + assertTrue( + exc.getMessage() + .contains("vector's dimensions must be <= [" + getVectorsMaxDimensions("f") + "]")); Document doc2 = new Document(); - doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.EUCLIDEAN)); + doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc2); + + Document doc3 = new Document(); + doc3.add( + new KnnFloatVectorField( + "f", + new float[getVectorsMaxDimensions("f") + 1], + VectorSimilarityFunction.DOT_PRODUCT)); + exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc3)); + assertTrue( + exc.getMessage() + .contains("Inconsistency of field data structures across documents for field [f]")); + w.flush(); + + Document doc4 = new Document(); + doc4.add( + new KnnFloatVectorField( + "f", + new float[getVectorsMaxDimensions("f") + 1], + VectorSimilarityFunction.DOT_PRODUCT)); + exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc4)); + assertTrue( + exc.getMessage() + .contains("vector's dimensions must be <= [" + getVectorsMaxDimensions("f") + "]")); } }