Skip to content

Commit

Permalink
Address Review Comments
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
  • Loading branch information
naveentatikonda committed Mar 11, 2024
1 parent 5c49444 commit 06dc3b1
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 94 deletions.
6 changes: 3 additions & 3 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public class KNNConstants {
public static final String FAISS_SQ_TYPE = "type";
public static final String FAISS_SQ_ENCODER_FP16 = "fp16";
public static final List<String> FAISS_SQ_ENCODER_TYPES = List.of(FAISS_SQ_ENCODER_FP16);
public static final String FAISS_SQ_CLIP_TO_RANGE = "clip_to_range";
public static final String FAISS_SQ_CLIP = "clip";

// Parameter defaults/limits
public static final Integer ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT = 1;
Expand All @@ -111,8 +111,8 @@ public class KNNConstants {
public static final Integer MODEL_CACHE_CAPACITY_ATROPHY_THRESHOLD_IN_MINUTES = 30;
public static final Integer MODEL_CACHE_EXPIRE_AFTER_ACCESS_TIME_MINUTES = 30;

public static final Integer FP16_MAX_VALUE = 65504;
public static final Integer FP16_MIN_VALUE = -65504;
public static final Float FP16_MAX_VALUE = 65504.0f;
public static final Float FP16_MIN_VALUE = -65504.0f;

// Lib names
private static final String JNI_LIBRARY_PREFIX = "opensearchknn_";
Expand Down
8 changes: 2 additions & 6 deletions src/main/java/org/opensearch/knn/index/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,13 @@ public ValidationException validate(Object value) {
ValidationException validationException = null;
if (!(value instanceof Boolean)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("Value not of type Boolean for Boolean " + "parameter \"%s\".", getName())
);
validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName()));
return validationException;

Check warning on line 83 in src/main/java/org/opensearch/knn/index/Parameter.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/Parameter.java#L81-L83

Added lines #L81 - L83 were not covered by tests
}

if (!validator.test((Boolean) value)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("Parameter validation failed for Boolean " + "parameter \"%s\".", getName())
);
validationException.addValidationError(String.format("parameter validation failed for Boolean parameter [%s].", getName()));

Check warning on line 88 in src/main/java/org/opensearch/knn/index/Parameter.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/Parameter.java#L87-L88

Added lines #L87 - L88 were not covered by tests
}
return validationException;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP_TO_RANGE;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
Expand Down Expand Up @@ -525,12 +526,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional;
if (isFaissSQfp16(fieldType().getKnnMethodContext())) {
floatsArrayOptional = getFloatsFromContextWithFP16Validation(context, dimension);
} else {
floatsArrayOptional = getFloatsFromContext(context, dimension);
}
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

if (!floatsArrayOptional.isPresent()) {
return;
Expand All @@ -549,36 +545,42 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
context.path().remove();
}

// Verify mapping and return true if it is a "faiss" Index using "sq" encoder
// of type "fp16"
// Verify mapping and return true if it is a "faiss" Index using "sq" encoder of type "fp16"
protected boolean isFaissSQfp16(KNNMethodContext knnMethodContext) {
if (knnMethodContext != null) {
if (knnMethodContext.getKnnEngine().getName().equals(FAISS_NAME)
&& knnMethodContext.getMethodComponentContext().getParameters().size() != 0) {
Map<String, Object> methodComponentParams = knnMethodContext.getMethodComponentContext().getParameters();
if (methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) {
MethodComponentContext methodComponentContext = (MethodComponentContext) methodComponentParams.get(
METHOD_ENCODER_PARAMETER
);
if (ENCODER_SQ.equals(methodComponentContext.getName())
&& FAISS_SQ_ENCODER_FP16.equals(
methodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
)) {
return true;
}
}

}
// KNNMethodContext shouldn't be null
if (Objects.isNull(knnMethodContext)) {
return false;
}

// engine should be faiss
if (!FAISS_NAME.equals(knnMethodContext.getKnnEngine().getName())) {
return false;
}
return false;

// Should have Method Component Parameters
if (knnMethodContext.getMethodComponentContext().getParameters().size() == 0) {
return false;
}
Map<String, Object> methodComponentParams = knnMethodContext.getMethodComponentContext().getParameters();

// The method component parameters should have an encoder
if (!methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) {
return false;
}

MethodComponentContext methodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER);

// returns true if encoder name is "sq" and type is "fp16"
return ENCODER_SQ.equals(methodComponentContext.getName())
&& FAISS_SQ_ENCODER_FP16.equals(methodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16));
}

