From 62c8e29e8ffcbe9a6b8bcfe2256783b0b4fc3210 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Wed, 6 Mar 2024 21:02:56 -0600 Subject: [PATCH] Address Review Comments Signed-off-by: Naveen Tatikonda --- .../opensearch/knn/common/KNNConstants.java | 6 +- .../index/mapper/KNNVectorFieldMapper.java | 86 +++++++------------ .../mapper/KNNVectorFieldMapperUtil.java | 2 +- .../org/opensearch/knn/index/util/Faiss.java | 4 +- .../org/opensearch/knn/index/FaissIT.java | 48 +++++++++-- .../mapper/KNNVectorFieldMapperTests.java | 45 ++++++++++ 6 files changed, 124 insertions(+), 67 deletions(-) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 455870ef9..861b3f554 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -96,7 +96,7 @@ public class KNNConstants { public static final String FAISS_SQ_TYPE = "type"; public static final String FAISS_SQ_ENCODER_FP16 = "fp16"; public static final List FAISS_SQ_ENCODER_TYPES = List.of(FAISS_SQ_ENCODER_FP16); - public static final String FAISS_SQ_CLIP_TO_RANGE = "clip_to_range"; + public static final String FAISS_SQ_CLIP = "clip"; // Parameter defaults/limits public static final Integer ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT = 1; @@ -111,8 +111,8 @@ public class KNNConstants { public static final Integer MODEL_CACHE_CAPACITY_ATROPHY_THRESHOLD_IN_MINUTES = 30; public static final Integer MODEL_CACHE_EXPIRE_AFTER_ACCESS_TIME_MINUTES = 30; - public static final Integer FP16_MAX_VALUE = 65504; - public static final Integer FP16_MIN_VALUE = -65504; + public static final Float FP16_MAX_VALUE = 65504.0f; + public static final Float FP16_MIN_VALUE = -65504.0f; // Lib names private static final String JNI_LIBRARY_PREFIX = "opensearchknn_"; diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 3cb3e1034..46ed15322 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -55,7 +55,7 @@ import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP_TO_RANGE; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; @@ -526,11 +526,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx addStoredFieldForVectorField(context, fieldType, name(), point.toString()); } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional; - if (isFaissSQfp16(fieldType().getKnnMethodContext())) { - floatsArrayOptional = getFloatsFromContextWithFP16Validation(context, dimension); - } else { - floatsArrayOptional = getFloatsFromContext(context, dimension); - } + floatsArrayOptional = getFloatsFromContext(context, dimension); if (!floatsArrayOptional.isPresent()) { return; @@ -574,11 +570,11 @@ protected boolean isFaissSQfp16(KNNMethodContext knnMethodContext) { return false; } - // Verify mapping and return the value of "clip_to_range" parameter(default false) for a "faiss" Index + // Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index // using "sq" encoder of type "fp16". protected boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComponentContext) { if (methodComponentContext != null) { - return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP_TO_RANGE, false); + return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false); } return false; } @@ -635,60 +631,36 @@ Optional getBytesFromContext(ParseContext context, int dimension) throws Optional getFloatsFromContext(ParseContext context, int dimension) throws IOException { context.path().add(simpleName()); - ArrayList vector = new ArrayList<>(); - XContentParser.Token token = context.parser().currentToken(); - float value; - if (token == XContentParser.Token.START_ARRAY) { - token = context.parser().nextToken(); - while (token != XContentParser.Token.END_ARRAY) { - value = context.parser().floatValue(); - validateFloatVectorValue(value); - vector.add(value); - token = context.parser().nextToken(); - } - } else if (token == XContentParser.Token.VALUE_NUMBER) { - value = context.parser().floatValue(); - validateFloatVectorValue(value); - vector.add(value); - context.parser().nextToken(); - } else if (token == XContentParser.Token.VALUE_NULL) { - context.path().remove(); - return Optional.empty(); - } - validateVectorDimension(dimension, vector.size()); - - float[] array = new float[vector.size()]; - int i = 0; - for (Float f : vector) { - array[i++] = f; + // Returns an optional array of float values where each value in the vector is parsed as a float and validated + // if it is a finite number and within the fp16 range of [-65504 to 65504] by default if Faiss encoder is SQ and type is 'fp16'. + // If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be + // clipped to FP16 range. + boolean isFaissSQfp16Flag = isFaissSQfp16(fieldType().getKnnMethodContext()); + boolean clipVectorValueToFP16RangeFlag = false; + if (isFaissSQfp16Flag) { + clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled( + (MethodComponentContext) fieldType().getKnnMethodContext() + .getMethodComponentContext() + .getParameters() + .get(METHOD_ENCODER_PARAMETER) + ); } - return Optional.of(array); - } - - // Returns an optional array of float values where each value in the vector is parsed as a float and validated - // if it is a finite number and within the fp16 range of [-65504 to 65504] by default. - // If the Index setting, "index.knn.faiss.clip_to_fp16_range" is set to True, if the vector value is - // outside the FP16 range then it will be clipped to FP16 range. - Optional getFloatsFromContextWithFP16Validation(ParseContext context, int dimension) throws IOException { - context.path().add(simpleName()); ArrayList vector = new ArrayList<>(); XContentParser.Token token = context.parser().currentToken(); - boolean clipToFP16Range = isFaissSQClipToFP16RangeEnabled( - (MethodComponentContext) fieldType().getKnnMethodContext() - .getMethodComponentContext() - .getParameters() - .get(METHOD_ENCODER_PARAMETER) - ); float value; if (token == XContentParser.Token.START_ARRAY) { token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { value = context.parser().floatValue(); - if (clipToFP16Range) { - value = clipVectorValueToFP16Range(value); + if (isFaissSQfp16Flag) { + if (clipVectorValueToFP16RangeFlag) { + value = clipVectorValueToFP16Range(value); + } else { + validateFP16VectorValue(value); + } } else { - validateFP16VectorValue(value); + validateFloatVectorValue(value); } vector.add(value); @@ -696,10 +668,14 @@ Optional getFloatsFromContextWithFP16Validation(ParseContext context, i } } else if (token == XContentParser.Token.VALUE_NUMBER) { value = context.parser().floatValue(); - if (clipToFP16Range) { - value = clipVectorValueToFP16Range(value); + if (isFaissSQfp16Flag) { + if (clipVectorValueToFP16RangeFlag) { + value = clipVectorValueToFP16Range(value); + } else { + validateFP16VectorValue(value); + } } else { - validateFP16VectorValue(value); + validateFloatVectorValue(value); } vector.add(value); context.parser().nextToken(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 04c173ef4..ad3da5975 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -91,7 +91,7 @@ public static void validateFP16VectorValue(float value) { throw new IllegalArgumentException( String.format( Locale.ROOT, - "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%d, %d]", + "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", ENCODER_SQ, FAISS_SQ_ENCODER_FP16, FP16_MIN_VALUE, diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index 83769b23c..563311c49 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -32,7 +32,7 @@ import static org.opensearch.knn.common.KNNConstants.FAISS_HNSW_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.FAISS_IVF_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.FAISS_PQ_DESCRIPTION; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP_TO_RANGE; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_TYPES; @@ -91,7 +91,7 @@ class Faiss extends NativeLibrary { FAISS_SQ_TYPE, new Parameter.StringParameter(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16, FAISS_SQ_ENCODER_TYPES::contains) ) - .addParameter(FAISS_SQ_CLIP_TO_RANGE, new Parameter.BooleanParameter(FAISS_SQ_CLIP_TO_RANGE, false, Objects::nonNull)) + .addParameter(FAISS_SQ_CLIP, new Parameter.BooleanParameter(FAISS_SQ_CLIP, false, Objects::nonNull)) .setMapGenerator( ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( FAISS_SQ_DESCRIPTION, diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index c4d841d47..94ba0dbd3 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -45,7 +45,7 @@ import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP_TO_RANGE; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; @@ -447,7 +447,41 @@ public void testHNSWSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { .contains( String.format( Locale.ROOT, - "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%d, %d]", + "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", + ENCODER_SQ, + FAISS_SQ_ENCODER_FP16, + FP16_MIN_VALUE, + FP16_MAX_VALUE + ) + ) + ); + + Float[] vector1 = { -65506.84f, 12.56f }; + + ResponseException ex1 = expectThrows(ResponseException.class, () -> addKnnDoc(indexName, "2", fieldName, vector1)); + assertTrue( + ex1.getMessage() + .contains( + String.format( + Locale.ROOT, + "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", + ENCODER_SQ, + FAISS_SQ_ENCODER_FP16, + FP16_MIN_VALUE, + FP16_MAX_VALUE + ) + ) + ); + + Float[] vector2 = { -65526.4567f, 65526.4567f }; + + ResponseException ex2 = expectThrows(ResponseException.class, () -> addKnnDoc(indexName, "3", fieldName, vector2)); + assertTrue( + ex2.getMessage() + .contains( + String.format( + Locale.ROOT, + "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", ENCODER_SQ, FAISS_SQ_ENCODER_FP16, FP16_MIN_VALUE, @@ -492,7 +526,7 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then .field(NAME, ENCODER_SQ) .startObject(PARAMETERS) .field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) - .field(FAISS_SQ_CLIP_TO_RANGE, true) + .field(FAISS_SQ_CLIP, true) .endObject() .endObject() .endObject() @@ -507,14 +541,16 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then createKnnIndex(indexName, mapping); assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); Float[] vector1 = { -65523.76f, 65504.2f }; - Float[] vector2 = { -20.89f, 65514.2f }; - Float[] vector3 = { -20.89f, 36.23f }; + Float[] vector2 = { -270.85f, 65514.2f }; + Float[] vector3 = { -150.9f, 65504.0f }; + Float[] vector4 = { -20.89f, 100000000.0f }; addKnnDoc(indexName, "1", fieldName, vector1); addKnnDoc(indexName, "2", fieldName, vector2); addKnnDoc(indexName, "3", fieldName, vector3); + addKnnDoc(indexName, "4", fieldName, vector4); float[] queryVector = { -10.5f, 25.48f }; - int k = 3; + int k = 4; Response searchResponse = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, queryVector, k), k); List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); assertEquals(k, results.size()); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 4a5db8c8f..67c34a254 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -53,6 +53,10 @@ import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; +import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; @@ -68,6 +72,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; public class KNNVectorFieldMapperTests extends KNNTestCase { @@ -873,6 +879,45 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { assertArrayEquals(TEST_BYTE_VECTOR, knnByteVectorField.vectorValue()); } + public void testValidateFp16VectorValue_outOfRange_throwsException() { + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> validateFP16VectorValue(65505.25f)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", + ENCODER_SQ, + FAISS_SQ_ENCODER_FP16, + FP16_MIN_VALUE, + FP16_MAX_VALUE + ) + ) + ); + + IllegalArgumentException ex1 = expectThrows(IllegalArgumentException.class, () -> validateFP16VectorValue(-65525.65f)); + assertTrue( + ex1.getMessage() + .contains( + String.format( + Locale.ROOT, + "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", + ENCODER_SQ, + FAISS_SQ_ENCODER_FP16, + FP16_MIN_VALUE, + FP16_MAX_VALUE + ) + ) + ); + } + + public void testClipVectorValuetoFP16Range_succeed() { + assertEquals(65504.0f, clipVectorValueToFP16Range(65504.10f), 0.0f); + assertEquals(65504.0f, clipVectorValueToFP16Range(1000000.89f), 0.0f); + assertEquals(-65504.0f, clipVectorValueToFP16Range(-65504.10f), 0.0f); + assertEquals(-65504.0f, clipVectorValueToFP16Range(-1000000.89f), 0.0f); + } + private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder( VectorDataType vectorDataType ) {