Skip to content

Commit

Permalink
Address Review Comments
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
  • Loading branch information
naveentatikonda committed Mar 8, 2024
1 parent 5c49444 commit 62c8e29
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 67 deletions.
6 changes: 3 additions & 3 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public class KNNConstants {
public static final String FAISS_SQ_TYPE = "type";
public static final String FAISS_SQ_ENCODER_FP16 = "fp16";
public static final List<String> FAISS_SQ_ENCODER_TYPES = List.of(FAISS_SQ_ENCODER_FP16);
public static final String FAISS_SQ_CLIP_TO_RANGE = "clip_to_range";
public static final String FAISS_SQ_CLIP = "clip";

// Parameter defaults/limits
public static final Integer ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT = 1;
Expand All @@ -111,8 +111,8 @@ public class KNNConstants {
public static final Integer MODEL_CACHE_CAPACITY_ATROPHY_THRESHOLD_IN_MINUTES = 30;
public static final Integer MODEL_CACHE_EXPIRE_AFTER_ACCESS_TIME_MINUTES = 30;

public static final Integer FP16_MAX_VALUE = 65504;
public static final Integer FP16_MIN_VALUE = -65504;
public static final Float FP16_MAX_VALUE = 65504.0f;
public static final Float FP16_MIN_VALUE = -65504.0f;

// Lib names
private static final String JNI_LIBRARY_PREFIX = "opensearchknn_";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP_TO_RANGE;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
Expand Down Expand Up @@ -526,11 +526,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional;
if (isFaissSQfp16(fieldType().getKnnMethodContext())) {
floatsArrayOptional = getFloatsFromContextWithFP16Validation(context, dimension);
} else {
floatsArrayOptional = getFloatsFromContext(context, dimension);
}
floatsArrayOptional = getFloatsFromContext(context, dimension);

if (!floatsArrayOptional.isPresent()) {
return;
Expand Down Expand Up @@ -574,11 +570,11 @@ protected boolean isFaissSQfp16(KNNMethodContext knnMethodContext) {
return false;
}

// Verify mapping and return the value of "clip_to_range" parameter(default false) for a "faiss" Index
// Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index
// using "sq" encoder of type "fp16".
protected boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComponentContext) {
if (methodComponentContext != null) {
return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP_TO_RANGE, false);
return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false);
}
return false;

Check warning on line 579 in src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java#L579

Added line #L579 was not covered by tests
}
Expand Down Expand Up @@ -635,71 +631,51 @@ Optional<byte[]> getBytesFromContext(ParseContext context, int dimension) throws
Optional<float[]> getFloatsFromContext(ParseContext context, int dimension) throws IOException {
context.path().add(simpleName());

ArrayList<Float> vector = new ArrayList<>();
XContentParser.Token token = context.parser().currentToken();
float value;
if (token == XContentParser.Token.START_ARRAY) {
token = context.parser().nextToken();
while (token != XContentParser.Token.END_ARRAY) {
value = context.parser().floatValue();
validateFloatVectorValue(value);
vector.add(value);
token = context.parser().nextToken();
}
} else if (token == XContentParser.Token.VALUE_NUMBER) {
value = context.parser().floatValue();
validateFloatVectorValue(value);
vector.add(value);
context.parser().nextToken();
} else if (token == XContentParser.Token.VALUE_NULL) {
context.path().remove();
return Optional.empty();
}
validateVectorDimension(dimension, vector.size());

float[] array = new float[vector.size()];
int i = 0;
for (Float f : vector) {
array[i++] = f;
// Returns an optional array of float values where each value in the vector is parsed as a float and validated
// if it is a finite number and within the fp16 range of [-65504 to 65504] by default if Faiss encoder is SQ and type is 'fp16'.
// If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be
// clipped to FP16 range.
boolean isFaissSQfp16Flag = isFaissSQfp16(fieldType().getKnnMethodContext());
boolean clipVectorValueToFP16RangeFlag = false;
if (isFaissSQfp16Flag) {
clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled(
(MethodComponentContext) fieldType().getKnnMethodContext()
.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER)
);
}
return Optional.of(array);
}

