Skip to content

Commit

Permalink
Add binary index support for Lucene engine
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Deng <jayd0104@gmail.com>
  • Loading branch information
jed326 authored and Jay Deng committed Dec 12, 2024
1 parent 2b9a741 commit 0f768ed
Show file tree
Hide file tree
Showing 16 changed files with 289 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x)
### Features
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public float compare(byte[] v1, byte[] v2) {

@Override
public VectorSimilarityFunction getVectorSimilarityFunction() {
throw new IllegalStateException("VectorSimilarityFunction is not available for Hamming space");
// This is not used in binary case
return VectorSimilarityFunction.EUCLIDEAN;
}
};

Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public enum VectorDataType {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
throw new IllegalStateException("Unsupported method");
return KnnByteVectorField.createFieldType(dimension / Byte.SIZE, vectorSimilarityFunction);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
}
}

KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(
params,
defaultMaxConnections,
defaultBeamWidth,
knnMethodContext.getSpaceType()
);
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
field,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;

import java.io.IOException;

/**
* A FlatVectorsScorer to be used for scoring binary vectors. Meant to be used with {@link KNN990BinaryVectorScorer}
*/
public class KNN990BinaryVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues
) throws IOException {
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) {
return new BinaryRandomVectorScorerSupplier((RandomAccessVectorValues.Bytes) randomAccessVectorValues);
}
throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes");
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
float[] queryVector
) throws IOException {
throw new IllegalArgumentException("binary vectors do not support float[] targets");
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
byte[] queryVector
) throws IOException {
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) {
return new BinaryRandomVectorScorer((RandomAccessVectorValues.Bytes) randomAccessVectorValues, queryVector);
}
throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes");
}

static class BinaryRandomVectorScorer implements RandomVectorScorer {
private final RandomAccessVectorValues.Bytes vectorValues;
private final byte[] queryVector;

BinaryRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) {
this.queryVector = query;
this.vectorValues = vectorValues;
}

@Override
public float score(int node) throws IOException {
return 1 / (float) (1 + VectorUtil.xorBitCount(queryVector, vectorValues.vectorValue(node)));
}

@Override
public int maxOrd() {
return vectorValues.size();
}

@Override
public int ordToDoc(int ord) {
return vectorValues.ordToDoc(ord);
}

@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return vectorValues.getAcceptOrds(acceptDocs);
}
}

static class BinaryRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
protected final RandomAccessVectorValues.Bytes vectorValues;
protected final RandomAccessVectorValues.Bytes vectorValues1;
protected final RandomAccessVectorValues.Bytes vectorValues2;

public BinaryRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues) throws IOException {
this.vectorValues = vectorValues;
this.vectorValues1 = vectorValues.copy();
this.vectorValues2 = vectorValues.copy();
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
byte[] queryVector = vectorValues1.vectorValue(ord);
return new BinaryRandomVectorScorer(vectorValues2, queryVector);
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new BinaryRandomVectorScorerSupplier(vectorValues.copy());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.TaskExecutor;
import org.opensearch.knn.index.engine.KNNEngine;

import java.io.IOException;
import java.util.concurrent.ExecutorService;

import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
import static org.opensearch.knn.index.engine.KNNEngine.getMaxDimensionByEngine;

/**
* Custom KnnVectorsFormat implementation to support binary vectors. This class is mostly identical to
* {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat}, however we use the custom {@link KNN990BinaryVectorScorer}
* to perform hamming bit scoring.
*/
public final class KNN990HnswBinaryVectorsFormat extends KnnVectorsFormat {

private final int maxConn;
private final int beamWidth;
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(new KNN990BinaryVectorScorer());
private final int numMergeWorkers;
private final TaskExecutor mergeExec;

private static final String NAME = "KNN990HnswBinaryVectorsFormat";

/**
* Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat()}
*/
public KNN990HnswBinaryVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
}

/**
* Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat(int, int)}
*/
public KNN990HnswBinaryVectorsFormat(int maxConn, int beamWidth) {
this(maxConn, beamWidth, 1, null);
}

/**
* Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat(int, int, int, java.util.concurrent.ExecutorService)}
*/
public KNN990HnswBinaryVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
super(NAME);
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
);
}
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
throw new IllegalArgumentException(
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
);
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
if (numMergeWorkers == 1 && mergeExec != null) {
throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge");
}
this.numMergeWorkers = numMergeWorkers;
if (mergeExec != null) {
this.mergeExec = new TaskExecutor(mergeExec);
} else {
this.mergeExec = null;
}
}

