Skip to content

Commit

Permalink
Increase vector size limit to 4096 (#680)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakn7 authored Aug 6, 2024
1 parent a7718c1 commit 35970d0
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import java.util.List;
import java.util.Map;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.document.Document;
Expand All @@ -38,6 +40,8 @@
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
Expand All @@ -53,7 +57,7 @@ public class VectorFieldDef extends IndexableFieldDef implements VectorQueryable
"cosine",
VectorSimilarityFunction.COSINE);
private static final String HNSW_FORMAT_TYPE = "hnsw";
private static final int MAX_DOC_VALUE_DIMENSIONS = 2048;
private static final int MAX_VECTOR_DIMENSIONS = 4096;
private final int vectorDimensions;
private final VectorSimilarityFunction similarityFunction;
private final KnnVectorsFormat vectorsFormat;
Expand Down Expand Up @@ -87,7 +91,29 @@ private static KnnVectorsFormat createVectorsFormat(VectorIndexingOptions vector
vectorIndexingOptions.getHnswEfConstruction() > 0
? vectorIndexingOptions.getHnswEfConstruction()
: Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
return new Lucene99HnswVectorsFormat(m, efConstruction);
Lucene99HnswVectorsFormat lucene99HnswVectorsFormat =
new Lucene99HnswVectorsFormat(m, efConstruction);
return new KnnVectorsFormat(lucene99HnswVectorsFormat.getName()) {
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return lucene99HnswVectorsFormat.fieldsWriter(state);
}

@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return lucene99HnswVectorsFormat.fieldsReader(state);
}

@Override
public int getMaxDimensions(String fieldName) {
return MAX_VECTOR_DIMENSIONS;
}

@Override
public String toString() {
return lucene99HnswVectorsFormat.toString();
}
};
}

/**
Expand Down Expand Up @@ -132,19 +158,9 @@ protected void validateRequest(Field requestField) {
if (requestField.getVectorDimensions() <= 0) {
throw new IllegalArgumentException("Vector dimension should be > 0");
}
if (requestField.getStoreDocValues()
&& requestField.getVectorDimensions() > MAX_DOC_VALUE_DIMENSIONS) {
throw new IllegalArgumentException(
"Vector dimension must be <= " + MAX_DOC_VALUE_DIMENSIONS + " for doc values");
}

if (requestField.getSearch()) {
if (requestField.getVectorDimensions() > Lucene99HnswVectorsFormat.DEFAULT_MAX_DIMENSIONS) {
throw new IllegalArgumentException(
"Vector dimension must be <= "
+ Lucene99HnswVectorsFormat.DEFAULT_MAX_DIMENSIONS
+ " for search");
}
if ((requestField.getStoreDocValues() || requestField.getSearch())
&& requestField.getVectorDimensions() > MAX_VECTOR_DIMENSIONS) {
throw new IllegalArgumentException("Vector dimension must be <= " + MAX_VECTOR_DIMENSIONS);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -768,14 +768,14 @@ public void testMaxDocValueDimensions() {
Field.newBuilder()
.setName("vector")
.setType(FieldType.VECTOR)
.setVectorDimensions(2049)
.setVectorDimensions(4097)
.setStoreDocValues(true)
.build();
try {
new VectorFieldDef("vector", field);
fail();
} catch (IllegalArgumentException e) {
assertEquals("Vector dimension must be <= 2048 for doc values", e.getMessage());
assertEquals("Vector dimension must be <= 4096", e.getMessage());
}
}

Expand All @@ -785,15 +785,15 @@ public void testMaxSearchDimensions() {
Field.newBuilder()
.setName("vector")
.setType(FieldType.VECTOR)
.setVectorDimensions(1025)
.setVectorDimensions(4097)
.setSearch(true)
.setVectorSimilarity("cosine")
.build();
try {
new VectorFieldDef("vector", field);
fail();
} catch (IllegalArgumentException e) {
assertEquals("Vector dimension must be <= 1024 for search", e.getMessage());
assertEquals("Vector dimension must be <= 4096", e.getMessage());
}
}

Expand Down

0 comments on commit 35970d0

Please sign in to comment.