From 96140ee9c30fb8a034d7e84cb0034ff718d42896 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Fri, 30 Jun 2023 11:43:26 -0500 Subject: [PATCH] Address Review Comments Signed-off-by: Naveen Tatikonda --- .../opensearch/knn/common/KNNConstants.java | 4 +- .../opensearch/knn/index/VectorDataType.java | 219 ++++++------------ .../org/opensearch/knn/index/VectorField.java | 2 +- .../index/mapper/KNNVectorFieldMapper.java | 32 +-- .../knn/index/mapper/LuceneFieldMapper.java | 23 +- .../knn/index/VectorDataTypeIT.java | 27 ++- .../mapper/KNNVectorFieldMapperTests.java | 20 +- 7 files changed, 127 insertions(+), 200 deletions(-) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 47ce0c957..6d387eec4 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -52,8 +52,8 @@ public class KNNConstants { public static final String MAX_VECTOR_COUNT_PARAMETER = "max_training_vector_count"; public static final String SEARCH_SIZE_PARAMETER = "search_size"; - public static final String VECTOR_DATA_TYPE = "data_type"; - public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE = VectorDataType.FLOAT; + public static final String VECTOR_DATA_TYPE_FIELD = "data_type"; + public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT; // Lucene specific constants public static final String LUCENE_NAME = "lucene"; diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index 70db606c0..b791c3488 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -5,192 +5,87 @@ package org.opensearch.knn.index; +import com.google.common.annotations.VisibleForTesting; +import lombok.AllArgsConstructor; +import lombok.Getter; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.DocValuesType; -import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.index.IndexableFieldType; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.opensearch.index.mapper.ParametrizedFieldMapper; import org.opensearch.knn.index.util.KNNEngine; +import java.util.Arrays; import java.util.HashSet; -import java.util.Map; +import java.util.Locale; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; -import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; -import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; /** * Enum contains data_type of vectors and right now only supported for lucene engine in k-NN plugin. * We have two vector data_types, one is float (default) and the other one is byte. */ +@AllArgsConstructor public enum VectorDataType { BYTE("byte") { /** - * @param dimension Dimension of the vector + * @param dimension Dimension of the vector * @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType - * @return FieldType of type KnnByteVectorField + * @return FieldType of type KnnByteVectorField */ @Override public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction); } - - /** - * @param knnEngine KNNEngine - * @return DocValues FieldType of type Binary and with BYTE VectorEncoding - */ - @Override - public FieldType buildDocValuesFieldType(KNNEngine knnEngine) { - IndexableFieldType indexableFieldType = new IndexableFieldType() { - @Override - public boolean stored() { - return false; - } - - @Override - public boolean tokenized() { - return true; - } - - @Override - public boolean storeTermVectors() { - return false; - } - - @Override - public boolean storeTermVectorOffsets() { - return false; - } - - @Override - public boolean storeTermVectorPositions() { - return false; - } - - @Override - public boolean storeTermVectorPayloads() { - return false; - } - - @Override - public boolean omitNorms() { - return false; - } - - @Override - public IndexOptions indexOptions() { - return IndexOptions.NONE; - } - - @Override - public DocValuesType docValuesType() { - return DocValuesType.NONE; - } - - @Override - public int pointDimensionCount() { - return 0; - } - - @Override - public int pointIndexDimensionCount() { - return 0; - } - - @Override - public int pointNumBytes() { - return 0; - } - - @Override - public int vectorDimension() { - return 0; - } - - @Override - public VectorEncoding vectorEncoding() { - return VectorEncoding.BYTE; - } - - @Override - public VectorSimilarityFunction vectorSimilarityFunction() { - return VectorSimilarityFunction.EUCLIDEAN; - } - - @Override - public Map getAttributes() { - return null; - } - }; - FieldType field = new FieldType(indexableFieldType); - field.putAttribute(KNN_ENGINE, knnEngine.getName()); - field.setDocValuesType(DocValuesType.BINARY); - field.freeze(); - return field; - } }, FLOAT("float") { /** - * @param dimension Dimension of the vector + * @param dimension Dimension of the vector * @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType - * @return FieldType of type KnnFloatVectorField + * @return FieldType of type KnnFloatVectorField */ @Override public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction); } - /** - * @param knnEngine KNNEngine - * @return DocValues FieldType of type Binary and with FLOAT32 VectorEncoding - */ - @Override - public FieldType buildDocValuesFieldType(KNNEngine knnEngine) { - FieldType field = new FieldType(); - field.putAttribute(KNN_ENGINE, knnEngine.getName()); - field.setDocValuesType(DocValuesType.BINARY); - field.freeze(); - return field; - } - }; + @Getter private final String value; - VectorDataType(String value) { - this.value = value; - } - /** - * Get VectorDataType name - * - * @return name + * @param dimension Dimension of the vector + * @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType + * @return FieldType */ - public String getValue() { - return value; - } - public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction); - public abstract FieldType buildDocValuesFieldType(KNNEngine knnEngine); + /** + * @param knnEngine KNNEngine + * @return DocValues FieldType of type Binary + */ + public FieldType buildDocValuesFieldType(KNNEngine knnEngine) { + FieldType field = new FieldType(); + field.putAttribute(KNN_ENGINE, knnEngine.getName()); + field.setDocValuesType(DocValuesType.BINARY); + field.freeze(); + return field; + } /** * @return Set of names of all the supporting VectorDataTypes */ + @VisibleForTesting public static Set getValues() { - Set values = new HashSet<>(); - - for (VectorDataType dataType : VectorDataType.values()) { - values.add(dataType.getValue()); - } - return values; + return Arrays.stream((VectorDataType.values())).map(VectorDataType::getValue).collect(Collectors.toCollection(HashSet::new)); } /** @@ -199,10 +94,14 @@ public static Set getValues() { * @return the same VectorDataType if it is in the supported values else throw exception. */ public static VectorDataType get(String vectorDataType) { - String supportedTypes = String.join(",", getValues()); Objects.requireNonNull( vectorDataType, - String.format("[{}] should not be null. Supported types are [{}]", VECTOR_DATA_TYPE, supportedTypes) + String.format( + Locale.ROOT, + "[{}] should not be null. Supported types are [{}]", + VECTOR_DATA_TYPE_FIELD, + String.join(",", getValues()) + ) ); for (VectorDataType currentDataType : VectorDataType.values()) { if (currentDataType.getValue().equalsIgnoreCase(vectorDataType)) { @@ -211,20 +110,20 @@ public static VectorDataType get(String vectorDataType) { } throw new IllegalArgumentException( String.format( - "[%s] field was set as [%s] in index mapping. But, supported values are [%s]", - VECTOR_DATA_TYPE, - vectorDataType, - supportedTypes + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, + String.join(",", getValues()) ) ); } /** - * Validate the float vector values if it is a number and in the finite range. + * Validate the float vector value and throw exception if it is not a number or not in the finite range. * * @param value float vector value */ - public static void validateFloatVectorValues(float value) { + public static void validateFloatVectorValue(float value) { if (Float.isNaN(value)) { throw new IllegalArgumentException("KNN vector values cannot be NaN"); } @@ -236,22 +135,29 @@ public static void validateFloatVectorValues(float value) { /** * Validate the float vector value in the byte range if it is a finite number, - * with no decimal values and in the byte range of [-128 to 127]. + * with no decimal values and in the byte range of [-128 to 127]. If not throw IllegalArgumentException. * * @param value float value in byte range */ - public static void validateByteVectorValues(float value) { - validateFloatVectorValues(value); + public static void validateByteVectorValue(float value) { + validateFloatVectorValue(value); if (value % 1 != 0) { throw new IllegalArgumentException( - "[data_type] field was set as [byte] in index mapping. But, KNN vector values are floats instead of byte integers" + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue() + ) + ); } if ((int) value < Byte.MIN_VALUE || (int) value > Byte.MAX_VALUE) { throw new IllegalArgumentException( String.format( - "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [{}, {}]", - VECTOR_DATA_TYPE, + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", + VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue(), Byte.MIN_VALUE, Byte.MAX_VALUE @@ -268,7 +174,7 @@ public static void validateByteVectorValues(float value) { */ public static void validateVectorDimension(int dimension, int vectorSize) { if (dimension != vectorSize) { - String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize); + String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize); throw new IllegalArgumentException(errorMessage); } @@ -281,13 +187,18 @@ public static void validateVectorDimension(int dimension, int vectorSize) { * @param knnMethodContext KNNMethodContext Parameter * @param vectorDataType VectorDataType Parameter */ - public static void validateVectorDataType_Engine( + public static void validateVectorDataTypeWithEngine( ParametrizedFieldMapper.Parameter knnMethodContext, ParametrizedFieldMapper.Parameter vectorDataType ) { - if (vectorDataType.getValue() != DEFAULT_VECTOR_DATA_TYPE - && (knnMethodContext.get() == null || knnMethodContext.getValue().getKnnEngine() != KNNEngine.LUCENE)) { - throw new IllegalArgumentException(String.format("[%s] is only supported for [%s] engine", VECTOR_DATA_TYPE, LUCENE_NAME)); + if (vectorDataType.getValue() == DEFAULT_VECTOR_DATA_TYPE_FIELD) { + return; + } + if ((knnMethodContext.getValue() == null && KNNEngine.DEFAULT != KNNEngine.LUCENE) + || knnMethodContext.getValue().getKnnEngine() != KNNEngine.LUCENE) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] is only supported for [%s] engine", VECTOR_DATA_TYPE_FIELD, LUCENE_NAME) + ); } } } diff --git a/src/main/java/org/opensearch/knn/index/VectorField.java b/src/main/java/org/opensearch/knn/index/VectorField.java index 2c346992d..f28ef6238 100644 --- a/src/main/java/org/opensearch/knn/index/VectorField.java +++ b/src/main/java/org/opensearch/knn/index/VectorField.java @@ -34,7 +34,7 @@ public VectorField(String name, byte[] value, IndexableFieldType type) { try { this.setBytesValue(value); } catch (Exception e) { - throw new IllegalArgumentException(e); + throw new RuntimeException(e); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 27416a7fb..0211db815 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -50,12 +50,12 @@ import java.util.Optional; import java.util.function.Supplier; -import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; -import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE; -import static org.opensearch.knn.index.VectorDataType.validateByteVectorValues; -import static org.opensearch.knn.index.VectorDataType.validateFloatVectorValues; -import static org.opensearch.knn.index.VectorDataType.validateVectorDataType_Engine; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.VectorDataType.validateByteVectorValue; +import static org.opensearch.knn.index.VectorDataType.validateFloatVectorValue; +import static org.opensearch.knn.index.VectorDataType.validateVectorDataTypeWithEngine; import static org.opensearch.knn.index.VectorDataType.validateVectorDimension; /** @@ -107,10 +107,10 @@ public static class Builder extends ParametrizedFieldMapper.Builder { * data_type which defines the datatype of the vector values. This is an optional parameter and * this is right now only relevant for lucene engine. The default value is float. */ - protected final Parameter vectorDataType = new Parameter<>( - VECTOR_DATA_TYPE, + private final Parameter vectorDataType = new Parameter<>( + VECTOR_DATA_TYPE_FIELD, false, - () -> DEFAULT_VECTOR_DATA_TYPE, + () -> DEFAULT_VECTOR_DATA_TYPE_FIELD, (n, c, o) -> VectorDataType.get((String) o), m -> toType(m).vectorDataType ); @@ -350,7 +350,7 @@ public Mapper.Builder parse(String name, Map node, ParserCont // Validates and throws exception if data_type field is set in the index mapping // using any VectorDataType (other than float, which is default) with any engine (except lucene). - validateVectorDataType_Engine(builder.knnMethodContext, builder.vectorDataType); + validateVectorDataTypeWithEngine(builder.knnMethodContext, builder.vectorDataType); return builder; } @@ -364,15 +364,15 @@ public static class KNNVectorFieldType extends MappedFieldType { VectorDataType vectorDataType; public KNNVectorFieldType(String name, Map meta, int dimension) { - this(name, meta, dimension, null, null, DEFAULT_VECTOR_DATA_TYPE); + this(name, meta, dimension, null, null, DEFAULT_VECTOR_DATA_TYPE_FIELD); } public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext) { - this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE); + this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD); } public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext, String modelId) { - this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE); + this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD); } public KNNVectorFieldType( @@ -522,13 +522,13 @@ Optional getBytesFromContext(ParseContext context, int dimension) throws token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { value = context.parser().floatValue(); - validateByteVectorValues(value); + validateByteVectorValue(value); vector.add((byte) value); token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { value = context.parser().floatValue(); - validateByteVectorValues(value); + validateByteVectorValue(value); vector.add((byte) value); context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { @@ -554,13 +554,13 @@ Optional getFloatsFromContext(ParseContext context, int dimension) thro token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { value = context.parser().floatValue(); - validateFloatVectorValues(value); + validateFloatVectorValue(value); vector.add(value); token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { value = context.parser().floatValue(); - validateFloatVectorValues(value); + validateFloatVectorValue(value); vector.add(value); context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index b571a9d2f..52119a6a7 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -21,9 +21,11 @@ import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; +import java.util.Locale; import java.util.Optional; import static org.apache.lucene.index.VectorValues.MAX_DIMENSIONS; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; /** * Field mapper for case when Lucene has been set as an engine. @@ -55,6 +57,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { if (dimension > LUCENE_MAX_DIMENSION) { throw new IllegalArgumentException( String.format( + Locale.ROOT, "Dimension value cannot be greater than [%s] but got [%s] for vector [%s]", LUCENE_MAX_DIMENSION, dimension, @@ -78,12 +81,12 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); - if (vectorDataType.equals(VectorDataType.BYTE)) { - Optional arrayOptional = getBytesFromContext(context, dimension); - if (arrayOptional.isEmpty()) { + if (VectorDataType.BYTE.equals(vectorDataType)) { + Optional bytesArrayOptional = getBytesFromContext(context, dimension); + if (bytesArrayOptional.isEmpty()) { return; } - final byte[] array = arrayOptional.get(); + final byte[] array = bytesArrayOptional.get(); KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType); context.doc().add(point); if (fieldType.stored()) { @@ -92,13 +95,13 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx if (hasDocValues && vectorFieldType != null) { context.doc().add(new VectorField(name(), array, vectorFieldType)); } - } else { - Optional arrayOptional = getFloatsFromContext(context, dimension); + } else if (VectorDataType.FLOAT.equals(vectorDataType)) { + Optional floatsArrayOptional = getFloatsFromContext(context, dimension); - if (arrayOptional.isEmpty()) { + if (floatsArrayOptional.isEmpty()) { return; } - final float[] array = arrayOptional.get(); + final float[] array = floatsArrayOptional.get(); KnnVectorField point = new KnnVectorField(name(), array, fieldType); @@ -110,6 +113,10 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx if (hasDocValues && vectorFieldType != null) { context.doc().add(new VectorField(name(), array, vectorFieldType)); } + } else { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) + ); } context.path().remove(); diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index 865402225..52ecded02 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -15,11 +15,12 @@ import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; +import java.util.Locale; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; -import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.getValues; public class VectorDataTypeIT extends KNNRestTestCase { @@ -75,7 +76,6 @@ public void testDeleteDocWithByteVector() throws Exception { // Set an invalid value for data_type field while creating the index which should throw an exception public void testInvalidVectorDataType() { String vectorDataType = "invalidVectorType"; - String supportedTypes = String.join(",", getValues()); ResponseException ex = expectThrows( ResponseException.class, () -> createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, vectorDataType) @@ -84,10 +84,10 @@ public void testInvalidVectorDataType() { ex.getMessage() .contains( String.format( - "[%s] field was set as [%s] in index mapping. But, supported values are [%s]", - VECTOR_DATA_TYPE, - vectorDataType, - supportedTypes + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, + String.join(",", getValues()) ) ) ); @@ -100,8 +100,9 @@ public void testVectorDataTypeAsNull() { ex.getMessage() .contains( String.format( + Locale.ROOT, "[%s] on mapper [%s] of type [%s] must not have a [null] value", - VECTOR_DATA_TYPE, + VECTOR_DATA_TYPE_FIELD, FIELD_NAME, KNN_VECTOR_TYPE ) @@ -133,8 +134,9 @@ public void testInvalidByteVectorRange() throws Exception { ex.getMessage() .contains( String.format( - "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [{}, {}]", - VECTOR_DATA_TYPE, + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", + VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue(), Byte.MIN_VALUE, Byte.MAX_VALUE @@ -149,7 +151,10 @@ public void testByteVectorDataTypeWithNmslibEngine() { ResponseException.class, () -> createKnnIndexMappingWithNmslibEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()) ); - assertTrue(ex.getMessage().contains(String.format("[%s] is only supported for [%s] engine", VECTOR_DATA_TYPE, LUCENE_NAME))); + assertTrue( + ex.getMessage() + .contains(String.format(Locale.ROOT, "[%s] is only supported for [%s] engine", VECTOR_DATA_TYPE_FIELD, LUCENE_NAME)) + ); } private void createKnnIndexMappingWithNmslibEngine(int dimension, SpaceType spaceType, String vectorDataType) throws Exception { @@ -168,7 +173,7 @@ private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spac .startObject(FIELD_NAME) .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) .field(DIMENSION, dimension) - .field(VECTOR_DATA_TYPE, vectorDataType) + .field(VECTOR_DATA_TYPE_FIELD, vectorDataType) .startObject(KNNConstants.KNN_METHOD) .field(KNNConstants.NAME, METHOD_HNSW) .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index d578729bb..e8fbe3dc9 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -45,7 +45,9 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Locale; import java.util.Optional; +import java.util.stream.Collectors; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; @@ -64,7 +66,7 @@ import static org.opensearch.Version.CURRENT; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.getValues; public class KNNVectorFieldMapperTests extends KNNTestCase { @@ -94,6 +96,9 @@ public void testBuilder_getParameters() { KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao); assertEquals(7, builder.getParameters().size()); + List actualParams = builder.getParameters().stream().map(a -> a.name).collect(Collectors.toList()); + List expectedParams = Arrays.asList("store", "doc_values", DIMENSION, VECTOR_DATA_TYPE_FIELD, "meta", KNN_METHOD, MODEL_ID); + assertEquals(expectedParams, actualParams); } public void testBuilder_build_fromKnnMethodContext() { @@ -361,7 +366,7 @@ public void testTypeParser_parse_invalidVectorDataType() throws IOException { .startObject() .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) .field(DIMENSION, 10) - .field(VECTOR_DATA_TYPE, vectorDataType) + .field(VECTOR_DATA_TYPE_FIELD, vectorDataType) .startObject(KNN_METHOD) .field(NAME, METHOD_HNSW) .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) @@ -382,9 +387,9 @@ public void testTypeParser_parse_invalidVectorDataType() throws IOException { ); assertEquals( String.format( - "[%s] field was set as [%s] in index mapping. But, supported values are [%s]", - VECTOR_DATA_TYPE, - vectorDataType, + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, supportedTypes ), ex.getMessage() @@ -834,7 +839,6 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() throws } assertEquals(TEST_BYTE_VECTOR_BYTES_REF, vectorField.binaryValue()); - assertEquals(VectorEncoding.BYTE, vectorField.fieldType().vectorEncoding()); assertArrayEquals(TEST_BYTE_VECTOR, knnByteVectorField.vectorValue()); // Test when doc values are disabled @@ -889,13 +893,13 @@ private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperIn .knnMethodContext(knnMethodContext); } - public static float[] createInitializedFloatArray(int dimension, float value) { + private static float[] createInitializedFloatArray(int dimension, float value) { float[] array = new float[dimension]; Arrays.fill(array, value); return array; } - public static byte[] createInitializedByteArray(int dimension, byte value) { + private static byte[] createInitializedByteArray(int dimension, byte value) { byte[] array = new byte[dimension]; Arrays.fill(array, value); return array;