// Returns an optional array of float values where each value in the vector is parsed as a float and validated
// if it is a finite number and within the fp16 range of [-65504 to 65504] by default.
// If the Index setting, "index.knn.faiss.clip_to_fp16_range" is set to True, if the vector value is
// outside the FP16 range then it will be clipped to FP16 range.
Optional<float[]> getFloatsFromContextWithFP16Validation(ParseContext context, int dimension) throws IOException {
context.path().add(simpleName());

ArrayList<Float> vector = new ArrayList<>();
XContentParser.Token token = context.parser().currentToken();
boolean clipToFP16Range = isFaissSQClipToFP16RangeEnabled(
(MethodComponentContext) fieldType().getKnnMethodContext()
.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER)
);
float value;
if (token == XContentParser.Token.START_ARRAY) {
token = context.parser().nextToken();
while (token != XContentParser.Token.END_ARRAY) {
value = context.parser().floatValue();
if (clipToFP16Range) {
value = clipVectorValueToFP16Range(value);
if (isFaissSQfp16Flag) {
if (clipVectorValueToFP16RangeFlag) {
value = clipVectorValueToFP16Range(value);
} else {
validateFP16VectorValue(value);
}
} else {
validateFP16VectorValue(value);
validateFloatVectorValue(value);
}

vector.add(value);
token = context.parser().nextToken();
}
} else if (token == XContentParser.Token.VALUE_NUMBER) {
value = context.parser().floatValue();
if (clipToFP16Range) {
value = clipVectorValueToFP16Range(value);
if (isFaissSQfp16Flag) {
if (clipVectorValueToFP16RangeFlag) {
value = clipVectorValueToFP16Range(value);

Check warning on line 673 in src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java#L673

Added line #L673 was not covered by tests
} else {
validateFP16VectorValue(value);

Check warning on line 675 in src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java#L675

Added line #L675 was not covered by tests
}
} else {
validateFP16VectorValue(value);
validateFloatVectorValue(value);

Check warning on line 678 in src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java#L678

Added line #L678 was not covered by tests
}
vector.add(value);
context.parser().nextToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public static void validateFP16VectorValue(float value) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%d, %d]",
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]",
ENCODER_SQ,
FAISS_SQ_ENCODER_FP16,
FP16_MIN_VALUE,
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/index/util/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import static org.opensearch.knn.common.KNNConstants.FAISS_HNSW_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_IVF_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_PQ_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP_TO_RANGE;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_TYPES;
Expand Down Expand Up @@ -91,7 +91,7 @@ class Faiss extends NativeLibrary {
FAISS_SQ_TYPE,
new Parameter.StringParameter(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16, FAISS_SQ_ENCODER_TYPES::contains)
)
.addParameter(FAISS_SQ_CLIP_TO_RANGE, new Parameter.BooleanParameter(FAISS_SQ_CLIP_TO_RANGE, false, Objects::nonNull))
.addParameter(FAISS_SQ_CLIP, new Parameter.BooleanParameter(FAISS_SQ_CLIP, false, Objects::nonNull))
.setMapGenerator(
((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder(
FAISS_SQ_DESCRIPTION,
Expand Down
48 changes: 42 additions & 6 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP_TO_RANGE;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE;
import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE;
Expand Down Expand Up @@ -447,7 +447,41 @@ public void testHNSWSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() {
.contains(
String.format(
Locale.ROOT,
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%d, %d]",
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]",
ENCODER_SQ,
FAISS_SQ_ENCODER_FP16,
FP16_MIN_VALUE,
FP16_MAX_VALUE
)
)
);

Float[] vector1 = { -65506.84f, 12.56f };

ResponseException ex1 = expectThrows(ResponseException.class, () -> addKnnDoc(indexName, "2", fieldName, vector1));
assertTrue(
ex1.getMessage()
.contains(
String.format(
Locale.ROOT,
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]",
ENCODER_SQ,
FAISS_SQ_ENCODER_FP16,
FP16_MIN_VALUE,
FP16_MAX_VALUE
)
)
);

Float[] vector2 = { -65526.4567f, 65526.4567f };

ResponseException ex2 = expectThrows(ResponseException.class, () -> addKnnDoc(indexName, "3", fieldName, vector2));
assertTrue(
ex2.getMessage()
.contains(
String.format(
Locale.ROOT,
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]",
ENCODER_SQ,
FAISS_SQ_ENCODER_FP16,
FP16_MIN_VALUE,
Expand Down Expand Up @@ -492,7 +526,7 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then
.field(NAME, ENCODER_SQ)
.startObject(PARAMETERS)
.field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
.field(FAISS_SQ_CLIP_TO_RANGE, true)
.field(FAISS_SQ_CLIP, true)
.endObject()
.endObject()
.endObject()
Expand All @@ -507,14 +541,16 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then
createKnnIndex(indexName, mapping);
assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName)));
Float[] vector1 = { -65523.76f, 65504.2f };
Float[] vector2 = { -20.89f, 65514.2f };
Float[] vector3 = { -20.89f, 36.23f };
Float[] vector2 = { -270.85f, 65514.2f };
Float[] vector3 = { -150.9f, 65504.0f };
Float[] vector4 = { -20.89f, 100000000.0f };
addKnnDoc(indexName, "1", fieldName, vector1);
addKnnDoc(indexName, "2", fieldName, vector2);
addKnnDoc(indexName, "3", fieldName, vector3);
addKnnDoc(indexName, "4", fieldName, vector4);

float[] queryVector = { -10.5f, 25.48f };
int k = 3;
int k = 4;
Response searchResponse = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, queryVector, k), k);
List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName);
assertEquals(k, results.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE;
import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
Expand All @@ -68,6 +72,8 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue;

public class KNNVectorFieldMapperTests extends KNNTestCase {

Expand Down Expand Up @@ -873,6 +879,45 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() {
assertArrayEquals(TEST_BYTE_VECTOR, knnByteVectorField.vectorValue());
}

public void testValidateFp16VectorValue_outOfRange_throwsException() {
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> validateFP16VectorValue(65505.25f));
assertTrue(
ex.getMessage()
.contains(
String.format(
Locale.ROOT,
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]",
ENCODER_SQ,
FAISS_SQ_ENCODER_FP16,
FP16_MIN_VALUE,
FP16_MAX_VALUE
)
)
);

IllegalArgumentException ex1 = expectThrows(IllegalArgumentException.class, () -> validateFP16VectorValue(-65525.65f));
assertTrue(
ex1.getMessage()
.contains(
String.format(
Locale.ROOT,
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]",
ENCODER_SQ,
FAISS_SQ_ENCODER_FP16,
FP16_MIN_VALUE,
FP16_MAX_VALUE
)
)
);
}

public void testClipVectorValuetoFP16Range_succeed() {
assertEquals(65504.0f, clipVectorValueToFP16Range(65504.10f), 0.0f);
assertEquals(65504.0f, clipVectorValueToFP16Range(1000000.89f), 0.0f);
assertEquals(-65504.0f, clipVectorValueToFP16Range(-65504.10f), 0.0f);
assertEquals(-65504.0f, clipVectorValueToFP16Range(-1000000.89f), 0.0f);
}

private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder(
VectorDataType vectorDataType
) {
Expand Down

0 comments on commit 62c8e29

Please sign in to comment.