Skip to content

Commit

Permalink
Add Indexing Support for Lucene Byte Sized Vector (#937)
Browse files Browse the repository at this point in the history
* Add Indexing Support for Lucene Byte Sized Vector

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

* Add tests for Indexing

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

* Add CHANGELOG

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

* Address Review Comments

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

---------

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
  • Loading branch information
naveentatikonda committed Jul 6, 2023
1 parent a2ee834 commit 1386519
Show file tree
Hide file tree
Showing 9 changed files with 772 additions and 79 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.8...2.x)
### Features
* Added efficient filtering support for Faiss Engine ([#936](https://github.com/opensearch-project/k-NN/pull/936))
* Add Indexing Support for Lucene Byte Sized Vector ([#937](https://github.com/opensearch-project/k-NN/pull/937))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.common;

import org.opensearch.knn.index.VectorDataType;

public class KNNConstants {
// shared across library constants
public static final String DIMENSION = "dimension";
Expand Down Expand Up @@ -50,6 +52,9 @@ public class KNNConstants {
public static final String MAX_VECTOR_COUNT_PARAMETER = "max_training_vector_count";
public static final String SEARCH_SIZE_PARAMETER = "search_size";

public static final String VECTOR_DATA_TYPE_FIELD = "data_type";
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";

Expand Down
89 changes: 89 additions & 0 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.VectorSimilarityFunction;

import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;

/**
* Enum contains data_type of vectors and right now only supported for lucene engine in k-NN plugin.
* We have two vector data_types, one is float (default) and the other one is byte.
*/
@AllArgsConstructor
public enum VectorDataType {
BYTE("byte") {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction);
}
},
FLOAT("float") {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction);
}

};

public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values())
.map(VectorDataType::getValue)
.collect(Collectors.joining(","));
@Getter
private final String value;

/**
* Creates a KnnVectorFieldType based on the VectorDataType using the provided dimension and
* VectorSimilarityFunction.
*
* @param dimension Dimension of the vector
* @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType
* @return FieldType
*/
public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction);

/**
* Validates if given VectorDataType is in the list of supported data types.
* @param vectorDataType VectorDataType
* @return the same VectorDataType if it is in the supported values
* throws Exception if an invalid value is provided.
*/
public static VectorDataType get(String vectorDataType) {
Objects.requireNonNull(
vectorDataType,
String.format(
Locale.ROOT,
"[%s] should not be null. Supported types are [%s]",
VECTOR_DATA_TYPE_FIELD,
SUPPORTED_VECTOR_DATA_TYPES
)
);
try {
return VectorDataType.valueOf(vectorDataType.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Invalid value provided for [%s] field. Supported values are [%s]",
VECTOR_DATA_TYPE_FIELD,
SUPPORTED_VECTOR_DATA_TYPES
)
);
}
}
}
15 changes: 15 additions & 0 deletions src/main/java/org/opensearch/knn/index/VectorField.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,19 @@ public VectorField(String name, float[] value, IndexableFieldType type) {
throw new RuntimeException(e);
}
}

