Skip to content

Commit

Permalink
Address Review Comments
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
  • Loading branch information
naveentatikonda committed Jun 30, 2023
1 parent a588973 commit 96140ee
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 200 deletions.
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ public class KNNConstants {
public static final String MAX_VECTOR_COUNT_PARAMETER = "max_training_vector_count";
public static final String SEARCH_SIZE_PARAMETER = "search_size";

public static final String VECTOR_DATA_TYPE = "data_type";
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE = VectorDataType.FLOAT;
public static final String VECTOR_DATA_TYPE_FIELD = "data_type";
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";
Expand Down
219 changes: 65 additions & 154 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,192 +5,87 @@

package org.opensearch.knn.index;

import com.google.common.annotations.VisibleForTesting;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.opensearch.index.mapper.ParametrizedFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.Locale;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE;
import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;

/**
* Enum contains data_type of vectors and right now only supported for lucene engine in k-NN plugin.
* We have two vector data_types, one is float (default) and the other one is byte.
*/
@AllArgsConstructor
public enum VectorDataType {
BYTE("byte") {
/**
* @param dimension Dimension of the vector
* @param dimension Dimension of the vector
* @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType
* @return FieldType of type KnnByteVectorField
* @return FieldType of type KnnByteVectorField
*/
@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction);
}

/**
* @param knnEngine KNNEngine
* @return DocValues FieldType of type Binary and with BYTE VectorEncoding
*/
@Override
public FieldType buildDocValuesFieldType(KNNEngine knnEngine) {
IndexableFieldType indexableFieldType = new IndexableFieldType() {
@Override
public boolean stored() {
return false;
}

@Override
public boolean tokenized() {
return true;
}

@Override
public boolean storeTermVectors() {
return false;
}

@Override
public boolean storeTermVectorOffsets() {
return false;
}

@Override
public boolean storeTermVectorPositions() {
return false;
}

@Override
public boolean storeTermVectorPayloads() {
return false;
}

@Override
public boolean omitNorms() {
return false;
}

@Override
public IndexOptions indexOptions() {
return IndexOptions.NONE;
}

@Override
public DocValuesType docValuesType() {
return DocValuesType.NONE;
}

@Override
public int pointDimensionCount() {
return 0;
}

@Override
public int pointIndexDimensionCount() {
return 0;
}

@Override
public int pointNumBytes() {
return 0;
}

@Override
public int vectorDimension() {
return 0;
}

@Override
public VectorEncoding vectorEncoding() {
return VectorEncoding.BYTE;
}

@Override
public VectorSimilarityFunction vectorSimilarityFunction() {
return VectorSimilarityFunction.EUCLIDEAN;
}

@Override
public Map<String, String> getAttributes() {
return null;
}
};
FieldType field = new FieldType(indexableFieldType);
field.putAttribute(KNN_ENGINE, knnEngine.getName());
field.setDocValuesType(DocValuesType.BINARY);
field.freeze();
return field;
}
},
FLOAT("float") {
/**
* @param dimension Dimension of the vector
* @param dimension Dimension of the vector
* @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType
* @return FieldType of type KnnFloatVectorField
* @return FieldType of type KnnFloatVectorField
*/
@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction);
}

/**
* @param knnEngine KNNEngine
* @return DocValues FieldType of type Binary and with FLOAT32 VectorEncoding
*/
@Override
public FieldType buildDocValuesFieldType(KNNEngine knnEngine) {
FieldType field = new FieldType();
field.putAttribute(KNN_ENGINE, knnEngine.getName());
field.setDocValuesType(DocValuesType.BINARY);
field.freeze();
return field;
}

};

@Getter
private final String value;

VectorDataType(String value) {
this.value = value;
}

/**
* Get VectorDataType name
*
* @return name
* @param dimension Dimension of the vector
* @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType
* @return FieldType
*/
public String getValue() {
return value;
}

public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction);

public abstract FieldType buildDocValuesFieldType(KNNEngine knnEngine);
/**
* @param knnEngine KNNEngine
* @return DocValues FieldType of type Binary
*/
public FieldType buildDocValuesFieldType(KNNEngine knnEngine) {
FieldType field = new FieldType();
field.putAttribute(KNN_ENGINE, knnEngine.getName());
field.setDocValuesType(DocValuesType.BINARY);
field.freeze();
return field;
}

/**
* @return Set of names of all the supporting VectorDataTypes
*/
@VisibleForTesting
public static Set<String> getValues() {
Set<String> values = new HashSet<>();

for (VectorDataType dataType : VectorDataType.values()) {
values.add(dataType.getValue());
}
return values;
return Arrays.stream((VectorDataType.values())).map(VectorDataType::getValue).collect(Collectors.toCollection(HashSet::new));
}

