Skip to content

Commit

Permalink
Fixing filtered vector search when quantization is applied (opensearc…
Browse files Browse the repository at this point in the history
…h-project#2076)

Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v authored Sep 9, 2024
1 parent ce735c4 commit 524dbd0
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 74 deletions.
86 changes: 60 additions & 26 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.knn.index.query;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Value;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
Expand Down Expand Up @@ -42,27 +44,18 @@ public class ExactSearcher {
/**
* Execute an exact search on a subset of documents of a leaf
*
* @param leafReaderContext LeafReaderContext to be searched over
* @param matchedDocs matched documents
* @param knnQuery KNN Query
* @param k number of results to return
* @param isParentHits whether the matchedDocs contains parent ids or child ids. This is relevant in the case of
* filtered nested search where the matchedDocs contain the parent ids and {@link NestedFilteredIdsKNNIterator}
* needs to be used.
* @param leafReaderContext {@link LeafReaderContext}
* @param exactSearcherContext {@link ExactSearcherContext}
* @return Map of re-scored results
* @throws IOException exception during execution of exact search
*/
public Map<Integer, Float> searchLeaf(
final LeafReaderContext leafReaderContext,
final BitSet matchedDocs,
final KNNQuery knnQuery,
int k,
boolean isParentHits
) throws IOException {
KNNIterator iterator = getMatchedKNNIterator(leafReaderContext, matchedDocs, knnQuery, isParentHits);
if (matchedDocs.cardinality() <= k) {
public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext)
throws IOException {
KNNIterator iterator = getMatchedKNNIterator(leafReaderContext, exactSearcherContext);
if (exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
}
return searchTopK(iterator, k);
return searchTopK(iterator, exactSearcherContext.getK());
}

private Map<Integer, Float> scoreAllDocs(KNNIterator iterator) throws IOException {
Expand Down Expand Up @@ -105,17 +98,15 @@ private Map<Integer, Float> searchTopK(KNNIterator iterator, int k) throws IOExc
return docToScore;
}

private KNNIterator getMatchedKNNIterator(
final LeafReaderContext leafReaderContext,
final BitSet matchedDocs,
KNNQuery knnQuery,
boolean isParentHits
) throws IOException {
private KNNIterator getMatchedKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext)
throws IOException {
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);

boolean isNestedRequired = isParentHits && knnQuery.getParentsFilter() != null;
boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null;

if (VectorDataType.BINARY == knnQuery.getVectorDataType() && isNestedRequired) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
Expand All @@ -137,6 +128,17 @@ private KNNIterator getMatchedKNNIterator(
spaceType
);
}
final byte[] quantizedQueryVector;
final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo;
if (exactSearcherContext.isUseQuantizedVectorsForSearch()) {
// Build Segment Level Quantization info.
segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(reader, fieldInfo, knnQuery.getField());
// Quantize the Query Vector Once.
quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector(knnQuery.getQueryVector(), segmentLevelQuantizationInfo);
} else {
segmentLevelQuantizationInfo = null;
quantizedQueryVector = null;
}

final KNNVectorValues<float[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
Expand All @@ -145,10 +147,42 @@ private KNNIterator getMatchedKNNIterator(
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
knnQuery.getParentsFilter().getBitSet(leafReaderContext),
quantizedQueryVector,
segmentLevelQuantizationInfo
);
}

return new FilteredIdsKNNIterator(matchedDocs, knnQuery.getQueryVector(), (KNNFloatVectorValues) vectorValues, spaceType);
return new FilteredIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
spaceType,
quantizedQueryVector,
segmentLevelQuantizationInfo
);
}

/**
* Stores the context that is used to do the exact search. This class will help in reducing the explosion of attributes
* for doing exact search.
*/
@Value
@Builder
public static class ExactSearcherContext {
/**
* controls whether we should use Quantized vectors during exact search or not. This is useful because when we do
* re-scoring we need to re-score using full precision vectors and not quantized vectors.
*/
boolean useQuantizedVectorsForSearch;
int k;
BitSet matchedDocs;
KNNQuery knnQuery;
/**
* whether the matchedDocs contains parent ids or child ids. This is relevant in the case of
* filtered nested search where the matchedDocs contain the parent ids and {@link NestedFilteredIdsKNNIterator}
* needs to be used.
*/
boolean isParentHits;
}
}
62 changes: 25 additions & 37 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
Expand All @@ -39,8 +38,6 @@
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;