/**
* @param name FieldType name
* @param value an array of byte vector values
* @param type FieldType to build DocValues
*/
public VectorField(String name, byte[] value, IndexableFieldType type) {
super(name, new BytesRef(), type);
try {
this.setBytesValue(value);
} catch (Exception e) {
throw new RuntimeException(e);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.opensearch.knn.common.KNNConstants;

import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
Expand All @@ -35,6 +34,7 @@
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorField;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
Expand All @@ -49,7 +49,14 @@
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDimension;

/**
* Field Mapper for KNN vector type.
Expand Down Expand Up @@ -96,6 +103,18 @@ public static class Builder extends ParametrizedFieldMapper.Builder {
return value;
}, m -> toType(m).dimension);

/**
* data_type which defines the datatype of the vector values. This is an optional parameter and
* this is right now only relevant for lucene engine. The default value is float.
*/
private final Parameter<VectorDataType> vectorDataType = new Parameter<>(
VECTOR_DATA_TYPE_FIELD,
false,
() -> DEFAULT_VECTOR_DATA_TYPE_FIELD,
(n, c, o) -> VectorDataType.get((String) o),
m -> toType(m).vectorDataType
);

/**
* modelId provides a way for a user to generate the underlying library indices from an already serialized
* model template index. If this parameter is set, it will take precedence. This parameter is only relevant for
Expand Down Expand Up @@ -168,7 +187,7 @@ public Builder(String name, String spaceType, String m, String efConstruction) {

@Override
protected List<Parameter<?>> getParameters() {
return Arrays.asList(stored, hasDocValues, dimension, meta, knnMethodContext, modelId);
return Arrays.asList(stored, hasDocValues, dimension, vectorDataType, meta, knnMethodContext, modelId);
}

protected Explicit<Boolean> ignoreMalformed(BuilderContext context) {
Expand Down Expand Up @@ -203,7 +222,8 @@ public KNNVectorFieldMapper build(BuilderContext context) {
buildFullName(context),
metaValue,
dimension.getValue(),
knnMethodContext
knnMethodContext,
vectorDataType.getValue()
);
if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE) {
log.debug(String.format("Use [LuceneFieldMapper] mapper for field [%s]", name));
Expand All @@ -216,6 +236,7 @@ public KNNVectorFieldMapper build(BuilderContext context) {
.ignoreMalformed(ignoreMalformed)
.stored(stored.get())
.hasDocValues(hasDocValues.get())
.vectorDataType(vectorDataType.getValue())
.knnMethodContext(knnMethodContext)
.build();
return new LuceneFieldMapper(createLuceneFieldMapperInput);
Expand Down Expand Up @@ -327,6 +348,10 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
throw new IllegalArgumentException(String.format("Dimension value missing for vector: %s", name));
}

// Validates and throws exception if data_type field is set in the index mapping
// using any VectorDataType (other than float, which is default) with any engine (except lucene).
validateVectorDataTypeWithEngine(builder.knnMethodContext, builder.vectorDataType);

return builder;
}
}
Expand All @@ -336,20 +361,43 @@ public static class KNNVectorFieldType extends MappedFieldType {
int dimension;
String modelId;
KNNMethodContext knnMethodContext;
VectorDataType vectorDataType;

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension) {
this(name, meta, dimension, null, null);
this(name, meta, dimension, null, null, DEFAULT_VECTOR_DATA_TYPE_FIELD);
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext) {
this(name, meta, dimension, knnMethodContext, null);
this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD);
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext, String modelId) {
this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD);
}

public KNNVectorFieldType(
String name,
Map<String, String> meta,
int dimension,
KNNMethodContext knnMethodContext,
VectorDataType vectorDataType
) {
this(name, meta, dimension, knnMethodContext, null, vectorDataType);
}

public KNNVectorFieldType(
String name,
Map<String, String> meta,
int dimension,
KNNMethodContext knnMethodContext,
String modelId,
VectorDataType vectorDataType
) {
super(name, false, false, true, TextSearchInfo.NONE, meta);
this.dimension = dimension;
this.modelId = modelId;
this.knnMethodContext = knnMethodContext;
this.vectorDataType = vectorDataType;
}

@Override
Expand Down Expand Up @@ -386,6 +434,7 @@ public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, S
protected boolean stored;
protected boolean hasDocValues;
protected Integer dimension;
protected VectorDataType vectorDataType;
protected ModelDao modelDao;

// These members map to parameters in the builder. They need to be declared in the abstract class due to the
Expand All @@ -408,6 +457,7 @@ public KNNVectorFieldMapper(
this.stored = stored;
this.hasDocValues = hasDocValues;
this.dimension = mappedFieldType.getDimension();
this.vectorDataType = mappedFieldType.getVectorDataType();
updateEngineStats();
}

Expand Down Expand Up @@ -439,9 +489,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
if (fieldType.stored()) {
context.doc().add(new StoredField(name(), point.toString()));
}
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
context.path().remove();
}

Expand All @@ -459,50 +507,65 @@ void validateIfKNNPluginEnabled() {
}
}

Optional<float[]> getFloatsFromContext(ParseContext context, int dimension) throws IOException {
// Returns an optional array of byte values where each value in the vector is parsed as a float and validated
// if it is a finite number without any decimals and within the byte range of [-128 to 127].
Optional<byte[]> getBytesFromContext(ParseContext context, int dimension) throws IOException {
context.path().add(simpleName());

ArrayList<Float> vector = new ArrayList<>();
ArrayList<Byte> vector = new ArrayList<>();
XContentParser.Token token = context.parser().currentToken();
float value;

if (token == XContentParser.Token.START_ARRAY) {
token = context.parser().nextToken();
while (token != XContentParser.Token.END_ARRAY) {
value = context.parser().floatValue();

if (Float.isNaN(value)) {
throw new IllegalArgumentException("KNN vector values cannot be NaN");
}

if (Float.isInfinite(value)) {
throw new IllegalArgumentException("KNN vector values cannot be infinity");
}

vector.add(value);
validateByteVectorValue(value);
vector.add((byte) value);
token = context.parser().nextToken();
}
} else if (token == XContentParser.Token.VALUE_NUMBER) {
value = context.parser().floatValue();
validateByteVectorValue(value);
vector.add((byte) value);
context.parser().nextToken();
} else if (token == XContentParser.Token.VALUE_NULL) {
context.path().remove();
return Optional.empty();
}
validateVectorDimension(dimension, vector.size());
byte[] array = new byte[vector.size()];
int i = 0;
for (Byte f : vector) {
array[i++] = f;
}
return Optional.of(array);
}

if (Float.isNaN(value)) {
throw new IllegalArgumentException("KNN vector values cannot be NaN");
}
Optional<float[]> getFloatsFromContext(ParseContext context, int dimension) throws IOException {
context.path().add(simpleName());

if (Float.isInfinite(value)) {
throw new IllegalArgumentException("KNN vector values cannot be infinity");
ArrayList<Float> vector = new ArrayList<>();
XContentParser.Token token = context.parser().currentToken();
float value;
if (token == XContentParser.Token.START_ARRAY) {
token = context.parser().nextToken();
while (token != XContentParser.Token.END_ARRAY) {
value = context.parser().floatValue();
validateFloatVectorValue(value);
vector.add(value);
token = context.parser().nextToken();
}

} else if (token == XContentParser.Token.VALUE_NUMBER) {
value = context.parser().floatValue();
validateFloatVectorValue(value);
vector.add(value);
context.parser().nextToken();
} else if (token == XContentParser.Token.VALUE_NULL) {
context.path().remove();
return Optional.empty();
}

if (dimension != vector.size()) {
String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d", dimension, vector.size());
throw new IllegalArgumentException(errorMessage);
}
validateVectorDimension(dimension, vector.size());

float[] array = new float[vector.size()];
int i = 0;
Expand Down
Loading

0 comments on commit 1386519

Please sign in to comment.