Skip to content

Commit

Permalink
Introduces NativeEngineKNNQuery which executes ANN on rewrite
Browse files Browse the repository at this point in the history
Signed-off-by: Tejas Shah <shatejas@amazon.com>
  • Loading branch information
shatejas committed Aug 1, 2024
1 parent 245995a commit 6f0d4b3
Show file tree
Hide file tree
Showing 16 changed files with 965 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.common.launchcontrol;

import org.opensearch.common.settings.Setting;
import org.opensearch.knn.index.KNNSettings;

import java.util.stream.Stream;

public final class KNNLaunchControl {

// Feature flags
public static final String KNN_LAUNCH_QUERY_REWRITE_ENABLED = "knn.launch.query.rewrite.enabled";
private static final boolean KNN_LAUNCH_QUERY_REWRITE_ENABLED_DEFAULT = true;

public static Setting<Boolean> KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING = Setting.boolSetting(
KNN_LAUNCH_QUERY_REWRITE_ENABLED,
KNN_LAUNCH_QUERY_REWRITE_ENABLED_DEFAULT,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

public static Stream<Setting<?>> getFeatureFlags() {
return Stream.of(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING);
}

public static boolean isFeatureEnabled(final String featureName) {
return Boolean.parseBoolean(KNNSettings.state().getSettingValue(featureName).toString());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.common.launchcontrol;

import org.apache.lucene.search.Query;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;

public final class LaunchControlUtils {

public static Query getNativeKNNQuery(final KNNQuery knnQuery) {
if (KNNLaunchControl.isFeatureEnabled(KNNLaunchControl.KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING.getKey())) {
return new NativeEngineKnnVectorQuery(knnQuery);
}
return knnQuery;
}
}
17 changes: 14 additions & 3 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,30 @@
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.index.IndexModule;
import org.opensearch.knn.common.launchcontrol.KNNLaunchControl;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryCacheManagerDto;
import org.opensearch.knn.index.util.IndexHyperParametersUtil;
import org.opensearch.monitor.jvm.JvmInfo;
import org.opensearch.monitor.os.OsProbe;

import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.stream.Collectors.toMap;
import static org.opensearch.common.settings.Setting.Property.Dynamic;
import static org.opensearch.common.settings.Setting.Property.IndexScope;
import static org.opensearch.common.settings.Setting.Property.NodeScope;
import static org.opensearch.core.common.unit.ByteSizeValue.parseBytesSizeValue;
import static org.opensearch.common.unit.MemorySizeValue.parseBytesSizeValueOrHeapRatio;
import static org.opensearch.knn.common.launchcontrol.KNNLaunchControl.getFeatureFlags;

/**
* This class defines
Expand Down Expand Up @@ -289,6 +292,9 @@ public class KNNSettings {
}
};

private static Map<String, Setting<?>> FEATURE_FLAGS = KNNLaunchControl.getFeatureFlags()
.collect(toMap(Setting::getKey, Function.identity()));

private ClusterService clusterService;
private Client client;

Expand Down Expand Up @@ -326,7 +332,7 @@ private void setSettingsUpdateConsumers() {
);

NativeMemoryCacheManager.getInstance().rebuildCache(builder.build());
}, new ArrayList<>(dynamicCacheSettings.values()));
}, Stream.concat(dynamicCacheSettings.values().stream(), FEATURE_FLAGS.values().stream()).collect(Collectors.toList()));
}

/**
Expand All @@ -346,6 +352,10 @@ private Setting<?> getSetting(String key) {
return dynamicCacheSettings.get(key);
}

if (FEATURE_FLAGS.containsKey(key)) {
return FEATURE_FLAGS.get(key);
}

if (KNN_CIRCUIT_BREAKER_TRIGGERED.equals(key)) {
return KNN_CIRCUIT_BREAKER_TRIGGERED_SETTING;
}
Expand Down Expand Up @@ -390,7 +400,8 @@ public List<Setting<?>> getSettings() {
KNN_FAISS_AVX2_DISABLED_SETTING,
KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING
);
return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream()).collect(Collectors.toList());
return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags(), dynamicCacheSettings.values().stream()))
.collect(Collectors.toList());
}

public static boolean isKNNPluginEnabled() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.launchcontrol.LaunchControlUtils.getNativeKNNQuery;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;

/**
Expand Down Expand Up @@ -98,9 +99,10 @@ public static Query create(CreateQueryRequest createQueryRequest) {
methodParameters
);

KNNQuery knnQuery = null;
switch (vectorDataType) {
case BINARY:
return KNNQuery.builder()
knnQuery = KNNQuery.builder()
.field(fieldName)
.byteQueryVector(byteVector)
.indexName(indexName)
Expand All @@ -110,8 +112,9 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.filterQuery(validatedFilterQuery)
.vectorDataType(vectorDataType)
.build();
break;
default:
return KNNQuery.builder()
knnQuery = KNNQuery.builder()
.field(fieldName)
.queryVector(vector)
.indexName(indexName)
Expand All @@ -122,6 +125,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.vectorDataType(vectorDataType)
.build();
}
return getNativeKNNQuery(knnQuery);
}

Integer requestEfSearch = null;
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ public float score() throws IOException {
public int docID() {
return docIdsIter.docID();
}

@Override
public boolean equals(Object obj) {
if (!(obj instanceof Scorer)) return false;
return getWeight().equals(((Scorer) obj).getWeight());
}
};

}
}
17 changes: 13 additions & 4 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,23 @@ public Explanation explain(LeafReaderContext context, int doc) {

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
final Map<Integer, Float> docIdToScoreMap = searchLeaf(context);
if (docIdToScoreMap.isEmpty()) {
return KNNScorer.emptyScorer(this);
}

return convertSearchResponseToScorer(docIdToScoreMap);
}

public Map<Integer, Float> searchLeaf(LeafReaderContext context) throws IOException {

final BitSet filterBitSet = getFilteredDocsBitSet(context);
int cardinality = filterBitSet.cardinality();
// We don't need to go to JNI layer if no documents are found which satisfy the filters
// We should give this condition a deeper look that where it should be placed. For now I feel this is a good
// place,
if (filterWeight != null && cardinality == 0) {
return KNNScorer.emptyScorer(this);
return Collections.emptyMap();
}
final Map<Integer, Float> docIdsToScoreMap = new HashMap<>();

Expand All @@ -129,7 +138,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
} else {
Map<Integer, Float> annResults = doANNSearch(context, filterBitSet, cardinality);
if (annResults == null) {
return null;
return Collections.emptyMap();
}
if (canDoExactSearchAfterANNSearch(cardinality, annResults.size())) {
log.debug(
Expand All @@ -144,9 +153,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
docIdsToScoreMap.putAll(annResults);
}
if (docIdsToScoreMap.isEmpty()) {
return KNNScorer.emptyScorer(this);
return Collections.emptyMap();
}
return convertSearchResponseToScorer(docIdsToScoreMap);
return docIdsToScoreMap;
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
Expand Down
Loading

0 comments on commit 6f0d4b3

Please sign in to comment.