// Verify mapping and return the value of "clip_to_range" parameter(default false) for a "faiss" Index
// Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index
// using "sq" encoder of type "fp16".
protected boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComponentContext) {
if (methodComponentContext != null) {
return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP_TO_RANGE, false);
if (Objects.nonNull(methodComponentContext)) {
return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false);
}
return false;

Check warning on line 585 in src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java#L585

Added line #L585 was not covered by tests
}
Expand Down Expand Up @@ -635,71 +637,51 @@ Optional<byte[]> getBytesFromContext(ParseContext context, int dimension) throws
Optional<float[]> getFloatsFromContext(ParseContext context, int dimension) throws IOException {
context.path().add(simpleName());

ArrayList<Float> vector = new ArrayList<>();
XContentParser.Token token = context.parser().currentToken();
float value;
if (token == XContentParser.Token.START_ARRAY) {
token = context.parser().nextToken();
while (token != XContentParser.Token.END_ARRAY) {
value = context.parser().floatValue();
validateFloatVectorValue(value);
vector.add(value);
token = context.parser().nextToken();
}
} else if (token == XContentParser.Token.VALUE_NUMBER) {
value = context.parser().floatValue();
validateFloatVectorValue(value);
vector.add(value);
context.parser().nextToken();
} else if (token == XContentParser.Token.VALUE_NULL) {
context.path().remove();
return Optional.empty();
}
validateVectorDimension(dimension, vector.size());

float[] array = new float[vector.size()];
int i = 0;
for (Float f : vector) {
array[i++] = f;
// Returns an optional array of float values where each value in the vector is parsed as a float and validated
// if it is a finite number and within the fp16 range of [-65504 to 65504] by default if Faiss encoder is SQ and type is 'fp16'.
// If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be
// clipped to FP16 range.
boolean isFaissSQfp16Flag = isFaissSQfp16(fieldType().getKnnMethodContext());
boolean clipVectorValueToFP16RangeFlag = false;
if (isFaissSQfp16Flag) {
clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled(
(MethodComponentContext) fieldType().getKnnMethodContext()
.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER)
);
}
return Optional.of(array);
}