@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(
state,
this.maxConn,
this.beamWidth,
flatVectorsFormat.fieldsWriter(state),
this.numMergeWorkers,
this.mergeExec
);
}

@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}

@Override
public int getMaxDimensions(String fieldName) {
return getMaxDimensionByEngine(KNNEngine.LUCENE);
}

@Override
public String toString() {
return "KNN990HnswBinaryVectorsFormat(name=KNN990HnswBinaryVectorsFormat, maxConn="
+ this.maxConn
+ ", beamWidth="
+ this.beamWidth
+ ", flatVectorFormat="
+ flatVectorsFormat
+ ")";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;
import org.opensearch.knn.index.engine.KNNEngine;

Expand All @@ -24,11 +25,19 @@ public KNN990PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
mapperService,
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene99HnswVectorsFormat(),
knnVectorsFormatParams -> new Lucene99HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
),
Lucene99HnswVectorsFormat::new,
knnVectorsFormatParams -> {
// There is an assumption here that hamming space will only be used for binary vectors. This will need to be fixed if that
// changes in the future.
if (knnVectorsFormatParams.getSpaceType() == SpaceType.HAMMING) {
return new KNN990HnswBinaryVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
);
} else {
return new Lucene99HnswVectorsFormat(knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth());
}
},
knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Getter;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;

import java.util.Map;

Expand All @@ -17,10 +18,16 @@
public class KNNVectorsFormatParams {
private int maxConnections;
private int beamWidth;
private final SpaceType spaceType;

public KNNVectorsFormatParams(final Map<String, Object> params, int defaultMaxConnections, int defaultBeamWidth) {
this(params, defaultMaxConnections, defaultBeamWidth, SpaceType.UNDEFINED);
}

public KNNVectorsFormatParams(final Map<String, Object> params, int defaultMaxConnections, int defaultBeamWidth, SpaceType spaceType) {
initMaxConnections(params, defaultMaxConnections);
initBeamWidth(params, defaultBeamWidth);
this.spaceType = spaceType;
}

public boolean validate(final Map<String, Object> params) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,18 @@
*/
public class LuceneHNSWMethod extends AbstractKNNMethod {

private static final Set<VectorDataType> SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT, VectorDataType.BYTE);
private static final Set<VectorDataType> SUPPORTED_DATA_TYPES = ImmutableSet.of(
VectorDataType.FLOAT,
VectorDataType.BYTE,
VectorDataType.BINARY
);

public final static List<SpaceType> SUPPORTED_SPACES = Arrays.asList(
SpaceType.UNDEFINED,
SpaceType.L2,
SpaceType.COSINESIMIL,
SpaceType.INNER_PRODUCT
SpaceType.INNER_PRODUCT,
SpaceType.HAMMING
);

final static Encoder SQ_ENCODER = new LuceneSQEncoder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
case BINARY:
return new LuceneEngineKnnVectorQuery(
getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, expandNested)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
#

org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat
org.opensearch.knn.index.codec.KNN990Codec.KNN990HnswBinaryVectorsFormat
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.util.BytesRef;
Expand Down Expand Up @@ -109,14 +108,6 @@ private void createKNNByteVectorDocument(Directory directory) throws IOException
writer.close();
}

public void testCreateKnnVectorFieldType_whenBinary_thenException() {
Exception ex = expectThrows(
IllegalStateException.class,
() -> VectorDataType.BINARY.createKnnVectorFieldType(1, VectorSimilarityFunction.EUCLIDEAN)
);
assertTrue(ex.getMessage().contains("Unsupported method"));
}

public void testGetVectorFromBytesRef_whenBinary_thenException() {
byte[] vector = { 1, 2, 3 };
float[] expected = { 1, 2, 3 };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,6 @@ public void testValidateVectorDataType_whenBinaryFaissHNSW_thenValid() {
}

public void testValidateVectorDataType_whenBinaryNonFaiss_thenException() {
validateValidateVectorDataType(
KNNEngine.LUCENE,
KNNConstants.METHOD_HNSW,
VectorDataType.BINARY,
SpaceType.HAMMING,
"UnsupportedMethod"
);
validateValidateVectorDataType(
KNNEngine.NMSLIB,
KNNConstants.METHOD_HNSW,
Expand Down
Loading

0 comments on commit 0f768ed

Please sign in to comment.