diff --git a/build.gradle b/build.gradle index 7fdc16cea..b795e6031 100644 --- a/build.gradle +++ b/build.gradle @@ -203,7 +203,7 @@ dependencies { version { strictly("${jacksonVersion}") } } // Multi-tenant SDK Client - implementation "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_version}" + implementation "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}" // ZipArchive dependencies used for integration tests zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 01f557c02..440ddaeac 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -151,7 +151,9 @@ public Collection createComponents( Map.entry(TENANT_AWARE_KEY, "true"), Map.entry(TENANT_ID_FIELD_KEY, TENANT_ID_FIELD) ) - : Collections.emptyMap() + : Collections.emptyMap(), + // TODO: Find a better thread pool or make one + client.threadPool().executor(ThreadPool.Names.GENERIC) ); EncryptorUtils encryptorUtils = new EncryptorUtils(clusterService, client, sdkClient, xContentRegistry); FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler( diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 954291368..ccdd3281a 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -68,16 +68,13 @@ import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX_MAPPING; -import static org.opensearch.flowframework.common.CommonValue.DEPROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_MAPPING; import static org.opensearch.flowframework.common.CommonValue.META; import static org.opensearch.flowframework.common.CommonValue.NO_SCHEMA_VERSION; -import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.SCHEMA_VERSION_FIELD; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX_MAPPING; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; /** @@ -359,7 +356,7 @@ private void putOrReplaceTemplateInGlobalContextIndex(String documentId, Templat .dataObject(encryptorUtils.encryptTemplateCredentials(template)) .build(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - sdkClient.putDataObjectAsync(request, client.threadPool().executor(WORKFLOW_THREAD_POOL)).whenComplete((r, throwable) -> { + sdkClient.putDataObjectAsync(request).whenComplete((r, throwable) -> { context.restore(); if (throwable == null) { try { @@ -426,24 +423,23 @@ public void putInitialStateToWorkflowState(String workflowId, String tenantId, U .dataObject(state) .build(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - sdkClient.putDataObjectAsync(putRequest, client.threadPool().executor(PROVISION_WORKFLOW_THREAD_POOL)) - .whenComplete((r, throwable) -> { - context.restore(); - if (throwable == null) { - try { - IndexResponse indexResponse = IndexResponse.fromXContent(r.parser()); - listener.onResponse(indexResponse); - } catch (IOException e) { - logger.error("Failed to parse index response", e); - listener.onFailure(new FlowFrameworkException("Failed to parse index response", INTERNAL_SERVER_ERROR)); - } - } else { - Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); - String errorMessage = "Failed to put state index document"; - logger.error(errorMessage, exception); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + sdkClient.putDataObjectAsync(putRequest).whenComplete((r, throwable) -> { + context.restore(); + if (throwable == null) { + try { + IndexResponse indexResponse = IndexResponse.fromXContent(r.parser()); + listener.onResponse(indexResponse); + } catch (IOException e) { + logger.error("Failed to parse index response", e); + listener.onFailure(new FlowFrameworkException("Failed to parse index response", INTERNAL_SERVER_ERROR)); } - }); + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + String errorMessage = "Failed to put state index document"; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }); } }, e -> { String errorMessage = "Failed to create workflow_state index"; @@ -552,7 +548,7 @@ public void getTemplate(String documentId, String tenantId, ActionListener { + sdkClient.getDataObjectAsync(getRequest).whenComplete((r, throwable) -> { context.restore(); if (throwable == null) { try { @@ -586,7 +582,7 @@ public void getWorkflowState(String workflowId, String tenantId, ActionListener< .id(workflowId) .tenantId(tenantId) .build(); - sdkClient.getDataObjectAsync(getRequest, client.threadPool().executor(WORKFLOW_THREAD_POOL)).whenComplete((r, throwable) -> { + sdkClient.getDataObjectAsync(getRequest).whenComplete((r, throwable) -> { context.restore(); if (throwable == null) { try { @@ -799,31 +795,30 @@ public void deleteFlowFrameworkSystemIndexDoc(String documentId, String tenantId .id(documentId) .tenantId(tenantId) .build(); - sdkClient.deleteDataObjectAsync(deleteRequest, client.threadPool().executor(DEPROVISION_WORKFLOW_THREAD_POOL)) - .whenComplete((r, throwable) -> { - context.restore(); - if (throwable == null) { - try { - DeleteResponse response = DeleteResponse.fromXContent(r.parser()); - logger.info("Deleted workflow state doc: {}", documentId); - listener.onResponse(response); - } catch (Exception e) { - logger.error("Failed to parse delete response", e); - listener.onFailure( - new FlowFrameworkException("Failed to parse delete response", RestStatus.INTERNAL_SERVER_ERROR) - ); - } - } else { - Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( - "Failed to delete {} entry : {}", - WORKFLOW_STATE_INDEX, - documentId - ).getFormattedMessage(); - logger.error(errorMessage, exception); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + sdkClient.deleteDataObjectAsync(deleteRequest).whenComplete((r, throwable) -> { + context.restore(); + if (throwable == null) { + try { + DeleteResponse response = DeleteResponse.fromXContent(r.parser()); + logger.info("Deleted workflow state doc: {}", documentId); + listener.onResponse(response); + } catch (Exception e) { + logger.error("Failed to parse delete response", e); + listener.onFailure( + new FlowFrameworkException("Failed to parse delete response", RestStatus.INTERNAL_SERVER_ERROR) + ); } - }); + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( + "Failed to delete {} entry : {}", + WORKFLOW_STATE_INDEX, + documentId + ).getFormattedMessage(); + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }); } } } @@ -959,10 +954,7 @@ private void getAndUpdateResourceInStateDocumentWithRetries( .id(workflowId) .tenantId(tenantId) .build(); - sdkClient.getDataObjectAsync( - getRequest, - client.threadPool().executor(operation == OpType.DELETE ? DEPROVISION_WORKFLOW_THREAD_POOL : PROVISION_WORKFLOW_THREAD_POOL) - ).whenComplete((r, throwable) -> { + sdkClient.getDataObjectAsync(getRequest).whenComplete((r, throwable) -> { if (throwable == null) { try { GetResponse getResponse = GetResponse.fromXContent(r.parser()); @@ -1008,10 +1000,7 @@ private void handleStateGetResponse( .ifSeqNo(getResponse.getSeqNo()) .ifPrimaryTerm(getResponse.getPrimaryTerm()) .build(); - sdkClient.updateDataObjectAsync( - updateRequest, - client.threadPool().executor(operation == OpType.DELETE ? DEPROVISION_WORKFLOW_THREAD_POOL : PROVISION_WORKFLOW_THREAD_POOL) - ).whenComplete((r, throwable) -> { + sdkClient.updateDataObjectAsync(updateRequest).whenComplete((r, throwable) -> { if (throwable == null) { handleStateUpdateSuccess(workflowId, resource, operation, listener); } else { diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 2e2f89de5..07465200b 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -58,7 +58,6 @@ import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; import static org.opensearch.flowframework.util.ParseUtils.checkFilterByBackendRoles; import static org.opensearch.flowframework.util.ParseUtils.getUserContext; @@ -372,8 +371,7 @@ private void createExecute(WorkflowRequest request, User user, String tenantId, logger.info("Querying existing workflow from global context: {}", workflowId); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { sdkClient.getDataObjectAsync( - GetDataObjectRequest.builder().index(GLOBAL_CONTEXT_INDEX).id(workflowId).tenantId(tenantId).build(), - client.threadPool().executor(WORKFLOW_THREAD_POOL) + GetDataObjectRequest.builder().index(GLOBAL_CONTEXT_INDEX).id(workflowId).tenantId(tenantId).build() ).whenComplete((r, throwable) -> { if (throwable == null) { context.restore(); @@ -516,24 +514,23 @@ void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, String ten .tenantId(tenantId) .build(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - sdkClient.searchDataObjectAsync(searchRequest, client.threadPool().executor(WORKFLOW_THREAD_POOL)) - .whenComplete((r, throwable) -> { - if (throwable == null) { - context.restore(); - try { - SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); - internalListener.onResponse(searchResponse.getHits().getTotalHits().value < maxWorkflow); - } catch (Exception e) { - logger.error("Failed to parse workflow searchResponse", e); - internalListener.onFailure(e); - } - } else { - Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); - String errorMessage = "Unable to fetch the workflows"; - logger.error(errorMessage, exception); - internalListener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + sdkClient.searchDataObjectAsync(searchRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + context.restore(); + try { + SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); + internalListener.onResponse(searchResponse.getHits().getTotalHits().value < maxWorkflow); + } catch (Exception e) { + logger.error("Failed to parse workflow searchResponse", e); + internalListener.onFailure(e); } - }); + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + String errorMessage = "Unable to fetch the workflows"; + logger.error(errorMessage, exception); + internalListener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }); } catch (Exception e) { String errorMessage = "Unable to fetch the workflows"; logger.error(errorMessage, e); diff --git a/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java index 637bc2d80..070ac3214 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java @@ -36,7 +36,6 @@ import static org.opensearch.flowframework.common.CommonValue.CLEAR_STATUS; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; import static org.opensearch.flowframework.util.ParseUtils.getUserContext; import static org.opensearch.flowframework.util.ParseUtils.resolveUserAndExecute; @@ -143,7 +142,7 @@ private void executeDeleteRequest( .id(workflowId) .tenantId(tenantId) .build(); - sdkClient.deleteDataObjectAsync(deleteRequest, client.threadPool().executor(WORKFLOW_THREAD_POOL)).whenComplete((r, throwable) -> { + sdkClient.deleteDataObjectAsync(deleteRequest).whenComplete((r, throwable) -> { context.restore(); if (throwable == null) { try { diff --git a/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java b/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java index 586aba9cd..96237b46c 100644 --- a/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java +++ b/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java @@ -29,7 +29,6 @@ import java.util.Arrays; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.util.ParseUtils.isAdmin; import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; @@ -122,7 +121,7 @@ private void doSearch(SearchRequest request, String tenantId, ActionListener { + sdkClient.searchDataObjectAsync(searchRequest).whenComplete((r, throwable) -> { if (throwable == null) { try { SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); diff --git a/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java index 51a301759..5036a248f 100644 --- a/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java @@ -61,7 +61,6 @@ import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX; import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; import static org.opensearch.flowframework.common.CommonValue.MASTER_KEY; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; /** * Encryption utility class @@ -324,7 +323,7 @@ private void generateAndIndexNewMasterKey(String tenantId, ActionListener { + sdkClient.putDataObjectAsync(putRequest).whenComplete((r, throwable) -> { if (throwable == null) { context.restore(); // Set generated key to master @@ -372,8 +371,7 @@ private CompletableFuture cacheMasterKeyFromConfigIndex(String tenantId) { .id(masterKeyId) .tenantId(tenantId) .fetchSourceContext(fetchSourceContext) - .build(), - client.threadPool().executor(WORKFLOW_THREAD_POOL) + .build() ).whenComplete((r, throwable) -> { context.restore(); if (throwable == null) { diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 36b00e483..1994b6032 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -68,7 +68,6 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; /** * Utility methods for Template parsing @@ -418,7 +417,7 @@ public static void getWorkflow( .id(workflowId) .tenantId(tenantId) .build(); - sdkClient.getDataObjectAsync(request, client.threadPool().executor(WORKFLOW_THREAD_POOL)).whenComplete((r, throwable) -> { + sdkClient.getDataObjectAsync(request).whenComplete((r, throwable) -> { if (throwable == null) { GetResponse getResponse; try { diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index d17ab57a1..bec97becc 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -64,6 +64,7 @@ public void setUp() throws Exception { when(client.admin()).thenReturn(adminClient); when(adminClient.cluster()).thenReturn(clusterAdminClient); threadPool = new TestThreadPool(FlowFrameworkPluginTests.class.getName()); + when(client.threadPool()).thenReturn(threadPool); environment = mock(Environment.class); settings = Settings.builder().build(); diff --git a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java index acc0b0fb5..1f862c96f 100644 --- a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java +++ b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java @@ -150,7 +150,12 @@ public void setUp() throws Exception { MockitoAnnotations.openMocks(this); when(client.threadPool()).thenReturn(testThreadPool); - sdkClient = SdkClientFactory.createSdkClient(client, namedXContentRegistry, Collections.emptyMap()); + sdkClient = SdkClientFactory.createSdkClient( + client, + namedXContentRegistry, + Collections.emptyMap(), + testThreadPool.executor(ThreadPool.Names.SAME) + ); flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler( client, sdkClient, diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 80bd05d0a..585318a6c 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -126,7 +126,12 @@ public void setUp() throws Exception { super.setUp(); client = mock(Client.class); when(client.threadPool()).thenReturn(testThreadPool); - this.sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + this.sdkClient = SdkClientFactory.createSdkClient( + client, + NamedXContentRegistry.EMPTY, + Collections.emptyMap(), + testThreadPool.executor(ThreadPool.Names.SAME) + ); this.flowFrameworkSettings = mock(FlowFrameworkSettings.class); when(flowFrameworkSettings.getMaxWorkflows()).thenReturn(2); @@ -728,7 +733,6 @@ public void testUpdateWorkflowWithField() throws IOException, InterruptedExcepti latchedActionListener = new LatchedActionListener<>(listener, latch); createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, latchedActionListener); latch.await(2, TimeUnit.SECONDS); - createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); verify(listener, times(2)).onResponse(any()); ArgumentCaptor