Skip to content

Commit

Permalink
Fix Faiss byte vector efficient filter bug
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
  • Loading branch information
naveentatikonda committed Dec 5, 2024
1 parent b0d82b7 commit 5af0fa2
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,9 @@ private boolean canDoExactSearch(final int filterIdsCount) {
* TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index
* is cheaper than computation cost for non binary vector
*/
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT
? knnQuery.getQueryVector().length
: knnQuery.getByteQueryVector().length);
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.BINARY
? knnQuery.getByteQueryVector().length
: knnQuery.getQueryVector().length);
}

/**
Expand Down
98 changes: 98 additions & 0 deletions src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index;

import com.google.common.collect.ImmutableMap;
import lombok.SneakyThrows;
import org.apache.http.util.EntityUtils;
import org.junit.After;
Expand All @@ -17,6 +18,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.common.KNNConstants;
Expand All @@ -25,11 +27,13 @@
import org.opensearch.script.Script;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
Expand Down Expand Up @@ -62,6 +66,7 @@ public class VectorDataTypeIT extends KNNRestTestCase {
private static final String KNN_VECTOR_TYPE = "knn_vector";
private static final int EF_CONSTRUCTION = 128;
private static final int M = 16;
private static final String COLOR_FIELD_NAME = "color";
private static final QueryBuilder MATCH_ALL_QUERY_BUILDER = new MatchAllQueryBuilder();

@After
Expand Down Expand Up @@ -689,6 +694,99 @@ public void testIVFByteVector_whenIndexedAndQueried_thenSucceed() {
deleteModel(modelId);
}

@SneakyThrows
public void testQueryWithFilterFaissByteVector_withDifferentCombination_thenSuccess() {
setupKNNFaissByteIndexForFilterQuery();
final Byte[] searchVector = { 6, 6, 4 };
// K > filteredResults
int kGreaterThanFilterResult = 5;
List<String> expectedDocIds = Arrays.asList("1", "3");
final Response response = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(
FIELD_NAME,
convertByteToFloatArray(searchVector),
kGreaterThanFilterResult,
QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")
),
kGreaterThanFilterResult
);
final String responseBody = EntityUtils.toString(response.getEntity());
final List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);

assertEquals(expectedDocIds.size(), knnResults.size());
assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds));

// K Limits Filter results
int kLimitsFilterResult = 1;
List<String> expectedDocIdsKLimitsFilterResult = List.of("1");
final Response responseKLimitsFilterResult = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(
FIELD_NAME,
convertByteToFloatArray(searchVector),
kLimitsFilterResult,
QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")
),
kLimitsFilterResult
);
final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity());
final List<KNNResult> knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME);

assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size());
assertTrue(
knnResultsKLimitsFilterResult.stream()
.map(KNNResult::getDocId)
.collect(Collectors.toList())
.containsAll(expectedDocIdsKLimitsFilterResult)
);

// Empty filter docIds
int k = 10;
final Response emptyFilterResponse = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(
FIELD_NAME,
convertByteToFloatArray(searchVector),
kLimitsFilterResult,
QueryBuilders.termQuery(COLOR_FIELD_NAME, "color_not_present")
),
k
);
final String responseBodyForEmptyDocIds = EntityUtils.toString(emptyFilterResponse.getEntity());
final List<KNNResult> emptyKNNFilteredResultsFromResponse = parseSearchResponse(responseBodyForEmptyDocIds, FIELD_NAME);

assertEquals(0, emptyKNNFilteredResultsFromResponse.size());
}

protected void setupKNNFaissByteIndexForFilterQuery() throws Exception {
// Create Mappings
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME)
.field("type", "knn_vector")
.field("dimension", 3)
.field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue())
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2)
.field(KNN_ENGINE, KNNEngine.FAISS.getName())
.endObject()
.endObject()
.endObject()
.endObject();
final String mapping = builder.toString();

createKnnIndex(INDEX_NAME, getKNNDefaultIndexSettings(), mapping);

addKnnDocWithAttributes(INDEX_NAME, "1", FIELD_NAME, new Byte[] { 6, 7, 3 }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));
addKnnDocWithAttributes(INDEX_NAME, "2", FIELD_NAME, new Byte[] { 3, 2, 4 }, ImmutableMap.of(COLOR_FIELD_NAME, "green"));
addKnnDocWithAttributes(INDEX_NAME, "3", FIELD_NAME, new Byte[] { 4, 5, 7 }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));

refreshIndex(INDEX_NAME);
}

@SneakyThrows
private void ingestL2ByteTestData() {
Byte[] b1 = { 6, 6 };
Expand Down
4 changes: 2 additions & 2 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -1786,11 +1786,11 @@ protected void addKnnDocWithAttributes(String docId, float[] vector, Map<String,
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

protected void addKnnDocWithAttributes(
protected <T> void addKnnDocWithAttributes(
String indexName,
String docId,
String vectorFieldName,
float[] vector,
T vector,
Map<String, String> fieldValues
) throws IOException {
Request request = new Request("POST", "/" + indexName + "/_doc/" + docId + "?refresh=true");
Expand Down

0 comments on commit 5af0fa2

Please sign in to comment.