Skip to content

Commit

Permalink
Use Faiss SQ for quantizing vectors
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
  • Loading branch information
naveentatikonda committed Jan 15, 2025
1 parent 4554912 commit fee1662
Show file tree
Hide file tree
Showing 15 changed files with 238 additions and 68 deletions.
15 changes: 7 additions & 8 deletions src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -185,21 +185,20 @@ List<EngineFileContext> getEngineFileContexts(IndexReader indexReader, KNNEngine
String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue());
SpaceType spaceType = SpaceType.getSpace(spaceTypeName);
String modelId = fieldInfo.attributes().getOrDefault(MODEL_ID, null);
QuantizationConfig quantizationConfig = FieldInfoExtractor.extractQuantizationConfig(fieldInfo);
engineFiles.addAll(
getEngineFileContexts(
reader.getSegmentInfo(),
fieldInfo.name,
fileExtension,
spaceType,
modelId,
FieldInfoExtractor.extractQuantizationConfig(fieldInfo) == QuantizationConfig.EMPTY
? VectorDataType.get(
fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue())
)
: FieldInfoExtractor.extractQuantizationConfig(fieldInfo)
.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT
? VectorDataType.BYTE
: VectorDataType.BINARY
quantizationConfig == QuantizationConfig.EMPTY
|| quantizationConfig.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT
? VectorDataType.get(
fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue())
)
: VectorDataType.BINARY
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
Expand Down Expand Up @@ -89,8 +88,8 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr
case TWO_BIT:
case FOUR_BIT:
return MultiBitScalarQuantizationState.fromByteArray(stateBytes);
case EIGHT_BIT:
return ByteScalarQuantizationState.fromByteArray(stateBytes);
// case EIGHT_BIT:
// return ByteScalarQuantizationState.fromByteArray(stateBytes);
default:
throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
Expand Down Expand Up @@ -255,10 +257,16 @@ private QuantizationState train(
final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
QuantizationState quantizationState = null;
if (quantizationParams != null && totalLiveDocs > 0) {
initQuantizationStateWriterIfNecessary();
KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs);
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
if ((quantizationParams.getTypeIdentifier()).equals(
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.EIGHT_BIT)
)) {
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs, fieldInfo);
} else {
initQuantizationStateWriterIfNecessary();
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs, fieldInfo);
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
}
}

return quantizationState;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState;

import java.io.IOException;
import java.security.AccessController;
Expand Down Expand Up @@ -77,14 +78,14 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept
long vectorAddress = vectorTransfer.getVectorAddress();
// Currently this is if else as there are only two cases, with more cases this will have to be made
// more maintainable
if (params.containsKey(MODEL_ID)) {
if (params.containsKey(MODEL_ID) || (indexInfo.getQuantizationState() instanceof ByteScalarQuantizationState)) {
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
JNIService.createIndexFromTemplate(
intListToArray(transferredDocIds),
vectorAddress,
indexBuildSetup.getDimensions(),
indexInfo.getIndexOutputWithBuffer(),
(byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER),
getIndexTemplate(params, indexInfo),
params,
indexInfo.getKnnEngine()
);
Expand Down Expand Up @@ -112,4 +113,13 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept
);
}
}

private byte[] getIndexTemplate(Map<String, Object> params, BuildIndexParams indexInfo) {
if (params.containsKey(MODEL_ID)) {
return (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER);
}

ByteScalarQuantizationState byteSQState = (ByteScalarQuantizationState) indexInfo.getQuantizationState();
return byteSQState.getIndexTemplate();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelCache;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
Expand Down Expand Up @@ -312,11 +313,20 @@ private static NativeIndexWriter createWriter(
@Nullable final QuantizationState quantizationState
) {
final KNNEngine knnEngine = extractKNNEngine(fieldInfo);
boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID);
boolean iterative = !isTemplate && KNNEngine.FAISS == knnEngine;
boolean iterative = !isTemplate(fieldInfo) && KNNEngine.FAISS == knnEngine;
NativeIndexBuildStrategy strategy = iterative
? MemOptimizedNativeIndexBuildStrategy.getInstance()
: DefaultIndexBuildStrategy.getInstance();
return new NativeIndexWriter(state, fieldInfo, strategy, quantizationState);
}