import java.io.IOException;
import java.nio.file.Path;
Expand Down Expand Up @@ -140,8 +137,17 @@ public Map<Integer, Float> searchLeaf(LeafReaderContext context, int k) throws I
* This improves the recall.
*/
Map<Integer, Float> docIdsToScoreMap;
final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder()
.k(k)
.isParentHits(true)
.matchedDocs(filterBitSet)
// setting to true, so that if quantization details are present we want to do search on the quantized
// vectors as this flow is used in first pass of search.
.useQuantizedVectorsForSearch(true)
.knnQuery(knnQuery)
.build();
if (filterWeight != null && canDoExactSearch(cardinality)) {
docIdsToScoreMap = exactSearch(context, filterBitSet, true, k);
docIdsToScoreMap = exactSearch(context, exactSearcherContext);
} else {
docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k);
if (docIdsToScoreMap == null) {
Expand All @@ -155,7 +161,7 @@ public Map<Integer, Float> searchLeaf(LeafReaderContext context, int k) throws I
docIdsToScoreMap.size(),
cardinality
);
docIdsToScoreMap = exactSearch(context, filterBitSet, true, k);
docIdsToScoreMap = exactSearch(context, exactSearcherContext);
}
}
if (docIdsToScoreMap.isEmpty()) {
Expand Down Expand Up @@ -258,10 +264,13 @@ private Map<Integer, Float> doANNSearch(
);
}

QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);

final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(
reader,
fieldInfo,
knnQuery.getField()
);
// TODO: Change type of vector once more quantization methods are supported
byte[] quantizedVector = getQuantizedVector(quantizationParams, reader, fieldInfo);
final byte[] quantizedVector = SegmentLevelQuantizationUtil.quantizeVector(knnQuery.getQueryVector(), segmentLevelQuantizationInfo);

