From fee16622452009510f32ba204aa7f7697ddb722b Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Tue, 14 Jan 2025 11:20:37 -0600 Subject: [PATCH] Use Faiss SQ for quantizing vectors Signed-off-by: Naveen Tatikonda --- .../opensearch/knn/index/KNNIndexShard.java | 15 +- .../KNN990QuantizationStateReader.java | 5 +- .../NativeEngines990KnnVectorsWriter.java | 14 +- .../DefaultIndexBuildStrategy.java | 14 +- .../codec/nativeindex/NativeIndexWriter.java | 14 +- .../nativeindex/QuantizationIndexUtils.java | 3 +- .../engine/faiss/AbstractFaissMethod.java | 36 +++-- .../QuantizationService.java | 16 +- .../opensearch/knn/index/query/KNNWeight.java | 1 + .../query/SegmentLevelQuantizationInfo.java | 7 +- .../ByteScalarQuantizationState.java | 17 +- .../quantizer/ByteScalarQuantizer.java | 149 ++++++++++++++++-- .../quantizer/MultiBitScalarQuantizer.java | 6 + .../quantizer/OneBitScalarQuantizer.java | 6 + .../knn/quantization/quantizer/Quantizer.java | 3 + 15 files changed, 238 insertions(+), 68 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java index 38eb628e2..6aec3b5b4 100644 --- a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java +++ b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java @@ -185,6 +185,7 @@ List getEngineFileContexts(IndexReader indexReader, KNNEngine String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); SpaceType spaceType = SpaceType.getSpace(spaceTypeName); String modelId = fieldInfo.attributes().getOrDefault(MODEL_ID, null); + QuantizationConfig quantizationConfig = FieldInfoExtractor.extractQuantizationConfig(fieldInfo); engineFiles.addAll( getEngineFileContexts( reader.getSegmentInfo(), @@ -192,14 +193,12 @@ List getEngineFileContexts(IndexReader indexReader, KNNEngine fileExtension, spaceType, modelId, - FieldInfoExtractor.extractQuantizationConfig(fieldInfo) == QuantizationConfig.EMPTY - ? VectorDataType.get( - fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()) - ) - : FieldInfoExtractor.extractQuantizationConfig(fieldInfo) - .getQuantizationType() == ScalarQuantizationType.EIGHT_BIT - ? VectorDataType.BYTE - : VectorDataType.BINARY + quantizationConfig == QuantizationConfig.EMPTY + || quantizationConfig.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT + ? VectorDataType.get( + fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()) + ) + : VectorDataType.BINARY ) ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java index 07a676dda..707b64dee 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java @@ -15,7 +15,6 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; -import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; @@ -89,8 +88,8 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr case TWO_BIT: case FOUR_BIT: return MultiBitScalarQuantizationState.fromByteArray(stateBytes); - case EIGHT_BIT: - return ByteScalarQuantizationState.fromByteArray(stateBytes); + // case EIGHT_BIT: + // return ByteScalarQuantizationState.fromByteArray(stateBytes); default: throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType)); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 7c8636577..a49ea4e8b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -30,7 +30,9 @@ import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -255,10 +257,16 @@ private QuantizationState train( final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); QuantizationState quantizationState = null; if (quantizationParams != null && totalLiveDocs > 0) { - initQuantizationStateWriterIfNecessary(); KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); - quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs); - quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); + if ((quantizationParams.getTypeIdentifier()).equals( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.EIGHT_BIT) + )) { + quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs, fieldInfo); + } else { + initQuantizationStateWriterIfNecessary(); + quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs, fieldInfo); + quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); + } } return quantizationState; diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java index 23c3ba116..67d4c25f5 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -12,6 +12,7 @@ import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState; import java.io.IOException; import java.security.AccessController; @@ -77,14 +78,14 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept long vectorAddress = vectorTransfer.getVectorAddress(); // Currently this is if else as there are only two cases, with more cases this will have to be made // more maintainable - if (params.containsKey(MODEL_ID)) { + if (params.containsKey(MODEL_ID) || (indexInfo.getQuantizationState() instanceof ByteScalarQuantizationState)) { AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndexFromTemplate( intListToArray(transferredDocIds), vectorAddress, indexBuildSetup.getDimensions(), indexInfo.getIndexOutputWithBuffer(), - (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER), + getIndexTemplate(params, indexInfo), params, indexInfo.getKnnEngine() ); @@ -112,4 +113,13 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept ); } } + + private byte[] getIndexTemplate(Map params, BuildIndexParams indexInfo) { + if (params.containsKey(MODEL_ID)) { + return (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER); + } + + ByteScalarQuantizationState byteSQState = (ByteScalarQuantizationState) indexInfo.getQuantizationState(); + return byteSQState.getIndexTemplate(); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index 27a1ecfb6..e2c309ad1 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -32,6 +32,7 @@ import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -312,11 +313,20 @@ private static NativeIndexWriter createWriter( @Nullable final QuantizationState quantizationState ) { final KNNEngine knnEngine = extractKNNEngine(fieldInfo); - boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID); - boolean iterative = !isTemplate && KNNEngine.FAISS == knnEngine; + boolean iterative = !isTemplate(fieldInfo) && KNNEngine.FAISS == knnEngine; NativeIndexBuildStrategy strategy = iterative ? MemOptimizedNativeIndexBuildStrategy.getInstance() : DefaultIndexBuildStrategy.getInstance(); return new NativeIndexWriter(state, fieldInfo, strategy, quantizationState); } + + private static boolean isTemplate(FieldInfo fieldInfo) { + if (fieldInfo.attributes().containsKey(MODEL_ID)) { + return true; + } + + QuantizationConfig quantizationConfig = FieldInfoExtractor.extractQuantizationConfig(fieldInfo); + return quantizationConfig != QuantizationConfig.EMPTY + && quantizationConfig.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT; + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java index c5994d66b..745802dd3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java @@ -10,6 +10,7 @@ import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -58,7 +59,7 @@ static IndexBuildSetup prepareIndexBuild(KNNVectorValues knnVectorValues, Bui int bytesPerVector; int dimensions; - if (quantizationState != null) { + if (quantizationState != null && !(quantizationState instanceof ByteScalarQuantizationState)) { bytesPerVector = quantizationState.getBytesPerVector(); dimensions = quantizationState.getDimensions(); quantizationOutput = quantizationService.createQuantizationOutput(quantizationState.getQuantizationParams()); diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java index 3bc6087c8..d8b298df1 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -109,29 +109,35 @@ static KNNLibraryIndexingContext adjustIndexDescription( if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) { prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; } - if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BYTE - || (encoderContext != null - && Objects.equals(encoderContext.getName(), ENCODER_SQ) - && Objects.equals( - encoderContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16), - FAISS_SQ_ENCODER_INT8 - ))) { + + if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BYTE) { // If VectorDataType is Byte using Faiss engine then manipulate Index Description to use "SQ8_direct_signed" scalar quantizer // For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed" - String indexDescription = methodAsMapBuilder.indexDescription; - if (StringUtils.isNotEmpty(indexDescription)) { - StringBuilder indexDescriptionBuilder = new StringBuilder(); - indexDescriptionBuilder.append(indexDescription.split(",")[0]); - indexDescriptionBuilder.append(","); - indexDescriptionBuilder.append(FAISS_SIGNED_BYTE_SQ); - methodAsMapBuilder.indexDescription = indexDescriptionBuilder.toString(); - } + methodAsMapBuilder.indexDescription = updateIndexDescription(methodAsMapBuilder.indexDescription, FAISS_SIGNED_BYTE_SQ); + } + + if (encoderContext != null + && Objects.equals(encoderContext.getName(), ENCODER_SQ) + && Objects.equals(encoderContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16), FAISS_SQ_ENCODER_INT8)) { + methodAsMapBuilder.indexDescription = updateIndexDescription(methodAsMapBuilder.indexDescription, "SQ8"); } methodAsMapBuilder.indexDescription = prefix + methodAsMapBuilder.indexDescription; return methodAsMapBuilder.build(); } + private static String updateIndexDescription(String indexDescription, String indexDescriptionName) { + if (StringUtils.isEmpty(indexDescription)) { + return indexDescription; + } + + StringBuilder indexDescriptionBuilder = new StringBuilder(); + indexDescriptionBuilder.append(indexDescription.split(",")[0]); + indexDescriptionBuilder.append(","); + indexDescriptionBuilder.append(indexDescriptionName); + return indexDescriptionBuilder.toString(); + } + static MethodComponentContext getEncoderMethodComponent(MethodComponentContext methodComponentContext) { if (!methodComponentContext.getParameters().containsKey(METHOD_ENCODER_PARAMETER)) { return null; diff --git a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java index ad506e978..9ed18b1b6 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java +++ b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java @@ -14,11 +14,11 @@ import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.factory.QuantizerFactory; import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; -import org.opensearch.knn.quantization.models.quantizationOutput.ByteQuantizationOutput; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.quantizer.ByteScalarQuantizer; import org.opensearch.knn.quantization.quantizer.Quantizer; import java.io.IOException; @@ -62,7 +62,8 @@ public static QuantizationService getInstance() { public QuantizationState train( final QuantizationParams quantizationParams, final KNNVectorValues knnVectorValues, - final long liveDocs + final long liveDocs, + final FieldInfo fieldInfo ) throws IOException { Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationParams); @@ -70,6 +71,9 @@ public QuantizationState train( KNNVectorQuantizationTrainingRequest trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues, liveDocs); // Train the quantizer and return the quantization state + if (quantizer instanceof ByteScalarQuantizer) { + return quantizer.train(trainingRequest, fieldInfo); + } return quantizer.train(trainingRequest); } @@ -111,7 +115,7 @@ public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) { QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo); if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT) { - return VectorDataType.BYTE; + return VectorDataType.FLOAT; } if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) { return VectorDataType.BINARY; @@ -130,9 +134,9 @@ public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) { public QuantizationOutput createQuantizationOutput(final QuantizationParams quantizationParams) { if (quantizationParams instanceof ScalarQuantizationParams) { ScalarQuantizationParams scalarParams = (ScalarQuantizationParams) quantizationParams; - if (scalarParams.getSqType() == ScalarQuantizationType.EIGHT_BIT) { - return (QuantizationOutput) new ByteQuantizationOutput(scalarParams.getSqType().getId()); - } + // if (scalarParams.getSqType() == ScalarQuantizationType.EIGHT_BIT) { + // return (QuantizationOutput) new ByteQuantizationOutput(scalarParams.getSqType().getId()); + // } return (QuantizationOutput) new BinaryQuantizationOutput(scalarParams.getSqType().getId()); } throw new IllegalArgumentException("Unsupported quantization parameters: " + quantizationParams.getClass().getName()); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 24fa16a59..ddbb7b0dd 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -459,6 +459,7 @@ private boolean isExactSearchThresholdSettingSet(int filterThresholdValue) { private boolean isExactSearchRequire(final LeafReaderContext context, final int filterIdsCount, final int annResultCount) { if (annResultCount == 0 && isMissingNativeEngineFiles(context)) { log.debug("Perform exact search after approximate search since no native engine files are available"); + log.info("Perform exact search after approximate search since no native engine files are available"); return true; } if (isFilteredExactSearchRequireAfterANNSearch(filterIdsCount, annResultCount)) { diff --git a/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java b/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java index d25774cdc..7fd1f61e6 100644 --- a/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java +++ b/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java @@ -11,7 +11,9 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReader; import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -36,7 +38,10 @@ public class SegmentLevelQuantizationInfo { public static SegmentLevelQuantizationInfo build(final LeafReader leafReader, final FieldInfo fieldInfo, final String fieldName) throws IOException { final QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo); - if (quantizationParams == null) { + if (quantizationParams == null + || (quantizationParams.getTypeIdentifier()).equals( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.EIGHT_BIT) + )) { return null; } final QuantizationState quantizationState = SegmentLevelQuantizationUtil.getQuantizationState(leafReader, fieldName); diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/ByteScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/ByteScalarQuantizationState.java index 679e60325..44d0ebfd8 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/ByteScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/ByteScalarQuantizationState.java @@ -22,8 +22,7 @@ @AllArgsConstructor public class ByteScalarQuantizationState implements QuantizationState { private ScalarQuantizationParams quantizationParams; - private float[] min; - private float[] diff; + private byte[] indexTemplate; @Override public QuantizationParams getQuantizationParams() { @@ -34,15 +33,14 @@ public QuantizationParams getQuantizationParams() { public void writeTo(StreamOutput out) throws IOException { out.writeVInt(Version.CURRENT.id); // Write the version quantizationParams.writeTo(out); - out.writeFloatArray(min); - out.writeFloatArray(diff); + out.writeByteArray(indexTemplate); } public ByteScalarQuantizationState(StreamInput in) throws IOException { int version = in.readVInt(); // Read the version this.quantizationParams = new ScalarQuantizationParams(in, version); - this.min = in.readFloatArray(); - this.diff = in.readFloatArray(); + this.indexTemplate = in.readByteArray(); + } @Override @@ -56,20 +54,19 @@ public static ByteScalarQuantizationState fromByteArray(final byte[] bytes) thro @Override public int getBytesPerVector() { - return min.length; + return 0; } @Override public int getDimensions() { - return min.length; + return 0; } @Override public long ramBytesUsed() { long size = RamUsageEstimator.shallowSizeOfInstance(ByteScalarQuantizationState.class); size += RamUsageEstimator.shallowSizeOf(quantizationParams); - size += RamUsageEstimator.sizeOf(min); - size += RamUsageEstimator.sizeOf(diff); + size += RamUsageEstimator.sizeOf(indexTemplate); return size; } } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizer.java index b0c121ab3..c1b73f26a 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizer.java @@ -5,6 +5,19 @@ package org.opensearch.knn.quantization.quantizer; +import org.apache.lucene.index.FieldInfo; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; @@ -17,7 +30,15 @@ import oshi.util.tuples.Pair; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory.getVectorTransfer; public class ByteScalarQuantizer implements Quantizer { private final int bitsPerCoordinate; @@ -58,30 +79,124 @@ public ByteScalarQuantizer(final int bitsPerCoordinate) { // return new ByteScalarQuantizationState(params, minAndDiff.getA(), minAndDiff.getB()); // } + // Train using Quantile + // @Override + // public QuantizationState train(TrainingRequest trainingRequest) throws IOException { + // int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + // float[][] transposedVec = transposeVectors(trainingRequest, sampledIndices); + // Pair minAndDiff = calculateMinAndDiffUsingQuantile(transposedVec); + // ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT); + // return new ByteScalarQuantizationState(params, minAndDiff.getA(), minAndDiff.getB()); + // } + @Override - public QuantizationState train(TrainingRequest trainingRequest) throws IOException { - int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); - float[][] transposedVec = transposeVectors(trainingRequest, sampledIndices); - Pair minAndDiff = calculateMinAndDiffUsingQuantile(transposedVec); - ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT); - return new ByteScalarQuantizationState(params, minAndDiff.getA(), minAndDiff.getB()); + public QuantizationState train(final TrainingRequest trainingRequest) throws IOException { + return null; } @Override - public void quantize(float[] vector, QuantizationState state, QuantizationOutput output) { + public QuantizationState train(final TrainingRequest trainingRequest, final FieldInfo fieldInfo) throws IOException { + int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + if (sampledIndices.length == 0) { + return null; + } + float[] vector = trainingRequest.getVectorAtThePosition(sampledIndices[0]); if (vector == null) { - throw new IllegalArgumentException("Vector to quantize must not be null."); + throw new IllegalArgumentException("Vector at sampled index " + sampledIndices[0] + " is null."); + } + int dimension = vector.length; + byte[] indexTemplate = JNIService.trainIndex( + getParameters(fieldInfo), + dimension, + getVectorAddressOfTrainData(sampledIndices, fieldInfo, trainingRequest, dimension), + KNNEngine.FAISS + ); + + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT); + return new ByteScalarQuantizationState(params, indexTemplate); + } + + private long getVectorAddressOfTrainData( + int[] sampledIndices, + FieldInfo fieldInfo, + final TrainingRequest trainingRequest, + int dimension + ) throws IOException { + int totalSamples = sampledIndices.length; + + final OffHeapVectorTransfer vectorTransfer = getVectorTransfer( + extractVectorDataType(fieldInfo), + 4 * dimension, + sampledIndices.length + ); + final List transferredDocIds = new ArrayList<>(totalSamples); + for (int i = 0; i < totalSamples; i++) { + Object vectorToTransfer = trainingRequest.getVectorAtThePosition(sampledIndices[i]); + vectorTransfer.transfer(vectorToTransfer, true); } - validateState(state); - int vectorLength = vector.length; - ByteScalarQuantizationState byteSQState = (ByteScalarQuantizationState) state; - float[] minArray = byteSQState.getMin(); - float[] diffArray = byteSQState.getDiff(); - if (minArray == null || minArray.length != vectorLength || diffArray == null || diffArray.length != vectorLength) { - throw new IllegalArgumentException("min and diff arrays must not be null and must match the dimension of the vector."); + vectorTransfer.flush(true); + return vectorTransfer.getVectorAddress(); + } + + private Map getParameters(final FieldInfo fieldInfo) throws IOException { + Map parameters = new HashMap<>(); + Map fieldAttributes = fieldInfo.attributes(); + String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); + + // parametersString will be null when legacy mapper is used + if (parametersString == null) { + parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); + + Map algoParams = new HashMap<>(); + + String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); + if (efConstruction != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); + } + + String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); + if (m != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); + } + parameters.put(PARAMETERS, algoParams); + } else { + parameters.putAll( + XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parametersString), + MediaTypeRegistry.getDefaultMediaType() + ).map() + ); } - output.prepareQuantizedVector(vectorLength); - quantizeVector(vector, minArray, diffArray, output.getQuantizedVector()); + + parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT); + // parameters.put("rangestat", 1); + // parameters.put("rs", 1); + // parameters.put("rangestat_arg", 20); + // parameters.put("rs_arg", 10.5); + + // Used to determine how many threads to use when indexing + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + + return parameters; + } + + @Override + public void quantize(float[] vector, QuantizationState state, QuantizationOutput output) { + // if (vector == null) { + // throw new IllegalArgumentException("Vector to quantize must not be null."); + // } + // validateState(state); + // int vectorLength = vector.length; + // ByteScalarQuantizationState byteSQState = (ByteScalarQuantizationState) state; + // float[] minArray = byteSQState.getMin(); + // float[] diffArray = byteSQState.getDiff(); + // if (minArray == null || minArray.length != vectorLength || diffArray == null || diffArray.length != vectorLength) { + // throw new IllegalArgumentException("min and diff arrays must not be null and must match the dimension of the vector."); + // } + // output.prepareQuantizedVector(vectorLength); + // quantizeVector(vector, minArray, diffArray, output.getQuantizedVector()); } private void quantizeVector(final float[] vector, final float[] min, final float[] diff, byte[] quantizedVector) { diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java index 0bcc252d1..bc5022709 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java @@ -6,6 +6,7 @@ package org.opensearch.knn.quantization.quantizer; +import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; @@ -119,6 +120,11 @@ public QuantizationState train(final TrainingRequest trainingRequest) t return new MultiBitScalarQuantizationState(params, thresholds); } + @Override + public QuantizationState train(final TrainingRequest trainingRequest, final FieldInfo fieldInfo) throws IOException { + return null; + } + /** * Quantizes the provided vector using the provided quantization state, producing a quantized output. * The vector is quantized based on the thresholds in the quantization state. diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java index 3cba89c39..70ed81a83 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -5,6 +5,7 @@ package org.opensearch.knn.quantization.quantizer; +import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; @@ -64,6 +65,11 @@ public QuantizationState train(final TrainingRequest trainingRequest) t return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds); } + @Override + public QuantizationState train(final TrainingRequest trainingRequest, final FieldInfo fieldInfo) throws IOException { + return null; + } + /** * Quantizes the provided vector using the given quantization state. * It compares each dimension of the vector against the corresponding mean (threshold) to determine the quantized value. diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java index 521863205..1e343b5e8 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java @@ -5,6 +5,7 @@ package org.opensearch.knn.quantization.quantizer; +import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; @@ -31,6 +32,8 @@ public interface Quantizer { */ QuantizationState train(TrainingRequest trainingRequest) throws IOException; + QuantizationState train(TrainingRequest trainingRequest, FieldInfo fieldInfo) throws IOException; + /** * Quantizes the provided vector using the specified quantization state. *