Skip to content

Commit

Permalink
Add binary index support for Lucene engine (#2292)
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Deng <jayd0104@gmail.com>
  • Loading branch information
jed326 authored Dec 14, 2024
1 parent 2b9a741 commit 8005bbf
Show file tree
Hide file tree
Showing 21 changed files with 395 additions and 65 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x)
### Features
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public float compare(byte[] v1, byte[] v2) {

@Override
public VectorSimilarityFunction getVectorSimilarityFunction() {
// For binary vectors using Lucene engine we instead implement a custom BinaryVectorScorer
throw new IllegalStateException("VectorSimilarityFunction is not available for Hamming space");
}
};
Expand Down
22 changes: 12 additions & 10 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

/**
* Enum contains data_type of vectors
* Lucene supports byte and float data type
* Lucene supports binary, byte and float data type
* NMSLib supports only float data type
* Faiss supports binary and float data type
*/
Expand All @@ -39,8 +39,10 @@ public enum VectorDataType {
BINARY("binary") {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
throw new IllegalStateException("Unsupported method");
public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction) {
// For binary vectors using Lucene engine we instead implement a custom BinaryVectorScorer so the VectorSimilarityFunction will
// not be used.
return KnnByteVectorField.createFieldType(dimension / Byte.SIZE, VectorSimilarityFunction.EUCLIDEAN);
}

@Override
Expand Down Expand Up @@ -68,8 +70,8 @@ public void freeNativeMemory(long memoryAddress) {
BYTE("byte") {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction);
public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction) {
return KnnByteVectorField.createFieldType(dimension, knnVectorSimilarityFunction.getVectorSimilarityFunction());
}

@Override
Expand Down Expand Up @@ -97,8 +99,8 @@ public void freeNativeMemory(long memoryAddress) {
FLOAT("float") {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction);
public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction) {
return KnnVectorField.createFieldType(dimension, knnVectorSimilarityFunction.getVectorSimilarityFunction());
}

@Override
Expand Down Expand Up @@ -129,11 +131,11 @@ public void freeNativeMemory(long memoryAddress) {
* Creates a KnnVectorFieldType based on the VectorDataType using the provided dimension and
* VectorSimilarityFunction.
*
* @param dimension Dimension of the vector
* @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType
* @param dimension Dimension of the vector
* @param knnVectorSimilarityFunction KNNVectorSimilarityFunction for a given spaceType
* @return FieldType
*/
public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction);
public abstract FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction);

/**
* Deserializes float vector from BytesRef.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
}
}

KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(
params,
defaultMaxConnections,
defaultBeamWidth,
knnMethodContext.getSpaceType()
);
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
field,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN9120Codec;

import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.opensearch.knn.index.KNNVectorSimilarityFunction;

import java.io.IOException;

/**
* A FlatVectorsScorer to be used for scoring binary vectors. Meant to be used with {@link KNN9120BinaryVectorScorer}
*/
public class KNN9120BinaryVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues
) throws IOException {
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) {
return new BinaryRandomVectorScorerSupplier((RandomAccessVectorValues.Bytes) randomAccessVectorValues);
}
throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes");
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
float[] queryVector
) throws IOException {
throw new IllegalArgumentException("binary vectors do not support float[] targets");
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
byte[] queryVector
) throws IOException {
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) {
return new BinaryRandomVectorScorer((RandomAccessVectorValues.Bytes) randomAccessVectorValues, queryVector);
}
throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes");
}

static class BinaryRandomVectorScorer implements RandomVectorScorer {
private final RandomAccessVectorValues.Bytes vectorValues;
private final byte[] queryVector;

BinaryRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) {
this.queryVector = query;
this.vectorValues = vectorValues;
}

@Override
public float score(int node) throws IOException {
return KNNVectorSimilarityFunction.HAMMING.compare(queryVector, vectorValues.vectorValue(node));
}

@Override
public int maxOrd() {
return vectorValues.size();
}

@Override
public int ordToDoc(int ord) {
return vectorValues.ordToDoc(ord);
}

@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return vectorValues.getAcceptOrds(acceptDocs);
}
}

