Skip to content

Commit

Permalink
Added backend role filtering for reprovisioning API
Browse files Browse the repository at this point in the history
Signed-off-by: owaiskazi19 <owaiskazi19@gmail.com>
  • Loading branch information
owaiskazi19 committed Aug 16, 2024
1 parent 83f6eb8 commit aca3d00
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 120 deletions.

Large diffs are not rendered by default.

24 changes: 13 additions & 11 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.join.ScoreMode;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.client.Client;
Expand All @@ -38,28 +37,27 @@
import org.opensearch.flowframework.transport.WorkflowResponse;
import org.opensearch.flowframework.transport.handler.WorkflowFunction;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.index.query.NestedQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.time.Instant;
import java.util.Map;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.Locale;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Optional;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -406,7 +404,9 @@ public static void getWorkflow(
);
}
} else {
listener.onFailure(new IndexNotFoundException(GLOBAL_CONTEXT_INDEX));
String errorMessage = "Failed to retrieve template (" + workflowId + ") from global context.";
logger.error(errorMessage);
listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND));
}
}

Expand Down Expand Up @@ -453,7 +453,9 @@ public static void onGetWorkflowResponse(
listener.onFailure(e);
}
} else {
listener.onFailure(new ResourceNotFoundException(workflowId));
String errorMessage = "Failed to retrieve template (" + workflowId + ") from global context.";
logger.error(errorMessage);
listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@
import org.junit.Before;

import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.ArrayList;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.util.Map;
import java.util.List;
import java.util.Set;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;

import org.mockito.ArgumentCaptor;
Expand Down Expand Up @@ -377,7 +377,7 @@ public void testCreateNewWorkflow() {
assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId());
}

public void testUpdateWorkflowWithReprovision() {
public void testUpdateWorkflowWithReprovision() throws IOException {
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(
Expand All @@ -400,6 +400,23 @@ public void testUpdateWorkflowWithReprovision() {
return null;
}).when(client).get(any(GetRequest.class), any());

GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX);
doAnswer(invocation -> {
Object[] args = invocation.getArguments();
assertEquals(
String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)),
2,
args.length
);

assertTrue(args[0] instanceof GetRequest);
assertTrue(args[1] instanceof ActionListener);

ActionListener<GetResponse> getListener = (ActionListener<GetResponse>) args[1];
getListener.onResponse(getWorkflowResponse);
return null;
}).when(client).get(any(GetRequest.class), any());

doAnswer(invocation -> {
ActionListener<WorkflowResponse> responseListener = invocation.getArgument(2);
responseListener.onResponse(new WorkflowResponse("1"));
Expand All @@ -413,7 +430,7 @@ public void testUpdateWorkflowWithReprovision() {
assertEquals("1", responseCaptor.getValue().getWorkflowId());
}

public void testFailedToUpdateWorkflowWithReprovision() {
public void testFailedToUpdateWorkflowWithReprovision() throws IOException {
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(
Expand All @@ -436,6 +453,23 @@ public void testFailedToUpdateWorkflowWithReprovision() {
return null;
}).when(client).get(any(GetRequest.class), any());

GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX);
doAnswer(invocation -> {
Object[] args = invocation.getArguments();
assertEquals(
String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)),
2,
args.length
);

assertTrue(args[0] instanceof GetRequest);
assertTrue(args[1] instanceof ActionListener);

ActionListener<GetResponse> getListener = (ActionListener<GetResponse>) args[1];
getListener.onResponse(getWorkflowResponse);
return null;
}).when(client).get(any(GetRequest.class), any());

doAnswer(invocation -> {
ActionListener<WorkflowResponse> responseListener = invocation.getArgument(2);
responseListener.onFailure(new Exception("failed"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import java.util.List;
import java.util.Map;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import org.mockito.ArgumentCaptor;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import java.util.List;
import java.util.Map;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
Expand All @@ -31,7 +33,10 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -75,6 +80,13 @@ public void setUp() throws Exception {
this.encryptorUtils = mock(EncryptorUtils.class);
this.pluginsService = mock(PluginsService.class);

ClusterService clusterService = mock(ClusterService.class);
ClusterSettings clusterSettings = new ClusterSettings(
Settings.EMPTY,
Collections.unmodifiableSet(new HashSet<>(Arrays.asList(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES)))
);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

this.reprovisionWorkflowTransportAction = new ReprovisionWorkflowTransportAction(
transportService,
actionFilters,
Expand All @@ -85,7 +97,10 @@ public void setUp() throws Exception {
flowFrameworkIndicesHandler,
flowFrameworkSettings,
encryptorUtils,
pluginsService
pluginsService,
clusterService,
xContentRegistry(),
Settings.EMPTY
);

ThreadPool clientThreadPool = mock(ThreadPool.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import org.opensearch.transport.TransportService;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class SearchWorkflowTransportActionTests extends OpenSearchTestCase {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
import static org.opensearch.flowframework.TestHelpers.matchAllRequest;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class SearchHandlerTests extends OpenSearchTestCase {

Expand Down

0 comments on commit aca3d00

Please sign in to comment.