private static boolean isTemplate(FieldInfo fieldInfo) {
if (fieldInfo.attributes().containsKey(MODEL_ID)) {
return true;
}

QuantizationConfig quantizationConfig = FieldInfoExtractor.extractQuantizationConfig(fieldInfo);
return quantizationConfig != QuantizationConfig.EMPTY
&& quantizationConfig.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
Expand Down Expand Up @@ -58,7 +59,7 @@ static IndexBuildSetup prepareIndexBuild(KNNVectorValues<?> knnVectorValues, Bui
int bytesPerVector;
int dimensions;

if (quantizationState != null) {
if (quantizationState != null && !(quantizationState instanceof ByteScalarQuantizationState)) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationService.createQuantizationOutput(quantizationState.getQuantizationParams());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,29 +109,35 @@ static KNNLibraryIndexingContext adjustIndexDescription(
if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) {
prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
}
if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BYTE
|| (encoderContext != null
&& Objects.equals(encoderContext.getName(), ENCODER_SQ)
&& Objects.equals(
encoderContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16),
FAISS_SQ_ENCODER_INT8
))) {

if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BYTE) {

// If VectorDataType is Byte using Faiss engine then manipulate Index Description to use "SQ8_direct_signed" scalar quantizer
// For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed"
String indexDescription = methodAsMapBuilder.indexDescription;
if (StringUtils.isNotEmpty(indexDescription)) {
StringBuilder indexDescriptionBuilder = new StringBuilder();
indexDescriptionBuilder.append(indexDescription.split(",")[0]);
indexDescriptionBuilder.append(",");
indexDescriptionBuilder.append(FAISS_SIGNED_BYTE_SQ);
methodAsMapBuilder.indexDescription = indexDescriptionBuilder.toString();
}
methodAsMapBuilder.indexDescription = updateIndexDescription(methodAsMapBuilder.indexDescription, FAISS_SIGNED_BYTE_SQ);
}

if (encoderContext != null
&& Objects.equals(encoderContext.getName(), ENCODER_SQ)
&& Objects.equals(encoderContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16), FAISS_SQ_ENCODER_INT8)) {
methodAsMapBuilder.indexDescription = updateIndexDescription(methodAsMapBuilder.indexDescription, "SQ8");
}
methodAsMapBuilder.indexDescription = prefix + methodAsMapBuilder.indexDescription;
return methodAsMapBuilder.build();
}

private static String updateIndexDescription(String indexDescription, String indexDescriptionName) {
if (StringUtils.isEmpty(indexDescription)) {
return indexDescription;
}

StringBuilder indexDescriptionBuilder = new StringBuilder();
indexDescriptionBuilder.append(indexDescription.split(",")[0]);
indexDescriptionBuilder.append(",");
indexDescriptionBuilder.append(indexDescriptionName);
return indexDescriptionBuilder.toString();
}