/**
Expand All @@ -199,10 +94,14 @@ public static Set<String> getValues() {
* @return the same VectorDataType if it is in the supported values else throw exception.
*/
public static VectorDataType get(String vectorDataType) {
String supportedTypes = String.join(",", getValues());
Objects.requireNonNull(
vectorDataType,
String.format("[{}] should not be null. Supported types are [{}]", VECTOR_DATA_TYPE, supportedTypes)
String.format(
Locale.ROOT,
"[{}] should not be null. Supported types are [{}]",
VECTOR_DATA_TYPE_FIELD,
String.join(",", getValues())
)
);
for (VectorDataType currentDataType : VectorDataType.values()) {
if (currentDataType.getValue().equalsIgnoreCase(vectorDataType)) {
Expand All @@ -211,20 +110,20 @@ public static VectorDataType get(String vectorDataType) {
}
throw new IllegalArgumentException(
String.format(
"[%s] field was set as [%s] in index mapping. But, supported values are [%s]",
VECTOR_DATA_TYPE,
vectorDataType,
supportedTypes
Locale.ROOT,
"Invalid value provided for [%s] field. Supported values are [%s]",
VECTOR_DATA_TYPE_FIELD,
String.join(",", getValues())
)
);
}

/**
* Validate the float vector values if it is a number and in the finite range.
* Validate the float vector value and throw exception if it is not a number or not in the finite range.
*
* @param value float vector value
*/
public static void validateFloatVectorValues(float value) {
public static void validateFloatVectorValue(float value) {
if (Float.isNaN(value)) {
throw new IllegalArgumentException("KNN vector values cannot be NaN");
}
Expand All @@ -236,22 +135,29 @@ public static void validateFloatVectorValues(float value) {

/**
* Validate the float vector value in the byte range if it is a finite number,
* with no decimal values and in the byte range of [-128 to 127].
* with no decimal values and in the byte range of [-128 to 127]. If not throw IllegalArgumentException.
*
* @param value float value in byte range
*/
public static void validateByteVectorValues(float value) {
validateFloatVectorValues(value);
public static void validateByteVectorValue(float value) {
validateFloatVectorValue(value);
if (value % 1 != 0) {
throw new IllegalArgumentException(
"[data_type] field was set as [byte] in index mapping. But, KNN vector values are floats instead of byte integers"
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue()
)

);
}
if ((int) value < Byte.MIN_VALUE || (int) value > Byte.MAX_VALUE) {
throw new IllegalArgumentException(
String.format(
"[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [{}, {}]",
VECTOR_DATA_TYPE,
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue(),
Byte.MIN_VALUE,
Byte.MAX_VALUE
Expand All @@ -268,7 +174,7 @@ public static void validateByteVectorValues(float value) {
*/
public static void validateVectorDimension(int dimension, int vectorSize) {
if (dimension != vectorSize) {
String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize);
String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize);
throw new IllegalArgumentException(errorMessage);
}

Expand All @@ -281,13 +187,18 @@ public static void validateVectorDimension(int dimension, int vectorSize) {
* @param knnMethodContext KNNMethodContext Parameter
* @param vectorDataType VectorDataType Parameter
*/
public static void validateVectorDataType_Engine(
public static void validateVectorDataTypeWithEngine(
ParametrizedFieldMapper.Parameter<KNNMethodContext> knnMethodContext,
ParametrizedFieldMapper.Parameter<VectorDataType> vectorDataType
) {
if (vectorDataType.getValue() != DEFAULT_VECTOR_DATA_TYPE
&& (knnMethodContext.get() == null || knnMethodContext.getValue().getKnnEngine() != KNNEngine.LUCENE)) {
throw new IllegalArgumentException(String.format("[%s] is only supported for [%s] engine", VECTOR_DATA_TYPE, LUCENE_NAME));
if (vectorDataType.getValue() == DEFAULT_VECTOR_DATA_TYPE_FIELD) {
return;
}
if ((knnMethodContext.getValue() == null && KNNEngine.DEFAULT != KNNEngine.LUCENE)
|| knnMethodContext.getValue().getKnnEngine() != KNNEngine.LUCENE) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] is only supported for [%s] engine", VECTOR_DATA_TYPE_FIELD, LUCENE_NAME)
);
}
}
}
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/VectorField.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public VectorField(String name, byte[] value, IndexableFieldType type) {
try {
this.setBytesValue(value);
} catch (Exception e) {
throw new IllegalArgumentException(e);
throw new RuntimeException(e);
}

}
Expand Down
Loading

0 comments on commit 96140ee

Please sign in to comment.