Skip to content

Commit

Permalink
Fix few bugs on binary index with Faiss HNSW
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <heemin@amazon.com>
  • Loading branch information
heemin32 committed Jul 18, 2024
1 parent 881364f commit 008ec15
Show file tree
Hide file tree
Showing 11 changed files with 166 additions and 10 deletions.
8 changes: 8 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ public static ValidationException validateKnnField(
return exception;
}

String vectorDataType = (String) fieldMap.get(VECTOR_DATA_TYPE_FIELD);
if (VectorDataType.BINARY.toString().equalsIgnoreCase(vectorDataType)) {
exception.addValidationError(
String.format(Locale.ROOT, "Field \"%s\" is of data type %s. Only FLOAT or BYTE is supported.", field, VectorDataType.BINARY)
);
return exception;
}

// Return if dimension does not need to be checked
if (expectedDimension < 0) {
return null;
Expand Down
9 changes: 7 additions & 2 deletions src/main/java/org/opensearch/knn/index/KNNMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

Expand Down Expand Up @@ -57,8 +58,10 @@ public ValidationException validate(KNNMethodContext knnMethodContext) {
if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) {
errorMessages.add(
String.format(
"\"%s\" configuration does not support space type: " + "\"%s\".",
Locale.ROOT,
"\"%s\" with \"%s\" configuration does not support space type: " + "\"%s\".",
this.methodComponent.getName(),
knnMethodContext.getKnnEngine().getName().toLowerCase(Locale.ROOT),
knnMethodContext.getSpaceType().getValue()
)
);
Expand Down Expand Up @@ -90,8 +93,10 @@ public ValidationException validateWithData(KNNMethodContext knnMethodContext, V
if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) {
errorMessages.add(
String.format(
"\"%s\" configuration does not support space type: " + "\"%s\".",
Locale.ROOT,
"\"%s\" with \"%s\" configuration does not support space type: " + "\"%s\".",
this.methodComponent.getName(),
knnMethodContext.getKnnEngine().getName().toLowerCase(Locale.ROOT),
knnMethodContext.getSpaceType().getValue()
)
);
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ public class KNNSettings {
*/
public static final boolean KNN_DEFAULT_FAISS_AVX2_DISABLED_VALUE = false;
public static final String INDEX_KNN_DEFAULT_SPACE_TYPE = "l2";
public static final String INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY = "hammingbit";
public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_M = 16;
public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH = 100;
public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION = 100;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ public KNNVectorFieldMapper build(BuilderContext context) {

// Build legacy
if (this.spaceType == null) {
this.spaceType = LegacyFieldMapper.getSpaceType(context.indexSettings());
this.spaceType = LegacyFieldMapper.getSpaceType(context.indexSettings(), vectorDataType.getValue());
}

if (this.m == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.index.mapper.ParametrizedFieldMapper;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.IndexHyperParametersUtil;
import org.opensearch.knn.index.util.KNNEngine;

Expand Down Expand Up @@ -78,17 +79,19 @@ public ParametrizedFieldMapper.Builder getMergeBuilder() {
);
}

static String getSpaceType(Settings indexSettings) {
static String getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) {
String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey());
if (spaceType == null) {
spaceType = VectorDataType.BINARY == vectorDataType
? KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY
: KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE;
log.info(
String.format(
"[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s",
METHOD_PARAMETER_SPACE_TYPE,
KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE
spaceType
)
);
return KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE;
}
return spaceType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,11 @@ protected Query doToQuery(QueryShardContext context) {
String.format(Locale.ROOT, "Engine [%s] does not support radial search", knnEngine)
);
}
if (vectorDataType == VectorDataType.BINARY) {
throw new UnsupportedOperationException(
String.format(Locale.ROOT, "Binary data type does not support radial search")
);
}
RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(indexName)
Expand Down
9 changes: 8 additions & 1 deletion src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,15 @@ private boolean canDoExactSearch(final int filterIdsCount) {
if (isExactSearchThresholdSettingSet(filterThresholdValue)) {
return filterThresholdValue >= filterIdsCount;
}

// if no setting is set, then use the default max distance computation value to see if we can do exact search.
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * knnQuery.getQueryVector().length;
/**
* TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index
* is cheaper than computation cost for non binary vector
*/
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT
? knnQuery.getQueryVector().length
: knnQuery.getByteQueryVector().length);
}

/**
Expand Down
21 changes: 21 additions & 0 deletions src/test/java/org/opensearch/knn/index/IndexUtilTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,27 @@ public void testValidateKnnField_EmptyIndexMetadata() {
assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Invalid index. Index does not contain a mapping;"));
}

public void testValidateKnnField_whenBinaryDataType_thenThrowException() {
Map<String, Object> fieldValues = Map.of("type", "knn_vector", "dimension", 8, "data_type", "BINARY");
Map<String, Object> top_level_field = Map.of("top_level_field", fieldValues);
Map<String, Object> properties = Map.of("properties", top_level_field);
String field = "top_level_field";
int dimension = 8;

MappingMetadata mappingMetadata = mock(MappingMetadata.class);
when(mappingMetadata.getSourceAsMap()).thenReturn(properties);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
ModelDao modelDao = mock(ModelDao.class);
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao);

assert (Objects.requireNonNull(e).getMessage().contains("is of data type BINARY. Only FLOAT or BYTE is supported"));
}

public void testIsShareableStateContainedInIndex_whenIndexNotModelBased_thenReturnFalse() {
String modelId = null;
KNNEngine knnEngine = KNNEngine.FAISS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,24 +192,38 @@ public void testBuilder_build_fromLegacy() {
ModelDao modelDao = mock(ModelDao.class);
KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT);

SpaceType spaceType = SpaceType.COSINESIMIL;
int m = 17;
int efConstruction = 17;

// Setup settings
Settings settings = Settings.builder()
.put(settings(CURRENT).build())
.put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue())
.put(KNNSettings.KNN_ALGO_PARAM_M, m)
.put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction)
.build();

Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath());
KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext);
assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper);

assertNull(knnVectorFieldMapper.modelId);
assertNull(knnVectorFieldMapper.knnMethod);
assertEquals(SpaceType.L2.getValue(), ((LegacyFieldMapper) knnVectorFieldMapper).spaceType);
}

public void testBuilder_whenKnnFalseWithBinary_thenSetHammingAsDefault() {
// Check legacy is picked up if model context and method context are not set
ModelDao modelDao = mock(ModelDao.class);
KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT);
builder.vectorDataType.setValue(VectorDataType.BINARY);
builder.dimension.setValue(8);

// Setup settings
Settings settings = Settings.builder().put(settings(CURRENT).build()).build();

Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath());
KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext);
assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper);
assertEquals(SpaceType.HAMMING_BIT.getValue(), ((LegacyFieldMapper) knnVectorFieldMapper).spaceType);
}

public void testBuilder_parse_fromKnnMethodContext_luceneEngine() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,30 @@ public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_then
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
}

public void testDoToQuery_whenRadialSearchOnBinaryIndex_thenException() {
float[] queryVector = { 1.0f };
KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder()
.fieldName(FIELD_NAME)
.vector(queryVector)
.maxDistance(MAX_DISTANCE)
.build();
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(8);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY);
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
MethodComponentContext methodComponentContext = new MethodComponentContext(
org.opensearch.knn.common.KNNConstants.METHOD_HNSW,
ImmutableMap.of()
);
KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.HAMMING_BIT, methodComponentContext);
when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext);
Exception e = expectThrows(UnsupportedOperationException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
assertTrue(e.getMessage().contains("Binary data type does not support radial search"));
}

public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception {
// Given
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
Expand Down
68 changes: 68 additions & 0 deletions src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,74 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
}

/**
* This test ensure that we do the exact search when threshold settings are correct and not using filteredIds<=K
* condition to do exact search on binary index
* FilteredIdThreshold: 10
* FilteredIdThresholdPct: 10%
* FilteredIdsCount: 6
* liveDocs : null, as there is no deleted documents
* MaxDoc: 100
* K : 1
*/
@SneakyThrows
public void testANNWithFilterQuery_whenExactSearchViaThresholdSettingOnBinaryIndex_thenSuccess() {
knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(10);
byte[] vector = new byte[] { 1, 3 };
int k = 1;
final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 };

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);
when(reader.maxDoc()).thenReturn(100);
when(reader.getLiveDocs()).thenReturn(null);
final Weight filterQueryWeight = mock(Weight.class);
final Scorer filterScorer = mock(Scorer.class);
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);

when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length));

final KNNQuery query = new KNNQuery(FIELD_NAME, BYTE_QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null, VectorDataType.BINARY);

final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);
final Map<String, String> attributesMap = ImmutableMap.of(
KNN_ENGINE,
KNNEngine.FAISS.getName(),
SPACE_TYPE,
SpaceType.HAMMING_BIT.name(),
PARAMETERS,
String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "BHNSW32")
);
final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class);
when(reader.getFieldInfos()).thenReturn(fieldInfos);
when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo);
when(fieldInfo.attributes()).thenReturn(attributesMap);
when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.HAMMING_BIT.getValue());
when(fieldInfo.getName()).thenReturn(FIELD_NAME);
when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues);
when(binaryDocValues.advance(0)).thenReturn(0);
BytesRef vectorByteRef = new BytesRef(vector);
when(binaryDocValues.binaryValue()).thenReturn(vectorByteRef);

final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);
assertNotNull(knnScorer);
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
assertNotNull(docIdSetIterator);
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost());

final List<Integer> actualDocIds = new ArrayList<>();
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
assertEquals(BINARY_EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f);
}
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
}

@SneakyThrows
public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() {
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
Expand Down

0 comments on commit 008ec15

Please sign in to comment.