// Returns an optional array of float values where each value in the vector is parsed as a float and validated
// if it is a finite number and within the fp16 range of [-65504 to 65504] by default.
// If the Index setting, "index.knn.faiss.clip_to_fp16_range" is set to True, if the vector value is
// outside the FP16 range then it will be clipped to FP16 range.
Optional<float[]> getFloatsFromContextWithFP16Validation(ParseContext context, int dimension) throws IOException {
context.path().add(simpleName());

ArrayList<Float> vector = new ArrayList<>();
XContentParser.Token token = context.parser().currentToken();
boolean clipToFP16Range = isFaissSQClipToFP16RangeEnabled(
(MethodComponentContext) fieldType().getKnnMethodContext()
.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER)
);
float value;
if (token == XContentParser.Token.START_ARRAY) {
token = context.parser().nextToken();
while (token != XContentParser.Token.END_ARRAY) {
value = context.parser().floatValue();
if (clipToFP16Range) {
value = clipVectorValueToFP16Range(value);
if (isFaissSQfp16Flag) {
if (clipVectorValueToFP16RangeFlag) {
value = clipVectorValueToFP16Range(value);
} else {
validateFP16VectorValue(value);
}
} else {
validateFP16VectorValue(value);
validateFloatVectorValue(value);
}

vector.add(value);
token = context.parser().nextToken();
}
} else if (token == XContentParser.Token.VALUE_NUMBER) {
value = context.parser().floatValue();
if (clipToFP16Range) {
value = clipVectorValueToFP16Range(value);
if (isFaissSQfp16Flag) {
if (clipVectorValueToFP16RangeFlag) {
value = clipVectorValueToFP16Range(value);

Check warning on line 679 in src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java#L679

Added line #L679 was not covered by tests
} else {
validateFP16VectorValue(value);

Check warning on line 681 in src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java#L681

Added line #L681 was not covered by tests
}
} else {
validateFP16VectorValue(value);
validateFloatVectorValue(value);

Check warning on line 684 in src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java#L684

Added line #L684 was not covered by tests
}
vector.add(value);
context.parser().nextToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public static void validateFP16VectorValue(float value) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%d, %d]",
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]",
ENCODER_SQ,
FAISS_SQ_ENCODER_FP16,
FP16_MIN_VALUE,
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/index/util/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import static org.opensearch.knn.common.KNNConstants.FAISS_HNSW_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_IVF_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_PQ_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP_TO_RANGE;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_TYPES;
Expand Down Expand Up @@ -91,7 +91,7 @@ class Faiss extends NativeLibrary {
FAISS_SQ_TYPE,
new Parameter.StringParameter(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16, FAISS_SQ_ENCODER_TYPES::contains)
)
.addParameter(FAISS_SQ_CLIP_TO_RANGE, new Parameter.BooleanParameter(FAISS_SQ_CLIP_TO_RANGE, false, Objects::nonNull))
.addParameter(FAISS_SQ_CLIP, new Parameter.BooleanParameter(FAISS_SQ_CLIP, false, Objects::nonNull))
.setMapGenerator(
((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder(
FAISS_SQ_DESCRIPTION,
Expand Down
48 changes: 42 additions & 6 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP_TO_RANGE;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE;
import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE;
Expand Down Expand Up @@ -447,7 +447,41 @@ public void testHNSWSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() {
.contains(
String.format(
Locale.ROOT,
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%d, %d]",
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]",
ENCODER_SQ,
FAISS_SQ_ENCODER_FP16,
FP16_MIN_VALUE,
FP16_MAX_VALUE
)
)
);

Float[] vector1 = { -65506.84f, 12.56f };

ResponseException ex1 = expectThrows(ResponseException.class, () -> addKnnDoc(indexName, "2", fieldName, vector1));
assertTrue(
ex1.getMessage()
.contains(
String.format(
Locale.ROOT,
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]",
ENCODER_SQ,
FAISS_SQ_ENCODER_FP16,
FP16_MIN_VALUE,
FP16_MAX_VALUE
)
)
);

Float[] vector2 = { -65526.4567f, 65526.4567f };

ResponseException ex2 = expectThrows(ResponseException.class, () -> addKnnDoc(indexName, "3", fieldName, vector2));
assertTrue(
ex2.getMessage()
.contains(
String.format(
Locale.ROOT,
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]",
ENCODER_SQ,
FAISS_SQ_ENCODER_FP16,
FP16_MIN_VALUE,
Expand Down Expand Up @@ -492,7 +526,7 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then
.field(NAME, ENCODER_SQ)
.startObject(PARAMETERS)
.field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
.field(FAISS_SQ_CLIP_TO_RANGE, true)
.field(FAISS_SQ_CLIP, true)
.endObject()
.endObject()
.endObject()
Expand All @@ -507,14 +541,16 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then
createKnnIndex(indexName, mapping);
assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName)));
Float[] vector1 = { -65523.76f, 65504.2f };
Float[] vector2 = { -20.89f, 65514.2f };
Float[] vector3 = { -20.89f, 36.23f };
Float[] vector2 = { -270.85f, 65514.2f };
Float[] vector3 = { -150.9f, 65504.0f };
Float[] vector4 = { -20.89f, 100000000.0f };
addKnnDoc(indexName, "1", fieldName, vector1);
addKnnDoc(indexName, "2", fieldName, vector2);
addKnnDoc(indexName, "3", fieldName, vector3);
addKnnDoc(indexName, "4", fieldName, vector4);

float[] queryVector = { -10.5f, 25.48f };
int k = 3;
int k = 4;
Response searchResponse = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, queryVector, k), k);
List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName);
assertEquals(k, results.size());
Expand Down
Loading

0 comments on commit 06dc3b1

Please sign in to comment.