Skip to content

Commit

Permalink
Integration of Quantization Framework for Binary Quantization with In…
Browse files Browse the repository at this point in the history
…dexing Flow (opensearch-project#1996)

* Integration of Quantization Framework for Binary Quantization with Indexing Flow

Signed-off-by: VIKASH TIWARI <viktari@amazon.com>

* Integration With Qunatization Config

Signed-off-by: VIKASH TIWARI <viktari@amazon.com>

---------

Signed-off-by: VIKASH TIWARI <viktari@amazon.com>
  • Loading branch information
Vikasht34 authored Aug 26, 2024
1 parent 59c312b commit bbaaaf9
Show file tree
Hide file tree
Showing 31 changed files with 1,414 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
import org.apache.lucene.index.Sorter;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -46,6 +49,7 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter {
private final FlatVectorsWriter flatVectorsWriter;
private final List<NativeEngineFieldVectorsWriter<?>> fields = new ArrayList<>();
private boolean finished;
private final QuantizationService quantizationService = QuantizationService.getInstance();

/**
* Add new field for indexing.
Expand All @@ -68,42 +72,24 @@ public KnnFieldVectorsWriter<?> addField(final FieldInfo fieldInfo) throws IOExc
*/
@Override
public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
// simply write data in the flat file
flatVectorsWriter.flush(maxDoc, sortMap);
for (final NativeEngineFieldVectorsWriter<?> field : fields) {
final VectorDataType vectorDataType = extractVectorDataType(field.getFieldInfo());
final KNNVectorValues<?> knnVectorValues = KNNVectorValuesFactory.getVectorValues(
vectorDataType,
field.getDocsWithField(),
field.getVectors()
trainAndIndex(
field.getFieldInfo(),
(vectorDataType, fieldInfo, fieldVectorsWriter) -> getKNNVectorValues(vectorDataType, fieldVectorsWriter),
NativeIndexWriter::flushIndex,
field
);

NativeIndexWriter.getWriter(field.getFieldInfo(), segmentWriteState).flushIndex(knnVectorValues);
}
}

@Override
public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState) throws IOException {
// This will ensure that we are merging the FlatIndex during force merge.
flatVectorsWriter.mergeOneField(fieldInfo, mergeState);

// For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
final KNNVectorValues<?> knnVectorValues;
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
final FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats);
break;
case BYTE:
final ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes);
break;
default:
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
}
trainAndIndex(fieldInfo, this::getKNNVectorValuesForMerge, NativeIndexWriter::mergeIndex, mergeState);

NativeIndexWriter.getWriter(fieldInfo, segmentWriteState).mergeIndex(knnVectorValues);
}

/**
Expand Down Expand Up @@ -146,4 +132,102 @@ public long ramBytesUsed() {
.sum();
}

/**
* Retrieves the {@link KNNVectorValues} for a specific field based on the vector data type and field writer.
*
* @param vectorDataType The {@link VectorDataType} representing the type of vectors stored.
* @param field The {@link NativeEngineFieldVectorsWriter} representing the field from which to retrieve vectors.
* @param <T> The type of vectors being processed.
* @return The {@link KNNVectorValues} associated with the field.
*/
private <T> KNNVectorValues<T> getKNNVectorValues(final VectorDataType vectorDataType, final NativeEngineFieldVectorsWriter<?> field) {
return (KNNVectorValues<T>) KNNVectorValuesFactory.getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors());
}

/**
* Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type.
*
* @param vectorDataType The {@link VectorDataType} representing the type of vectors stored.
* @param fieldInfo The {@link FieldInfo} object containing metadata about the field.
* @param mergeState The {@link MergeState} representing the state of the merge operation.
* @param <T> The type of vectors being processed.
* @return The {@link KNNVectorValues} associated with the field during the merge.
* @throws IOException If an I/O error occurs during the retrieval.
*/
private <T> KNNVectorValues<T> getKNNVectorValuesForMerge(
final VectorDataType vectorDataType,
final FieldInfo fieldInfo,
final MergeState mergeState
) throws IOException {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
return (KNNVectorValues<T>) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats);
case BYTE:
ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
return (KNNVectorValues<T>) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes);
default:
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
}
}

/**
* Functional interface representing an operation that indexes the provided {@link KNNVectorValues}.
*
* @param <T> The type of vectors being processed.
*/
@FunctionalInterface
private interface IndexOperation<T> {
void buildAndWrite(NativeIndexWriter writer, KNNVectorValues<T> knnVectorValues) throws IOException;
}

/**
* Functional interface representing a method that retrieves {@link KNNVectorValues} based on
* the vector data type, field information, and the merge state.
*
* @param <DataType> The type of the data representing the vector (e.g., {@link VectorDataType}).
* @param <FieldInfo> The metadata about the field.
* @param <MergeState> The state of the merge operation.
* @param <Result> The result of the retrieval, typically {@link KNNVectorValues}.
*/
@FunctionalInterface
private interface VectorValuesRetriever<DataType, FieldInfo, MergeState, Result> {
Result apply(DataType vectorDataType, FieldInfo fieldInfo, MergeState mergeState) throws IOException;
}