List<String> engineFiles = getEngineFiles(reader, knnEngine.getExtension());
if (engineFiles.isEmpty()) {
Expand All @@ -285,7 +294,7 @@ private Map<Integer, Float> doANNSearch(
knnEngine,
knnQuery.getIndexName(),
// TODO: In the future, more vector data types will be supported with quantization
quantizationParams == null ? vectorDataType : VectorDataType.BINARY
quantizedVector == null ? vectorDataType : VectorDataType.BINARY
),
knnQuery.getIndexName(),
modelId
Expand All @@ -310,11 +319,11 @@ private Map<Integer, Float> doANNSearch(
int[] parentIds = getParentIdsArray(context);
if (k > 0) {
if (knnQuery.getVectorDataType() == VectorDataType.BINARY
|| quantizationParams != null && quantizationService.getVectorDataTypeForTransfer(fieldInfo) == VectorDataType.BINARY) {
|| quantizedVector != null && quantizationService.getVectorDataTypeForTransfer(fieldInfo) == VectorDataType.BINARY) {
results = JNIService.queryBinaryIndex(
indexAllocation.getMemoryAddress(),
// TODO: In the future, quantizedVector can have other data types than byte
quantizationParams == null ? knnQuery.getByteQueryVector() : quantizedVector,
quantizedVector == null ? knnQuery.getByteQueryVector() : quantizedVector,
k,
knnQuery.getMethodParameters(),
knnEngine,
Expand Down Expand Up @@ -391,16 +400,14 @@ List<String> getEngineFiles(SegmentReader reader, String extension) throws IOExc
/**
* Execute exact search for the given matched doc ids and return the results as a map of docId to score.
*
* @param leafReaderContext The leaf reader context for the current segment.
* @param matchSet The filterIds to search for.
* @param isParentHits Whether the matchedDocs contains parent ids or child ids.
* @param k The number of results to return.
* @return Map of docId to score for the exact search results.
* @throws IOException If an error occurs during the search.
*/
public Map<Integer, Float> exactSearch(final LeafReaderContext leafReaderContext, final BitSet matchSet, boolean isParentHits, int k)
throws IOException {
return exactSearcher.searchLeaf(leafReaderContext, matchSet, knnQuery, k, isParentHits);
public Map<Integer, Float> exactSearch(
final LeafReaderContext leafReaderContext,
final ExactSearcher.ExactSearcherContext exactSearcherContext
) throws IOException {
return exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext);
}

@Override
Expand Down Expand Up @@ -462,23 +469,4 @@ private boolean isExactSearchThresholdSettingSet(int filterThresholdValue) {
private boolean canDoExactSearchAfterANNSearch(final int filterIdsCount, final int annResultCount) {
return filterWeight != null && filterIdsCount >= knnQuery.getK() && knnQuery.getK() > annResultCount;
}

// TODO: this will eventually return more types than just byte
private byte[] getQuantizedVector(QuantizationParams quantizationParams, SegmentReader reader, FieldInfo fieldInfo) throws IOException {
if (quantizationParams != null) {
QuantizationConfigKNNCollector tempCollector = new QuantizationConfigKNNCollector();
reader.searchNearestVectors(knnQuery.getField(), new float[0], tempCollector, null);
if (tempCollector.getQuantizationState() == null) {
throw new IllegalStateException(String.format("No quantization state found for field %s", fieldInfo.getName()));
}
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(quantizationParams);
// TODO: In the future, byte array will not be the only output type from this method
return (byte[]) quantizationService.quantize(
tempCollector.getQuantizationState(),
knnQuery.getQueryVector(),
quantizationOutput
);
}
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;

/**
* This class encapsulate the necessary details to do the quantization of the vectors present in a lucene segment.
*/
@Getter
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public class SegmentLevelQuantizationInfo {
private final QuantizationParams quantizationParams;
private final QuantizationState quantizationState;

/**
* A builder like function to build the {@link SegmentLevelQuantizationInfo}
* @param leafReader {@link LeafReader}
* @param fieldInfo {@link FieldInfo}
* @param fieldName {@link String}
* @return {@link SegmentLevelQuantizationInfo}
* @throws IOException exception while creating the {@link SegmentLevelQuantizationInfo} object.
*/
public static SegmentLevelQuantizationInfo build(final LeafReader leafReader, final FieldInfo fieldInfo, final String fieldName)
throws IOException {
final QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo);
if (quantizationParams == null) {
return null;
}
final QuantizationState quantizationState = SegmentLevelQuantizationUtil.getQuantizationState(leafReader, fieldName);
return new SegmentLevelQuantizationInfo(quantizationParams, quantizationState);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query;

import lombok.experimental.UtilityClass;
import org.apache.lucene.index.LeafReader;
import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.util.Locale;

/**
* A utility class for doing Quantization related operation at a segment level. We can move this utility in {@link SegmentLevelQuantizationInfo}
* but I am keeping it thinking that {@link SegmentLevelQuantizationInfo} free from these utility functions to reduce
* the responsibilities of {@link SegmentLevelQuantizationInfo} class.
*/
@UtilityClass
public class SegmentLevelQuantizationUtil {

/**
* A simple function to convert a vector to a quantized vector for a segment.
* @param vector array of float
* @return array of byte
*/
@SuppressWarnings("unchecked")
public static byte[] quantizeVector(final float[] vector, final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo) {
if (segmentLevelQuantizationInfo == null) {
return null;
}
final QuantizationService quantizationService = QuantizationService.getInstance();
// TODO: We are converting the output of Quantize to byte array for now. But this needs to be fixed when
// other types of quantized outputs are returned like float[].
return (byte[]) quantizationService.quantize(
segmentLevelQuantizationInfo.getQuantizationState(),
vector,
quantizationService.createQuantizationOutput(segmentLevelQuantizationInfo.getQuantizationParams())
);
}

/**
* A utility function to get {@link QuantizationState} for a given segment and field.
* @param leafReader {@link LeafReader}
* @param fieldName {@link String}
* @return {@link QuantizationState}
* @throws IOException exception during reading the {@link QuantizationState}
*/
static QuantizationState getQuantizationState(final LeafReader leafReader, String fieldName) throws IOException {
final QuantizationConfigKNNCollector tempCollector = new QuantizationConfigKNNCollector();
leafReader.searchNearestVectors(fieldName, new float[0], tempCollector, null);
if (tempCollector.getQuantizationState() == null) {
throw new IllegalStateException(String.format(Locale.ROOT, "No quantization state found for field %s", fieldName));
}
return tempCollector.getQuantizationState();
}
}
Loading

0 comments on commit 524dbd0

Please sign in to comment.