Skip to content

Commit

Permalink
Fix shard level rescoring disabled setting flag
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
  • Loading branch information
naveentatikonda committed Dec 24, 2024
1 parent 646d8b7 commit b6780d5
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) {
.getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE);
}

public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String indexName) {
public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) {
return KNNSettings.state().clusterService.state()
.getMetadata()
.index(indexName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
if (rescoreContext == null) {
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK);
} else {
boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName());
boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName());
int dimension = knnQuery.getQueryVector().length;
int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension);
int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension);
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
if (isShardLevelRescoringEnabled == true) {
if (isShardLevelRescoringDisabled == false) {
ResultUtil.reduceToTopK(perLeafResults, firstPassK);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ public static RescoreContext getDefault() {
* based on the vector dimension if shard-level rescoring is disabled.
*
* @param finalK The final number of results to return for the entire shard.
* @param isShardLevelRescoringEnabled A boolean flag indicating whether shard-level rescoring is enabled.
* If true, the dimension-based oversampling logic is bypassed.
* @param isShardLevelRescoringDisabled A boolean flag indicating whether shard-level rescoring is disabled.
* If false, the dimension-based oversampling logic is bypassed.
* @param dimension The dimension of the vector. This is used to determine the oversampling factor when
* shard-level rescoring is disabled.
* @return The number of results to return for the first pass of rescoring, adjusted by the oversample factor.
*/
public int getFirstPassK(int finalK, boolean isShardLevelRescoringEnabled, int dimension) {
public int getFirstPassK(int finalK, boolean isShardLevelRescoringDisabled, int dimension) {
// Only apply default dimension-based oversampling logic when:
// 1. Shard-level rescoring is disabled
// 2. The oversample factor was not provided by the user
if (!isShardLevelRescoringEnabled && !userProvided) {
if (isShardLevelRescoringDisabled && !userProvided) {
// Apply new dimension-based oversampling logic when shard-level rescoring is disabled
if (dimension >= DIMENSION_THRESHOLD_1000) {
oversampleFactor = OVERSAMPLE_FACTOR_1000; // No oversampling for dimensions >= 1000
Expand Down
6 changes: 3 additions & 3 deletions src/test/java/org/opensearch/knn/index/KNNSettingsTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() {
}

@SneakyThrows
public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
Node mockNode = createMockNode(Collections.emptyMap());
mockNode.start();
ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class);
mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet();
mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet();
KNNSettings.state().setClusterService(clusterService);

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME);
boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertFalse(shardLevelRescoringDisabled);
}
Expand All @@ -188,7 +188,7 @@ public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingA

mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet();

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME);
boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() {
) {

// When shard-level re-scoring is enabled
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true);
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false);

// Mock ResultUtil to return valid TopDocs
mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(any(), anyInt()))
Expand Down Expand Up @@ -265,7 +265,7 @@ public void testRescore() {
) {

// When shard-level re-scoring is enabled
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true);
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false);

mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod);
mockedResultUtil.when(() -> ResultUtil.resultMapToDocIds(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod);
Expand Down

0 comments on commit b6780d5

Please sign in to comment.