/**
* Unified method for processing a field during either the indexing or merge operation. This method retrieves vector values
* based on the provided vector data type and applies the specified index operation, potentially including quantization if needed.
*
* @param fieldInfo The {@link FieldInfo} object containing metadata about the field.
* @param vectorValuesRetriever A functional interface that retrieves {@link KNNVectorValues} based on the vector data type,
* field information, and additional context (e.g., merge state or field writer).
* @param indexOperation A functional interface that performs the indexing operation using the retrieved
* {@link KNNVectorValues}.
* @param VectorProcessingContext The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}).
* From Flush we need NativeFieldWriter which contains total number of vectors while from Merge we need merge state which contains vector information
* @param <T> The type of vectors being processed.
* @param <C> The type of the context needed for retrieving the vector values.
* @throws IOException If an I/O error occurs during the processing.
*/
private <T, C> void trainAndIndex(
final FieldInfo fieldInfo,
final VectorValuesRetriever<VectorDataType, FieldInfo, C, KNNVectorValues<T>> vectorValuesRetriever,
final IndexOperation<T> indexOperation,
final C VectorProcessingContext
) throws IOException {
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
KNNVectorValues<T> knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
QuantizationState quantizationState = null;
if (quantizationParams != null) {
quantizationState = quantizationService.train(quantizationParams, knnVectorValues);
}
NativeIndexWriter writer = (quantizationParams != null)
? NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)
: NativeIndexWriter.getWriter(fieldInfo, segmentWriteState);

knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
indexOperation.buildAndWrite(writer, knnVectorValues);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,32 @@ public static DefaultIndexBuildStrategy getInstance() {
return INSTANCE;
}

/**
* Builds and writes a k-NN index using the provided vector values and index parameters. This method handles both
* quantized and non-quantized vectors, transferring them off-heap before building the index using native JNI services.
*
* <p>The method first iterates over the vector values to calculate the necessary bytes per vector. If quantization is
* enabled, the vectors are quantized before being transferred off-heap. Once all vectors are transferred, they are
* flushed and used to build the index. The index is then written to the specified path using JNI calls.</p>
*
* @param indexInfo The {@link BuildIndexParams} containing the parameters and configuration for building the index.
* @param knnVectorValues The {@link KNNVectorValues} representing the vectors to be indexed.
* @throws IOException If an I/O error occurs during the process of building and writing the index.
*/
public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException {
iterateVectorValuesOnce(knnVectorValues); // to get bytesPerVector
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / knnVectorValues.bytesPerVector());
// Needed to make sure we don't get 0 dimensions while initializing index
iterateVectorValuesOnce(knnVectorValues);
IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo);

int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {
final List<Integer> transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs());

final List<Integer> tranferredDocIds = new ArrayList<>();
while (knnVectorValues.docId() != NO_MORE_DOCS) {
Object vector = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, indexBuildSetup);
// append is true here so off heap memory buffer isn't overwritten
vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), true);
tranferredDocIds.add(knnVectorValues.docId());
vectorTransfer.transfer(vector, true);
transferredDocIds.add(knnVectorValues.docId());
knnVectorValues.nextDoc();
}
vectorTransfer.flush(true);
Expand All @@ -60,24 +76,24 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
if (params.containsKey(MODEL_ID)) {
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
JNIService.createIndexFromTemplate(
intListToArray(tranferredDocIds),
intListToArray(transferredDocIds),
vectorAddress,
knnVectorValues.dimension(),
indexBuildSetup.getDimensions(),
indexInfo.getIndexPath(),
(byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER),
indexInfo.getParameters(),
params,
indexInfo.getKnnEngine()
);
return null;
});
} else {
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
JNIService.createIndex(
intListToArray(tranferredDocIds),
intListToArray(transferredDocIds),
vectorAddress,
knnVectorValues.dimension(),
indexBuildSetup.getDimensions(),
indexInfo.getIndexPath(),
indexInfo.getParameters(),
params,
indexInfo.getKnnEngine()
);
return null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.nativeindex;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

/**
* IndexBuildSetup encapsulates the configuration and parameters required for building an index.
* This includes the size of each vector, the dimensions of the vectors, and any quantization-related
* settings such as the output and state of quantization.
*/
@Getter
@AllArgsConstructor
final class IndexBuildSetup {
/**
* The number of bytes per vector.
*/
private final int bytesPerVector;

/**
* Dimension of Vector for Indexing
*/
private final int dimensions;

/**
* The quantization output that will hold the quantized vector.
*/
private final QuantizationOutput quantizationOutput;

/**
* The state of quantization, which may include parameters and trained models.
*/
private final QuantizationState quantizationState;
}
Loading

0 comments on commit bbaaaf9

Please sign in to comment.