static class BinaryRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
protected final RandomAccessVectorValues.Bytes vectorValues;
protected final RandomAccessVectorValues.Bytes vectorValues1;
protected final RandomAccessVectorValues.Bytes vectorValues2;

public BinaryRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues) throws IOException {
this.vectorValues = vectorValues;
this.vectorValues1 = vectorValues.copy();
this.vectorValues2 = vectorValues.copy();
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
byte[] queryVector = vectorValues1.vectorValue(ord);
return new BinaryRandomVectorScorer(vectorValues2, queryVector);
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new BinaryRandomVectorScorerSupplier(vectorValues.copy());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN9120Codec;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.TaskExecutor;
import org.opensearch.knn.index.engine.KNNEngine;

import java.io.IOException;
import java.util.concurrent.ExecutorService;

import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
import static org.opensearch.knn.index.engine.KNNEngine.getMaxDimensionByEngine;

/**
* Custom KnnVectorsFormat implementation to support binary vectors. This class is mostly identical to
* {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat}, however we use the custom {@link KNN9120BinaryVectorScorer}
* to perform hamming bit scoring.
*/
public final class KNN9120HnswBinaryVectorsFormat extends KnnVectorsFormat {

private final int maxConn;
private final int beamWidth;
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(new KNN9120BinaryVectorScorer());
private final int numMergeWorkers;
private final TaskExecutor mergeExec;

private static final String NAME = "KNN990HnswBinaryVectorsFormat";

/**
* Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat()}
*/
public KNN9120HnswBinaryVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
}

/**
* Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat(int, int)}
*/
public KNN9120HnswBinaryVectorsFormat(int maxConn, int beamWidth) {
this(maxConn, beamWidth, 1, null);
}

/**
* Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat(int, int, int, java.util.concurrent.ExecutorService)}
*/
public KNN9120HnswBinaryVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
super(NAME);
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
);
}
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
throw new IllegalArgumentException(
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
);
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
if (numMergeWorkers == 1 && mergeExec != null) {
throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge");
}
this.numMergeWorkers = numMergeWorkers;
if (mergeExec != null) {
this.mergeExec = new TaskExecutor(mergeExec);
} else {
this.mergeExec = null;
}
}

@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(
state,
this.maxConn,
this.beamWidth,
flatVectorsFormat.fieldsWriter(state),
this.numMergeWorkers,
this.mergeExec
);
}

@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}

@Override
public int getMaxDimensions(String fieldName) {
return getMaxDimensionByEngine(KNNEngine.LUCENE);
}

@Override
public String toString() {
return "KNN990HnswBinaryVectorsFormat(name=KNN990HnswBinaryVectorsFormat, maxConn="
+ this.maxConn
+ ", beamWidth="
+ this.beamWidth
+ ", flatVectorFormat="
+ flatVectorsFormat
+ ")";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN9120Codec;

import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;
import org.opensearch.knn.index.engine.KNNEngine;

import java.util.Optional;

/**
* Class provides per field format implementation for Lucene Knn vector type
*/
public class KNN9120PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat {
private static final int NUM_MERGE_WORKERS = 1;

public KNN9120PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
super(
mapperService,
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
Lucene99HnswVectorsFormat::new,
knnVectorsFormatParams -> {
// There is an assumption here that hamming space will only be used for binary vectors. This will need to be fixed if that
// changes in the future.
if (knnVectorsFormatParams.getSpaceType() == SpaceType.HAMMING) {
return new KNN9120HnswBinaryVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
);
} else {
return new Lucene99HnswVectorsFormat(knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth());
}
},
knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
NUM_MERGE_WORKERS,
knnScalarQuantizedVectorsFormatParams.getBits(),
knnScalarQuantizedVectorsFormatParams.isCompressFlag(),
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
null
)
);
}

@Override
/**
* This method returns the maximum dimension allowed from KNNEngine for Lucene codec
*
* @param fieldName Name of the field, ignored
* @return Maximum constant dimension set by KNNEngine
*/
public int getMaxDimensions(String fieldName) {
return KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public KNN990PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
mapperService,
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene99HnswVectorsFormat(),
Lucene99HnswVectorsFormat::new,
knnVectorsFormatParams -> new Lucene99HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
Expand Down
Loading

0 comments on commit 8005bbf

Please sign in to comment.