static MethodComponentContext getEncoderMethodComponent(MethodComponentContext methodComponentContext) {
if (!methodComponentContext.getParameters().containsKey(METHOD_ENCODER_PARAMETER)) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.factory.QuantizerFactory;
import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationOutput.ByteQuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.quantizer.ByteScalarQuantizer;
import org.opensearch.knn.quantization.quantizer.Quantizer;
import java.io.IOException;

Expand Down Expand Up @@ -62,14 +62,18 @@ public static <T, R> QuantizationService<T, R> getInstance() {
public QuantizationState train(
final QuantizationParams quantizationParams,
final KNNVectorValues<T> knnVectorValues,
final long liveDocs
final long liveDocs,
final FieldInfo fieldInfo
) throws IOException {
Quantizer<T, R> quantizer = QuantizerFactory.getQuantizer(quantizationParams);

// Create the training request from the vector values
KNNVectorQuantizationTrainingRequest<T> trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues, liveDocs);

// Train the quantizer and return the quantization state
if (quantizer instanceof ByteScalarQuantizer) {
return quantizer.train(trainingRequest, fieldInfo);
}
return quantizer.train(trainingRequest);
}

Expand Down Expand Up @@ -111,7 +115,7 @@ public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) {
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
if (quantizationConfig != QuantizationConfig.EMPTY
&& quantizationConfig.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT) {
return VectorDataType.BYTE;
return VectorDataType.FLOAT;
}
if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) {
return VectorDataType.BINARY;
Expand All @@ -130,9 +134,9 @@ public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) {
public QuantizationOutput<R> createQuantizationOutput(final QuantizationParams quantizationParams) {
if (quantizationParams instanceof ScalarQuantizationParams) {
ScalarQuantizationParams scalarParams = (ScalarQuantizationParams) quantizationParams;
if (scalarParams.getSqType() == ScalarQuantizationType.EIGHT_BIT) {
return (QuantizationOutput<R>) new ByteQuantizationOutput(scalarParams.getSqType().getId());
}
// if (scalarParams.getSqType() == ScalarQuantizationType.EIGHT_BIT) {
// return (QuantizationOutput<R>) new ByteQuantizationOutput(scalarParams.getSqType().getId());
// }
return (QuantizationOutput<R>) new BinaryQuantizationOutput(scalarParams.getSqType().getId());
}
throw new IllegalArgumentException("Unsupported quantization parameters: " + quantizationParams.getClass().getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ private boolean isExactSearchThresholdSettingSet(int filterThresholdValue) {
private boolean isExactSearchRequire(final LeafReaderContext context, final int filterIdsCount, final int annResultCount) {
if (annResultCount == 0 && isMissingNativeEngineFiles(context)) {
log.debug("Perform exact search after approximate search since no native engine files are available");
log.info("Perform exact search after approximate search since no native engine files are available");
return true;
}
if (isFilteredExactSearchRequireAfterANNSearch(filterIdsCount, annResultCount)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
Expand All @@ -36,7 +38,10 @@ public class SegmentLevelQuantizationInfo {
public static SegmentLevelQuantizationInfo build(final LeafReader leafReader, final FieldInfo fieldInfo, final String fieldName)
throws IOException {
final QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo);
if (quantizationParams == null) {
if (quantizationParams == null
|| (quantizationParams.getTypeIdentifier()).equals(
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.EIGHT_BIT)
)) {
return null;
}
final QuantizationState quantizationState = SegmentLevelQuantizationUtil.getQuantizationState(leafReader, fieldName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
@AllArgsConstructor
public class ByteScalarQuantizationState implements QuantizationState {
private ScalarQuantizationParams quantizationParams;
private float[] min;
private float[] diff;
private byte[] indexTemplate;

@Override
public QuantizationParams getQuantizationParams() {
Expand All @@ -34,15 +33,14 @@ public QuantizationParams getQuantizationParams() {
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(Version.CURRENT.id); // Write the version
quantizationParams.writeTo(out);
out.writeFloatArray(min);
out.writeFloatArray(diff);
out.writeByteArray(indexTemplate);
}

public ByteScalarQuantizationState(StreamInput in) throws IOException {
int version = in.readVInt(); // Read the version
this.quantizationParams = new ScalarQuantizationParams(in, version);
this.min = in.readFloatArray();
this.diff = in.readFloatArray();
this.indexTemplate = in.readByteArray();

}

@Override
Expand All @@ -56,20 +54,19 @@ public static ByteScalarQuantizationState fromByteArray(final byte[] bytes) thro

@Override
public int getBytesPerVector() {
return min.length;
return 0;
}

@Override
public int getDimensions() {
return min.length;
return 0;
}

@Override
public long ramBytesUsed() {
long size = RamUsageEstimator.shallowSizeOfInstance(ByteScalarQuantizationState.class);
size += RamUsageEstimator.shallowSizeOf(quantizationParams);
size += RamUsageEstimator.sizeOf(min);
size += RamUsageEstimator.sizeOf(diff);
size += RamUsageEstimator.sizeOf(indexTemplate);
return size;
}
}
Loading

0 comments on commit fee1662

Please sign in to comment.