Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing redundant type conversions for script scoring for hamming space with binary vectors #2351

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add WithFieldName implementation to KNNQueryBuilder (#2398)[https://github.com/opensearch-project/k-NN/pull/2398]
- Make the build work for M series MacOS without manual code changes and local JAVA_HOME config (#2397)[https://github.com/opensearch-project/k-NN/pull/2397]
- Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter (#2408)[https://github.com/opensearch-project/k-NN/pull/2408]
- Removing redundant type conversions for script scoring for hamming space with binary vectors (#2351)[https://github.com/opensearch-project/k-NN/pull/2351]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Fixing the bug where search fails with "fields" parameter for an index with a knn_vector field (#2314)[https://github.com/opensearch-project/k-NN/pull/2314]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public long ramBytesUsed() {
}

@Override
public ScriptDocValues<float[]> getScriptValues() {
public ScriptDocValues<?> getScriptValues() {
try {
FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, fieldName);
if (fieldInfo == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import org.opensearch.index.fielddata.ScriptDocValues;

@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public abstract class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {
public abstract class KNNVectorScriptDocValues<T> extends ScriptDocValues<T> {

private final DocIdSetIterator vectorValues;
private final String fieldName;
Expand All @@ -42,7 +42,7 @@ public void setNextDocId(int docId) throws IOException {
docExists = lastDocID == curDocID;
}

public float[] getValue() {
public T getValue() {
if (!docExists) {
String errorMessage = String.format(
"One of the document doesn't have a value for field '%s'. "
Expand All @@ -60,15 +60,15 @@ public float[] getValue() {
}
}

protected abstract float[] doGetValue() throws IOException;
protected abstract T doGetValue() throws IOException;

@Override
public int size() {
return docExists ? 1 : 0;
}

@Override
public float[] get(int i) {
public T get(int i) {
throw new UnsupportedOperationException("knn vector does not support this operation");
}

Expand All @@ -81,20 +81,20 @@ public float[] get(int i) {
* @return A KNNVectorScriptDocValues object based on the type of the values.
* @throws IllegalArgumentException If the type of values is unsupported.
*/
public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
public static KNNVectorScriptDocValues<?> create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
Objects.requireNonNull(values, "values must not be null");
if (values instanceof ByteVectorValues) {
return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof FloatVectorValues) {
return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof BinaryDocValues) {
return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType);
return new KNNNativeVectorScriptDocValues<>((BinaryDocValues) values, fieldName, vectorDataType);
} else {
throw new IllegalArgumentException("Unsupported values type: " + values.getClass());
}
}

private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues {
private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues<byte[]> {
private final ByteVectorValues values;

KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) {
Expand All @@ -103,17 +103,16 @@ private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptD
}

@Override
protected float[] doGetValue() throws IOException {
byte[] bytes = values.vectorValue();
float[] value = new float[bytes.length];
for (int i = 0; i < bytes.length; i++) {
value[i] = (float) bytes[i];
protected byte[] doGetValue() throws IOException {
try {
return values.vectorValue();
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
return value;
}
}

private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues {
private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues<float[]> {
private final FloatVectorValues values;

KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) {
Expand All @@ -127,7 +126,7 @@ protected float[] doGetValue() throws IOException {
}
}

private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues {
private static final class KNNNativeVectorScriptDocValues<T> extends KNNVectorScriptDocValues<T> {
private final BinaryDocValues values;

KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) {
Expand All @@ -136,7 +135,7 @@ private static final class KNNNativeVectorScriptDocValues extends KNNVectorScrip
}

@Override
protected float[] doGetValue() throws IOException {
protected T doGetValue() throws IOException {
return getVectorDataType().getVectorFromBytesRef(values.binaryValue());
}
}
Expand All @@ -148,10 +147,18 @@ protected float[] doGetValue() throws IOException {
* @param type The data type of the vector.
* @return An empty KNNVectorScriptDocValues object.
*/
public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) {
return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) {
public static KNNVectorScriptDocValues<?> emptyValues(String fieldName, VectorDataType type) {
if (type == VectorDataType.FLOAT) {
return new KNNVectorScriptDocValues<float[]>(DocIdSetIterator.empty(), fieldName, type) {
@Override
protected float[] doGetValue() throws IOException {
throw new UnsupportedOperationException("empty values");
}
};
}
return new KNNVectorScriptDocValues<byte[]>(DocIdSetIterator.empty(), fieldName, type) {
@Override
protected float[] doGetValue() throws IOException {
protected byte[] doGetValue() throws IOException {
throw new UnsupportedOperationException("empty values");
}
};
Expand Down
24 changes: 5 additions & 19 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,8 @@ public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunc
}

@Override
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
int i = 0;
int j = binaryValue.offset;

while (i < binaryValue.length) {
vector[i++] = binaryValue.bytes[j++];
}
return vector;
public byte[] getVectorFromBytesRef(BytesRef binaryValue) {
return binaryValue.bytes;
}

@Override
Expand All @@ -75,15 +68,8 @@ public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunc
}

@Override
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
int i = 0;
int j = binaryValue.offset;

while (i < binaryValue.length) {
vector[i++] = binaryValue.bytes[j++];
}
return vector;
public byte[] getVectorFromBytesRef(BytesRef binaryValue) {
return binaryValue.bytes;
}

@Override
Expand Down Expand Up @@ -143,7 +129,7 @@ public void freeNativeMemory(long memoryAddress) {
* @param binaryValue Binary Value
* @return float vector deserialized from binary value
*/
public abstract float[] getVectorFromBytesRef(BytesRef binaryValue);
public abstract <T> T getVectorFromBytesRef(BytesRef binaryValue);

/**
* @param trainingDataAllocation training data that has been allocated in native memory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ public double execute(ScoreScript.ExplanationHolder explanationHolder) {
* KNNVectors with float[] type. The query value passed in is expected to be float[]. The fieldType of the docs
* being searched over are expected to be KNNVector type.
*/
public static class KNNVectorType extends KNNScoreScript<float[]> {
public static class KNNFloatVectorType extends KNNScoreScript<float[]> {

public KNNVectorType(
public KNNFloatVectorType(
Map<String, Object> params,
float[] queryValue,
String field,
Expand All @@ -136,8 +136,45 @@ public KNNVectorType(
* @return score of the vector to the query vector
*/
@Override
@SuppressWarnings("unchecked")
public double execute(ScoreScript.ExplanationHolder explanationHolder) {
KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field);
KNNVectorScriptDocValues<float[]> scriptDocValues = (KNNVectorScriptDocValues<float[]>) getDoc().get(this.field);
if (scriptDocValues.isEmpty()) {
return 0.0;
}
return this.scoringMethod.apply(this.queryValue, scriptDocValues.getValue());
}
}

/**
* KNNVectors with byte[] type. The query value passed in is expected to be byte[]. The fieldType of the docs
* being searched over are expected to be KNNVector type.
*/
public static class KNNByteVectorType extends KNNScoreScript<byte[]> {

public KNNByteVectorType(
Map<String, Object> params,
byte[] queryValue,
String field,
BiFunction<byte[], byte[], Float> scoringMethod,
SearchLookup lookup,
LeafReaderContext leafContext,
IndexSearcher searcher
) throws IOException {
super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher);
}

/**
* This function called for each doc in the segment. We evaluate the score of the vector in the doc
*
* @param explanationHolder A helper to take in an explanation from a script and turn
* it into an {@link org.apache.lucene.search.Explanation}
* @return score of the vector to the query vector
*/
@Override
@SuppressWarnings("unchecked")
public double execute(ScoreScript.ExplanationHolder explanationHolder) {
KNNVectorScriptDocValues<byte[]> scriptDocValues = (KNNVectorScriptDocValues<byte[]>) getDoc().get(this.field);
if (scriptDocValues.isEmpty()) {
return 0.0;
}
Expand Down
Loading
Loading