forked from opensearch-project/k-NN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixing filtered vector search when quantization is applied (opensearc…
…h-project#2076) Signed-off-by: Navneet Verma <navneev@amazon.com>
- Loading branch information
Showing
9 changed files
with
249 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.Getter; | ||
import lombok.RequiredArgsConstructor; | ||
import org.apache.lucene.index.FieldInfo; | ||
import org.apache.lucene.index.LeafReader; | ||
import org.opensearch.knn.index.quantizationservice.QuantizationService; | ||
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; | ||
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* This class encapsulate the necessary details to do the quantization of the vectors present in a lucene segment. | ||
*/ | ||
@Getter | ||
@RequiredArgsConstructor(access = AccessLevel.PRIVATE) | ||
public class SegmentLevelQuantizationInfo { | ||
private final QuantizationParams quantizationParams; | ||
private final QuantizationState quantizationState; | ||
|
||
/** | ||
* A builder like function to build the {@link SegmentLevelQuantizationInfo} | ||
* @param leafReader {@link LeafReader} | ||
* @param fieldInfo {@link FieldInfo} | ||
* @param fieldName {@link String} | ||
* @return {@link SegmentLevelQuantizationInfo} | ||
* @throws IOException exception while creating the {@link SegmentLevelQuantizationInfo} object. | ||
*/ | ||
public static SegmentLevelQuantizationInfo build(final LeafReader leafReader, final FieldInfo fieldInfo, final String fieldName) | ||
throws IOException { | ||
final QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo); | ||
if (quantizationParams == null) { | ||
return null; | ||
} | ||
final QuantizationState quantizationState = SegmentLevelQuantizationUtil.getQuantizationState(leafReader, fieldName); | ||
return new SegmentLevelQuantizationInfo(quantizationParams, quantizationState); | ||
} | ||
|
||
} |
60 changes: 60 additions & 0 deletions
60
src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationUtil.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query; | ||
|
||
import lombok.experimental.UtilityClass; | ||
import org.apache.lucene.index.LeafReader; | ||
import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector; | ||
import org.opensearch.knn.index.quantizationservice.QuantizationService; | ||
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; | ||
|
||
import java.io.IOException; | ||
import java.util.Locale; | ||
|
||
/** | ||
* A utility class for doing Quantization related operation at a segment level. We can move this utility in {@link SegmentLevelQuantizationInfo} | ||
* but I am keeping it thinking that {@link SegmentLevelQuantizationInfo} free from these utility functions to reduce | ||
* the responsibilities of {@link SegmentLevelQuantizationInfo} class. | ||
*/ | ||
@UtilityClass | ||
public class SegmentLevelQuantizationUtil { | ||
|
||
/** | ||
* A simple function to convert a vector to a quantized vector for a segment. | ||
* @param vector array of float | ||
* @return array of byte | ||
*/ | ||
@SuppressWarnings("unchecked") | ||
public static byte[] quantizeVector(final float[] vector, final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo) { | ||
if (segmentLevelQuantizationInfo == null) { | ||
return null; | ||
} | ||
final QuantizationService quantizationService = QuantizationService.getInstance(); | ||
// TODO: We are converting the output of Quantize to byte array for now. But this needs to be fixed when | ||
// other types of quantized outputs are returned like float[]. | ||
return (byte[]) quantizationService.quantize( | ||
segmentLevelQuantizationInfo.getQuantizationState(), | ||
vector, | ||
quantizationService.createQuantizationOutput(segmentLevelQuantizationInfo.getQuantizationParams()) | ||
); | ||
} | ||
|
||
/** | ||
* A utility function to get {@link QuantizationState} for a given segment and field. | ||
* @param leafReader {@link LeafReader} | ||
* @param fieldName {@link String} | ||
* @return {@link QuantizationState} | ||
* @throws IOException exception during reading the {@link QuantizationState} | ||
*/ | ||
static QuantizationState getQuantizationState(final LeafReader leafReader, String fieldName) throws IOException { | ||
final QuantizationConfigKNNCollector tempCollector = new QuantizationConfigKNNCollector(); | ||
leafReader.searchNearestVectors(fieldName, new float[0], tempCollector, null); | ||
if (tempCollector.getQuantizationState() == null) { | ||
throw new IllegalStateException(String.format(Locale.ROOT, "No quantization state found for field %s", fieldName)); | ||
} | ||
return tempCollector.getQuantizationState(); | ||
} | ||
} |
Oops, something went wrong.