From ff500da835443e59c279aec4b6fa4dbcd47c7d53 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Fri, 24 Jan 2025 17:14:09 -0800 Subject: [PATCH] Implement multi tenancy in Flow Framework (#980) * Import SdkClient and inject it Signed-off-by: Daniel Widdis * Pass sdkClient to IndicesHandler and EncryptorUtils classes Signed-off-by: Daniel Widdis * Extract tenant id from REST header into RestAction Signed-off-by: Daniel Widdis * Pass tenant id to transport actions in template Signed-off-by: Daniel Widdis * Validate tenant id existence in workflow transport actions Signed-off-by: Daniel Widdis * Pass SdkClient and tenant id to util used for access control checks Signed-off-by: Daniel Widdis * Perform tenant id validation checks for workflow APIs Signed-off-by: Daniel Widdis * Migrate Update workflow get action to SdkCleint Signed-off-by: Daniel Widdis * Pass tenantId to IndicesHandler and use in EncryptorUtils Signed-off-by: Daniel Widdis * Migrate EncryptorUtils getting master key from index Signed-off-by: Daniel Widdis * Refactor fetching master key to permit reuse Signed-off-by: Daniel Widdis * Refactor initializeMasterKey to use common code Signed-off-by: Daniel Widdis * Migrate indexing new key to config Signed-off-by: Daniel Widdis * Migrate template indexing to sdkClient Signed-off-by: Daniel Widdis * Migrate template deletion to sdkClient Signed-off-by: Daniel Widdis * Migrate get template to sdkClient Signed-off-by: Daniel Widdis * Migrate provision template to sdkClient Signed-off-by: Daniel Widdis * Migrate max workflow search to sdkClient Signed-off-by: Daniel Widdis * Add tenantId to GetWorkflowStateRequest Signed-off-by: Daniel Widdis * Migrate GetWorkflowStateRequest to multitenant client Signed-off-by: Daniel Widdis * Migrate getProvisioningProgress to avoid repetition Signed-off-by: Daniel Widdis * Migrate canDeleteWorkflowStateDoc to avoid repetition Signed-off-by: Daniel Widdis * Migrate initial state document creation to metadata client Signed-off-by: Daniel Widdis * Migrate state document deletion to metadata client Signed-off-by: Daniel Widdis * Add Tenant aware Rest Tests for Workflows Signed-off-by: Daniel Widdis * Fix javadocs Signed-off-by: Daniel Widdis * Add publishToMavenLocal for more CI Signed-off-by: Daniel Widdis * Fix some CI Signed-off-by: Daniel Widdis * Enable tenant aware search Signed-off-by: Daniel Widdis * Refactor state index update method using multitenant client Signed-off-by: Daniel Widdis * Get metadata client artifacts from Maven Snapshot Signed-off-by: Daniel Widdis * Update tests for new update async code Signed-off-by: Daniel Widdis * Switch SdkClient to use default generic thread executor Signed-off-by: Daniel Widdis * Migrate last updates to sdkClient Signed-off-by: Daniel Widdis * Revert (most) changes to unit tests based on async client changes Signed-off-by: Daniel Widdis * Pass tenant id when updating state during provisioning Signed-off-by: Daniel Widdis * Integrate tenantId with synchronous provisioning Signed-off-by: Daniel Widdis * Fix failing integ tests after rebase, code review updates Signed-off-by: Daniel Widdis * Replace fakeTenantId placeholders with actual tenant id Signed-off-by: Daniel Widdis * Use version catalog for commons-lang3 and httpcore dependencies Signed-off-by: Daniel Widdis * Exclude transitive httpclient dependency from metadata and rest client Signed-off-by: Daniel Widdis * Fix more test errors and tweak dependencies Signed-off-by: Daniel Widdis * More code review comments and refactoring Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis --- .github/workflows/CI.yml | 19 + .github/workflows/test_security.yml | 1 - CHANGELOG.md | 1 + build.gradle | 50 +- .../flowframework/FlowFrameworkPlugin.java | 61 +- .../flowframework/common/CommonValue.java | 11 + .../common/FlowFrameworkSettings.java | 83 +++ .../indices/FlowFrameworkIndicesHandler.java | 629 ++++++++++++------ .../flowframework/model/Template.java | 64 +- .../flowframework/model/WorkflowState.java | 46 +- .../rest/AbstractSearchWorkflowAction.java | 37 +- .../rest/RestCreateWorkflowAction.java | 114 ++-- .../rest/RestDeleteWorkflowAction.java | 5 +- .../rest/RestDeprovisionWorkflowAction.java | 5 +- .../rest/RestGetWorkflowAction.java | 6 +- .../rest/RestGetWorkflowStateAction.java | 4 +- .../rest/RestGetWorkflowStepAction.java | 2 + .../rest/RestProvisionWorkflowAction.java | 10 +- .../CreateWorkflowTransportAction.java | 324 +++++---- .../DeleteWorkflowTransportAction.java | 56 +- .../DeprovisionWorkflowTransportAction.java | 54 +- .../transport/GetWorkflowStateRequest.java | 21 +- .../GetWorkflowStateTransportAction.java | 65 +- .../transport/GetWorkflowTransportAction.java | 31 +- .../ProvisionWorkflowTransportAction.java | 68 +- .../ReprovisionWorkflowTransportAction.java | 63 +- .../SearchWorkflowStateTransportAction.java | 8 +- .../SearchWorkflowTransportAction.java | 8 +- .../transport/handler/SearchHandler.java | 53 +- .../flowframework/util/EncryptorUtils.java | 288 +++++--- .../flowframework/util/ParseUtils.java | 98 ++- .../flowframework/util/TenantAwareHelper.java | 89 +++ .../util/WorkflowTimeoutUtility.java | 20 +- .../workflow/AbstractCreatePipelineStep.java | 1 + .../AbstractRegisterLocalModelStep.java | 2 + .../AbstractRetryableWorkflowStep.java | 11 +- .../workflow/CreateConnectorStep.java | 1 + .../workflow/CreateIndexStep.java | 1 + .../workflow/DeployModelStep.java | 1 + .../workflow/RegisterAgentStep.java | 1 + .../workflow/RegisterModelGroupStep.java | 1 + .../workflow/RegisterRemoteModelStep.java | 2 + .../workflow/WorkflowProcessSorter.java | 42 +- .../resources/mappings/global-context.json | 5 +- .../resources/mappings/workflow-state.json | 5 +- .../FlowFrameworkPluginTests.java | 17 +- .../FlowFrameworkRestTestCase.java | 14 + .../FlowFrameworkTenantAwareRestTestCase.java | 211 ++++++ .../opensearch/flowframework/TestHelpers.java | 20 + .../common/FlowFrameworkSettingsTests.java | 4 +- .../FlowFrameworkIndicesHandlerTests.java | 83 ++- .../flowframework/model/TemplateTests.java | 27 +- .../rest/FlowFrameworkRestApiIT.java | 4 +- .../rest/RestCreateWorkflowActionTests.java | 3 +- .../rest/RestWorkflowStateTenantAwareIT.java | 356 ++++++++++ .../rest/RestWorkflowTenantAwareIT.java | 310 +++++++++ .../CreateWorkflowTransportActionTests.java | 323 ++++----- .../DeleteWorkflowTransportActionTests.java | 10 +- ...provisionWorkflowTransportActionTests.java | 61 +- .../GetWorkflowStateTransportActionTests.java | 45 +- .../GetWorkflowTransportActionTests.java | 24 +- ...ProvisionWorkflowTransportActionTests.java | 40 +- .../ReprovisionWorkflowRequestTests.java | 2 + ...provisionWorkflowTransportActionTests.java | 29 +- ...archWorkflowStateTransportActionTests.java | 7 +- .../SearchWorkflowTransportActionTests.java | 7 +- .../WorkflowRequestResponseTests.java | 4 +- .../transport/handler/SearchHandlerTests.java | 27 +- .../util/EncryptorUtilsTests.java | 112 ++-- .../flowframework/util/ParseUtilsTests.java | 2 - .../util/TenantAwareHelperTests.java | 158 +++++ .../util/WorkflowTimeoutUtilityTests.java | 6 +- .../workflow/CreateConnectorStepTests.java | 5 +- .../workflow/CreateIndexStepTests.java | 5 +- .../CreateIngestPipelineStepTests.java | 5 +- .../CreateSearchPipelineStepTests.java | 5 +- .../workflow/DeployModelStepTests.java | 5 +- .../workflow/RegisterAgentTests.java | 11 +- .../RegisterLocalCustomModelStepTests.java | 10 +- ...RegisterLocalPretrainedModelStepTests.java | 5 +- ...sterLocalSparseEncodingModelStepTests.java | 5 +- .../workflow/RegisterModelGroupStepTests.java | 5 +- .../RegisterRemoteModelStepTests.java | 24 +- .../workflow/WorkflowProcessSorterTests.java | 16 +- 84 files changed, 3377 insertions(+), 1097 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/util/TenantAwareHelper.java create mode 100644 src/test/java/org/opensearch/flowframework/FlowFrameworkTenantAwareRestTestCase.java create mode 100644 src/test/java/org/opensearch/flowframework/rest/RestWorkflowStateTenantAwareIT.java create mode 100644 src/test/java/org/opensearch/flowframework/rest/RestWorkflowTenantAwareIT.java create mode 100644 src/test/java/org/opensearch/flowframework/util/TenantAwareHelperTests.java diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a2319bd91..70e192f8f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -102,3 +102,22 @@ jobs: - name: Build and Run Tests run: | ./gradlew integTest -PnumNodes=3 + integTenantAwareTest: + needs: [spotless, javadoc] + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + java: [21] + name: Tenant Aware Integ Test JDK${{ matrix.java }}, ${{ matrix.os }} + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up JDK ${{ matrix.java }} + uses: actions/setup-java@v4 + with: + java-version: ${{ matrix.java }} + distribution: temurin + - name: Build and Run Tests + run: | + ./gradlew integTest "-Dtests.rest.tenantaware=true" diff --git a/.github/workflows/test_security.yml b/.github/workflows/test_security.yml index aca08b248..c48d13d55 100644 --- a/.github/workflows/test_security.yml +++ b/.github/workflows/test_security.yml @@ -34,7 +34,6 @@ jobs: steps: - name: Run start commands run: ${{ needs.Get-CI-Image-Tag.outputs.ci-image-start-command }} - - name: Checkout Flow Framework uses: actions/checkout@v4 - name: Setup Java ${{ matrix.java }} diff --git a/CHANGELOG.md b/CHANGELOG.md index ab5c330e5..80e293a30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.17...2.x) ### Features +- Add multitenant remote metadata client ([#980](https://github.com/opensearch-project/flow-framework/pull/980)) - Add synchronous execution option to workflow provisioning ([#990](https://github.com/opensearch-project/flow-framework/pull/990)) - Add ApiSpecFetcher for Fetching and Comparing API Specifications ([#651](https://github.com/opensearch-project/flow-framework/issues/651)) - Add optional config field to tool step ([#899](https://github.com/opensearch-project/flow-framework/pull/899)) diff --git a/build.gradle b/build.gradle index 970cf0f30..12eac2240 100644 --- a/build.gradle +++ b/build.gradle @@ -167,17 +167,18 @@ configurations { dependencies { implementation "org.opensearch:opensearch:${opensearch_version}" - implementation 'org.junit.jupiter:junit-jupiter:5.11.4' api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" - api group: 'org.opensearch.client', name: 'opensearch-rest-client', version: "${opensearch_version}" + implementation "org.apache.commons:commons-lang3:${versions.commonslang}" + api(group: 'org.opensearch.client', name: 'opensearch-rest-client', version: "${opensearch_version}") { + exclude group: "org.apache.httpcomponents.client5", module: "httpclient5" + } api group: 'org.slf4j', name: 'slf4j-api', version: '1.7.36' - implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.17.0' implementation "org.opensearch:common-utils:${common_utils_version}" implementation "com.amazonaws:aws-encryption-sdk-java:3.0.1" implementation "software.amazon.cryptography:aws-cryptographic-material-providers:1.8.0" implementation "org.dafny:DafnyRuntime:4.9.1" implementation "software.amazon.smithy.dafny:conversion:0.1.1" - implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1' + implementation 'org.bouncycastle:bcprov-jdk18on:1.80' implementation "jakarta.json.bind:jakarta.json.bind-api:3.0.1" implementation "org.glassfish:jakarta.json:2.0.1" implementation "org.eclipse:yasson:3.0.4" @@ -188,7 +189,11 @@ dependencies { implementation "io.swagger.parser.v3:swagger-parser-core:${swaggerVersion}" implementation "io.swagger.parser.v3:swagger-parser:${swaggerVersion}" implementation "io.swagger.parser.v3:swagger-parser-v3:${swaggerVersion}" - // Declare Jackson dependencies for tests (from OpenSearch version catalog) + // Multi-tenant SDK Client + implementation ("org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}") { + exclude group: "org.apache.httpcomponents.client5", module: "httpclient5" + } + testImplementation 'org.junit.jupiter:junit-jupiter:5.11.4' testImplementation "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}" testImplementation "com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}" testImplementation "com.fasterxml.jackson.datatype:jackson-datatype-jsr310:${versions.jackson_databind}" @@ -202,7 +207,6 @@ dependencies { configurations.all { resolutionStrategy { force("com.google.guava:guava:33.4.0-jre") // CVE for 31.1, keep to force transitive dependencies - force("org.apache.httpcomponents.core5:httpcore5:5.3.2") // Dependency Jar Hell } } } @@ -262,10 +266,19 @@ integTest { systemProperty('user', user) systemProperty('password', password) + // Only tenant aware test if set + if (System.getProperty("tests.rest.tenantaware") == "true") { + filter { + includeTestsMatching "org.opensearch.flowframework.*TenantAwareIT" + } + systemProperty "plugins.flow_framework.multi_tenancy_enabled", "true" + } + // Only rest case can run with remote cluster - if (System.getProperty("tests.rest.cluster") != null) { + if (System.getProperty("tests.rest.cluster") != null && System.getProperty("tests.rest.tenantaware") == null) { filter { includeTestsMatching "org.opensearch.flowframework.rest.*IT" + excludeTestsMatching "org.opensearch.flowframework.rest.*TenantAwareIT" } } @@ -288,11 +301,34 @@ integTest { filter { includeTestsMatching "org.opensearch.flowframework.rest.FlowFrameworkSecureRestApiIT" excludeTestsMatching "org.opensearch.flowframework.rest.FlowFrameworkRestApiIT" + excludeTestsMatching "org.opensearch.flowframework.rest.*TenantAwareIT" } } // doFirst delays this block until execution time doFirst { + if (System.getProperty("tests.rest.tenantaware") == "true") { + def ymlFile = file("$buildDir/testclusters/integTest-0/config/opensearch.yml") + if (ymlFile.exists()) { + ymlFile.withWriterAppend { + writer -> + writer.write("\n# Set multitenancy\n") + writer.write("plugins.flow_framework.multi_tenancy_enabled: true\n") + } + // TODO this properly uses the remote client factory but needs a remote cluster set up + // TODO get the endpoint from a system property + if (System.getProperty("tests.rest.cluster") != null) { + ymlFile.withWriterAppend { writer -> + writer.write("\n# Use a remote cluster\n") + writer.write("plugins.flow_framework.remote_metadata_type: RemoteOpenSearch\n") + writer.write("plugins.flow_framework.remote_metadata_endpoint: https://127.0.0.1:9200\n") + } + } + } else { + throw new GradleException("opensearch.yml not found at: $ymlFile") + } + } + // Tell the test JVM if the cluster JVM is running under a debugger so that tests can // use longer timeouts for requests. def isDebuggingCluster = getDebug() || System.getProperty("test.debug") != null diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 92121dce5..440ddaeac 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -65,6 +65,8 @@ import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.SystemIndexPlugin; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.repositories.RepositoriesService; import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; @@ -75,7 +77,9 @@ import org.opensearch.watcher.ResourceWatcherService; import java.util.Collection; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.function.Supplier; import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX; @@ -83,14 +87,26 @@ import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; +import static org.opensearch.flowframework.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_MULTI_TENANCY_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.REMOTE_METADATA_ENDPOINT; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.REMOTE_METADATA_REGION; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.REMOTE_METADATA_SERVICE_NAME; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.REMOTE_METADATA_TYPE; import static org.opensearch.flowframework.common.FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; +import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_ENDPOINT_KEY; +import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_REGION_KEY; +import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_SERVICE_NAME_KEY; +import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_TYPE_KEY; +import static org.opensearch.remote.metadata.common.CommonValue.TENANT_AWARE_KEY; +import static org.opensearch.remote.metadata.common.CommonValue.TENANT_ID_FIELD_KEY; /** * An OpenSearch plugin that enables builders to innovate AI apps on OpenSearch. @@ -121,9 +137,28 @@ public Collection createComponents( Settings settings = environment.settings(); flowFrameworkSettings = new FlowFrameworkSettings(clusterService, settings); MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); - EncryptorUtils encryptorUtils = new EncryptorUtils(clusterService, client, xContentRegistry); + SdkClient sdkClient = SdkClientFactory.createSdkClient( + client, + xContentRegistry, + // Here we assume remote metadata client is only used with tenant awareness. + // This may change in the future allowing more options for this map + FLOW_FRAMEWORK_MULTI_TENANCY_ENABLED.get(settings) + ? Map.ofEntries( + Map.entry(REMOTE_METADATA_TYPE_KEY, REMOTE_METADATA_TYPE.get(settings)), + Map.entry(REMOTE_METADATA_ENDPOINT_KEY, REMOTE_METADATA_ENDPOINT.get(settings)), + Map.entry(REMOTE_METADATA_REGION_KEY, REMOTE_METADATA_REGION.get(settings)), + Map.entry(REMOTE_METADATA_SERVICE_NAME_KEY, REMOTE_METADATA_SERVICE_NAME.get(settings)), + Map.entry(TENANT_AWARE_KEY, "true"), + Map.entry(TENANT_ID_FIELD_KEY, TENANT_ID_FIELD) + ) + : Collections.emptyMap(), + // TODO: Find a better thread pool or make one + client.threadPool().executor(ThreadPool.Names.GENERIC) + ); + EncryptorUtils encryptorUtils = new EncryptorUtils(clusterService, client, sdkClient, xContentRegistry); FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler( client, + sdkClient, clusterService, encryptorUtils, xContentRegistry @@ -137,7 +172,13 @@ public Collection createComponents( ); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool, flowFrameworkSettings); - SearchHandler searchHandler = new SearchHandler(settings, clusterService, client, FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES); + SearchHandler searchHandler = new SearchHandler( + settings, + clusterService, + client, + sdkClient, + FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES + ); return List.of( workflowStepFactory, @@ -145,7 +186,8 @@ public Collection createComponents( encryptorUtils, flowFrameworkIndicesHandler, searchHandler, - flowFrameworkSettings + flowFrameworkSettings, + sdkClient ); } @@ -196,7 +238,12 @@ public List> getSettings() { MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, TASK_REQUEST_RETRY_DURATION, - FILTER_BY_BACKEND_ROLES + FILTER_BY_BACKEND_ROLES, + FLOW_FRAMEWORK_MULTI_TENANCY_ENABLED, + REMOTE_METADATA_TYPE, + REMOTE_METADATA_ENDPOINT, + REMOTE_METADATA_REGION, + REMOTE_METADATA_SERVICE_NAME ); } @@ -206,21 +253,21 @@ public List> getExecutorBuilders(Settings settings) { new ScalingExecutorBuilder( WORKFLOW_THREAD_POOL, 1, - Math.max(2, OpenSearchExecutors.allocatedProcessors(settings) - 1), + Math.max(4, OpenSearchExecutors.allocatedProcessors(settings) - 1), TimeValue.timeValueMinutes(1), FLOW_FRAMEWORK_THREAD_POOL_PREFIX + WORKFLOW_THREAD_POOL ), new ScalingExecutorBuilder( PROVISION_WORKFLOW_THREAD_POOL, 1, - Math.max(4, OpenSearchExecutors.allocatedProcessors(settings) - 1), + Math.max(8, OpenSearchExecutors.allocatedProcessors(settings) - 1), TimeValue.timeValueMinutes(5), FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_WORKFLOW_THREAD_POOL ), new ScalingExecutorBuilder( DEPROVISION_WORKFLOW_THREAD_POOL, 1, - Math.max(2, OpenSearchExecutors.allocatedProcessors(settings) - 1), + Math.max(4, OpenSearchExecutors.allocatedProcessors(settings) - 1), TimeValue.timeValueMinutes(1), FLOW_FRAMEWORK_THREAD_POOL_PREFIX + DEPROVISION_WORKFLOW_THREAD_POOL ) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 0a4af0758..b4d7fb98b 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -8,6 +8,8 @@ */ package org.opensearch.flowframework.common; +import org.opensearch.Version; + /** * Representation of common values that are used across project */ @@ -82,6 +84,10 @@ private CommonValue() {} public static final String USE_CASE = "use_case"; /** The param name for reprovisioning, used by the create workflow API */ public static final String REPROVISION_WORKFLOW = "reprovision"; + /** The REST header containing the tenant id */ + public static final String TENANT_ID_HEADER = "x-tenant-id"; + /** The field name containing the tenant id */ + public static final String TENANT_ID_FIELD = "tenant_id"; /* * Constants associated with plugin configuration @@ -244,4 +250,9 @@ private CommonValue() {} public static final String ML_COMMONS_API_SPEC_YAML_URI = "https://raw.githubusercontent.com/opensearch-project/opensearch-api-specification/refs/heads/main/spec/namespaces/ml.yaml"; + /* + * Constants associated with non-BWC features + */ + /** Version 2.19.0 */ + public static final Version VERSION_2_19_0 = Version.fromString("2.19.0"); } diff --git a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java index 922212a38..e701c427a 100644 --- a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java +++ b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java @@ -10,9 +10,15 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Setting.Property; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_ENDPOINT_KEY; +import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_REGION_KEY; +import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_SERVICE_NAME_KEY; +import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_TYPE_KEY; + /** The common settings of flow framework */ public class FlowFrameworkSettings { @@ -25,6 +31,8 @@ public class FlowFrameworkSettings { protected volatile Integer maxWorkflows; /** Timeout for internal requests*/ protected volatile TimeValue requestTimeout; + /** Whether multitenancy is enabled */ + private final Boolean isMultiTenancyEnabled; /** The upper limit of max workflows that can be created */ public static final int MAX_WORKFLOWS_LIMIT = 10000; @@ -83,6 +91,72 @@ public class FlowFrameworkSettings { Setting.Property.Dynamic ); + /** + * Indicates whether multi-tenancy is enabled in Flow Framework. + * + * This is a static setting that must be configured before starting OpenSearch. The corresponding setting {@code plugins.ml_commons.multi_tenancy_enabled} in the ML Commons plugin should match. + * + * It can be set in the following ways, in priority order: + * + *
    + *
  1. As a command-line argument using the -E flag (this overrides other options): + *
    +     *       ./bin/opensearch -Eplugins.flow_framework.multi_tenancy_enabled=true
    +     *       
    + *
  2. + *
  3. As a system property using OPENSEARCH_JAVA_OPTS (this overrides opensearch.yml): + *
    +     *       export OPENSEARCH_JAVA_OPTS="-Dplugins.flow_framework.multi_tenancy_enabled=true"
    +     *       ./bin/opensearch
    +     *       
    + * Or inline when starting OpenSearch: + *
    +     *       OPENSEARCH_JAVA_OPTS="-Dplugins.flow_framework.multi_tenancy_enabled=true" ./bin/opensearch
    +     *       
    + *
  4. + *
  5. In the opensearch.yml configuration file: + *
    +     *       plugins.flow_framework.multi_tenancy_enabled: true
    +     *       
    + *
  6. + *
+ * + * After setting this option, a full cluster restart is required for the changes to take effect. + */ + public static final Setting FLOW_FRAMEWORK_MULTI_TENANCY_ENABLED = Setting.boolSetting( + "plugins.flow_framework.multi_tenancy_enabled", + false, + Setting.Property.NodeScope + ); + + /** This setting sets the remote metadata type */ + public static final Setting REMOTE_METADATA_TYPE = Setting.simpleString( + "plugins.flow_framework." + REMOTE_METADATA_TYPE_KEY, + Property.NodeScope, + Property.Final + ); + + /** This setting sets the remote metadata endpoint */ + public static final Setting REMOTE_METADATA_ENDPOINT = Setting.simpleString( + "plugins.flow_framework." + REMOTE_METADATA_ENDPOINT_KEY, + Property.NodeScope, + Property.Final + ); + + /** This setting sets the remote metadata region */ + public static final Setting REMOTE_METADATA_REGION = Setting.simpleString( + "plugins.flow_framework." + REMOTE_METADATA_REGION_KEY, + Property.NodeScope, + Property.Final + ); + + /** This setting sets the remote metadata service name */ + public static final Setting REMOTE_METADATA_SERVICE_NAME = Setting.simpleString( + "plugins.flow_framework." + REMOTE_METADATA_SERVICE_NAME_KEY, + Property.NodeScope, + Property.Final + ); + /** * Instantiate this class. * @@ -97,6 +171,7 @@ public FlowFrameworkSettings(ClusterService clusterService, Settings settings) { this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings); this.maxWorkflows = MAX_WORKFLOWS.get(settings); this.requestTimeout = WORKFLOW_REQUEST_TIMEOUT.get(settings); + this.isMultiTenancyEnabled = FLOW_FRAMEWORK_MULTI_TENANCY_ENABLED.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(FLOW_FRAMEWORK_ENABLED, it -> isFlowFrameworkEnabled = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_REQUEST_RETRY_DURATION, it -> retryDuration = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it); @@ -143,4 +218,12 @@ public Integer getMaxWorkflows() { public TimeValue getRequestTimeout() { return requestTimeout; } + + /** + * Whether multitenancy is enabled. + * @return whether Flow Framework multitenancy is enabled. + */ + public boolean isMultiTenancyEnabled() { + return isMultiTenancyEnabled; + } } diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 0ef0a7adb..31c49bb14 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -12,32 +12,27 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessageFactory; import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.DocWriteRequest.OpType; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest; import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest; -import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.util.concurrent.ThreadContext.StoredContext; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.ProvisioningProgress; @@ -48,7 +43,13 @@ import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.remote.metadata.client.DeleteDataObjectRequest; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.PutDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.UpdateDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; import java.io.IOException; import java.util.ArrayList; @@ -79,6 +80,7 @@ public class FlowFrameworkIndicesHandler { private static final Logger logger = LogManager.getLogger(FlowFrameworkIndicesHandler.class); private final Client client; + private final SdkClient sdkClient; private final ClusterService clusterService; private final EncryptorUtils encryptorUtils; private static final Map indexMappingUpdated = new HashMap<>(); @@ -90,17 +92,20 @@ public class FlowFrameworkIndicesHandler { /** * constructor * @param client the open search client + * @param sdkClient the remote metadata client * @param clusterService ClusterService * @param encryptorUtils encryption utility * @param xContentRegistry contentRegister to parse any response */ public FlowFrameworkIndicesHandler( Client client, + SdkClient sdkClient, ClusterService clusterService, EncryptorUtils encryptorUtils, NamedXContentRegistry xContentRegistry ) { this.client = client; + this.sdkClient = sdkClient; this.clusterService = clusterService; this.encryptorUtils = encryptorUtils; for (FlowFrameworkIndex mlIndex : FlowFrameworkIndex.values()) { @@ -332,37 +337,54 @@ public void putTemplateToGlobalContext(Template template, ActionListener { logger.error("Failed to create global_context index"); listener.onFailure(e); })); } + private void putOrReplaceTemplateInGlobalContextIndex(String documentId, Template template, ActionListener listener) { + PutDataObjectRequest request = PutDataObjectRequest.builder() + .index(GLOBAL_CONTEXT_INDEX) + .id(documentId) + .tenantId(template.getTenantId()) + .dataObject(encryptorUtils.encryptTemplateCredentials(template)) + .build(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + sdkClient.putDataObjectAsync(request).whenComplete((r, throwable) -> { + context.restore(); + if (throwable == null) { + try { + IndexResponse indexResponse = IndexResponse.fromXContent(r.parser()); + listener.onResponse(indexResponse); + } catch (IOException e) { + String errorMessage = "Failed to parse index response"; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage, INTERNAL_SERVER_ERROR)); + } + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + String errorMessage = "Failed to index template in global context index"; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }); + } + } + /** * Initializes config index and EncryptorUtils + * @param tenantId the tenant id * @param listener action listener */ - public void initializeConfigIndex(ActionListener listener) { + public void initializeConfigIndex(String tenantId, ActionListener listener) { initConfigIndexIfAbsent(ActionListener.wrap(indexCreated -> { if (!indexCreated) { listener.onFailure(new FlowFrameworkException("No response to create config index", INTERNAL_SERVER_ERROR)); return; } - encryptorUtils.initializeMasterKey(listener); + encryptorUtils.initializeMasterKey(tenantId, listener); }, createIndexException -> { logger.error("Failed to create config index"); listener.onFailure(createIndexException); @@ -371,11 +393,12 @@ public void initializeConfigIndex(ActionListener listener) { /** * add document insert into global context index - * @param workflowId the workflowId, corresponds to document ID of + * @param workflowId the workflowId, corresponds to document ID + * @param tenantId the tenant id * @param user passes the user that created the workflow * @param listener action listener */ - public void putInitialStateToWorkflowState(String workflowId, User user, ActionListener listener) { + public void putInitialStateToWorkflowState(String workflowId, String tenantId, User user, ActionListener listener) { WorkflowState state = WorkflowState.builder() .workflowId(workflowId) .state(State.NOT_STARTED.name()) @@ -383,27 +406,38 @@ public void putInitialStateToWorkflowState(String workflowId, User user, ActionL .user(user) .resourcesCreated(Collections.emptyList()) .userOutputs(Collections.emptyMap()) + .tenantId(tenantId) .build(); initWorkflowStateIndexIfAbsent(ActionListener.wrap(indexCreated -> { if (!indexCreated) { listener.onFailure(new FlowFrameworkException("No response to create workflow_state index", INTERNAL_SERVER_ERROR)); return; } - IndexRequest request = new IndexRequest(WORKFLOW_STATE_INDEX); - try ( - XContentBuilder builder = XContentFactory.jsonBuilder(); - ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext(); - - ) { - request.source(state.toXContent(builder, ToXContent.EMPTY_PARAMS)).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - request.id(workflowId); - client.index(request, ActionListener.runBefore(listener, context::restore)); - } catch (Exception e) { - String errorMessage = "Failed to put state index document"; - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + PutDataObjectRequest putRequest = PutDataObjectRequest.builder() + .index(WORKFLOW_STATE_INDEX) + .id(workflowId) + .tenantId(tenantId) + .dataObject(state) + .build(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + sdkClient.putDataObjectAsync(putRequest).whenComplete((r, throwable) -> { + context.restore(); + if (throwable == null) { + try { + IndexResponse indexResponse = IndexResponse.fromXContent(r.parser()); + listener.onResponse(indexResponse); + } catch (IOException e) { + logger.error("Failed to parse index response", e); + listener.onFailure(new FlowFrameworkException("Failed to parse index response", INTERNAL_SERVER_ERROR)); + } + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + String errorMessage = "Failed to put state index document"; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }); } - }, e -> { String errorMessage = "Failed to create workflow_state index"; logger.error(errorMessage, e); @@ -434,36 +468,21 @@ public void updateTemplateInGlobalContext( ActionListener listener, boolean ignoreNotStartedCheck ) { + String tenantId = template.getTenantId(); if (!doesIndexExist(GLOBAL_CONTEXT_INDEX)) { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( - "Failed to update template for workflow_id : {}, global_context index does not exist.", + "Failed to update template for workflow_id : {}, global context index does not exist.", documentId ).getFormattedMessage(); logger.error(errorMessage); - listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR)); return; } - doesTemplateExist(documentId, templateExists -> { + doesTemplateExist(documentId, tenantId, templateExists -> { if (templateExists) { - getProvisioningProgress(documentId, progress -> { + getProvisioningProgress(documentId, tenantId, progress -> { if (ignoreNotStartedCheck || ProvisioningProgress.NOT_STARTED.equals(progress.orElse(null))) { - IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX).id(documentId); - try ( - XContentBuilder builder = XContentFactory.jsonBuilder(); - ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext() - ) { - Template encryptedTemplate = encryptorUtils.encryptTemplateCredentials(template); - request.source(encryptedTemplate.toXContent(builder, ToXContent.EMPTY_PARAMS)) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(request, ActionListener.runBefore(listener, context::restore)); - } catch (Exception e) { - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( - "Failed to update global_context entry : {}", - documentId - ).getFormattedMessage(); - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + putOrReplaceTemplateInGlobalContextIndex(documentId, template, listener); } else { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( "The template can not be updated unless its provisioning state is NOT_STARTED: {}. Deprovision the workflow to reset the state.", @@ -486,135 +505,201 @@ public void updateTemplateInGlobalContext( * Check if the given template exists in the template index * * @param documentId document id + * @param tenantId tenant id * @param booleanResultConsumer a consumer based on whether the template exist * @param listener action listener * @param action listener response type */ - public void doesTemplateExist(String documentId, Consumer booleanResultConsumer, ActionListener listener) { - GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, documentId); + public void doesTemplateExist( + String documentId, + String tenantId, + Consumer booleanResultConsumer, + ActionListener listener + ) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.get(getRequest, ActionListener.wrap(response -> { booleanResultConsumer.accept(response.isExists()); }, exception -> { - context.restore(); + getTemplate( + documentId, + tenantId, + ActionListener.wrap(response -> booleanResultConsumer.accept(response.isExists()), exception -> { + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Failed to get template {}", documentId) + .getFormattedMessage(); + logger.error(errorMessage); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + }), + context + ); + } + } + + /** + * Get a template from the template index + * + * @param documentId document id + * @param tenantId tenant id + * @param listener action listener + * @param context the thread context + */ + public void getTemplate(String documentId, String tenantId, ActionListener listener, StoredContext context) { + GetDataObjectRequest getRequest = GetDataObjectRequest.builder() + .index(GLOBAL_CONTEXT_INDEX) + .id(documentId) + .tenantId(tenantId) + .build(); + sdkClient.getDataObjectAsync(getRequest).whenComplete((r, throwable) -> { + context.restore(); + if (throwable == null) { + try { + GetResponse getResponse = GetResponse.fromXContent(r.parser()); + listener.onResponse(getResponse); + } catch (IOException e) { + logger.error("Failed to parse get response", e); + listener.onFailure(new FlowFrameworkException("Failed to parse get response", INTERNAL_SERVER_ERROR)); + } + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Failed to get template {}", documentId) .getFormattedMessage(); - logger.error(errorMessage); + logger.error(errorMessage, exception); listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - })); - } catch (Exception e) { - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( - "Failed to retrieve template from global context: {}", - documentId - ).getFormattedMessage(); - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + } + }); + } + + /** + * Get a workflow state from the state index + * + * @param workflowId workflow id + * @param tenantId tenant id + * @param listener action listener + * @param context the thread context + */ + public void getWorkflowState(String workflowId, String tenantId, ActionListener listener, StoredContext context) { + GetDataObjectRequest getRequest = GetDataObjectRequest.builder() + .index(WORKFLOW_STATE_INDEX) + .id(workflowId) + .tenantId(tenantId) + .build(); + sdkClient.getDataObjectAsync(getRequest).whenComplete((r, throwable) -> { + context.restore(); + if (throwable == null) { + try { + GetResponse getResponse = GetResponse.fromXContent(r.parser()); + if (getResponse != null && getResponse.isExists()) { + try ( + XContentParser parser = ParseUtils.createXContentParserFromRegistry( + xContentRegistry, + getResponse.getSourceAsBytesRef() + ) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + WorkflowState workflowState = WorkflowState.parse(parser); + listener.onResponse(workflowState); + } catch (Exception e) { + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( + "Failed to parse workflowState: {}", + getResponse.getId() + ).getFormattedMessage(); + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR)); + } + } else { + listener.onFailure( + new FlowFrameworkException("Fail to find workflow status of " + workflowId, RestStatus.NOT_FOUND) + ); + } + } catch (Exception e) { + logger.error("Failed to parse get response", e); + listener.onFailure(new FlowFrameworkException("Failed to parse get response", INTERNAL_SERVER_ERROR)); + } + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + if (exception instanceof IndexNotFoundException) { + listener.onFailure(new FlowFrameworkException("Fail to find workflow status of " + workflowId, RestStatus.NOT_FOUND)); + } else { + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( + "Failed to get workflow status of: {}", + workflowId + ).getFormattedMessage(); + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); + } + } + }); } /** * Check workflow provisioning state and executes the consumer * - * @param documentId document id + * @param workflowId workflow id + * @param tenantId tenant id * @param provisioningProgressConsumer consumer function based on if workflow is provisioned. * @param listener action listener * @param action listener response type */ public void getProvisioningProgress( - String documentId, + String workflowId, + String tenantId, Consumer> provisioningProgressConsumer, ActionListener listener ) { - GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX, documentId); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.get(getRequest, ActionListener.wrap(response -> { - context.restore(); - if (!response.isExists()) { + getWorkflowState(workflowId, tenantId, ActionListener.wrap(workflowState -> { + provisioningProgressConsumer.accept(Optional.of(ProvisioningProgress.valueOf(workflowState.getProvisioningProgress()))); + }, exception -> { + if (exception instanceof FlowFrameworkException + && ((FlowFrameworkException) exception).getRestStatus() == RestStatus.NOT_FOUND) { provisioningProgressConsumer.accept(Optional.empty()); - return; - } - try ( - XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - WorkflowState workflowState = WorkflowState.parse(parser); - provisioningProgressConsumer.accept(Optional.of(ProvisioningProgress.valueOf(workflowState.getProvisioningProgress()))); - } catch (Exception e) { - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Failed to parse workflow state {}", documentId) - .getFormattedMessage(); - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR)); + } else { + listener.onFailure(exception); } - }, exception -> { - logger.error("Failed to get workflow state for {} ", documentId); - provisioningProgressConsumer.accept(Optional.empty()); - })); - } catch (Exception e) { - String errorMessage = "Failed to retrieve workflow state to check provisioning status"; - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + }), context); } } /** * Check workflow provisioning state and resources to see if state can be deleted with template * - * @param documentId document id + * @param workflowId workflow id + * @param tenantId tenant id * @param clearStatus if set true, always deletes the state document unless status is IN_PROGRESS * @param canDeleteStateConsumer consumer function which will be true if workflow state is not IN_PROGRESS and either no resources or true clearStatus * @param listener action listener from caller to fail on error * @param action listener response type */ public void canDeleteWorkflowStateDoc( - String documentId, + String workflowId, + String tenantId, boolean clearStatus, Consumer canDeleteStateConsumer, ActionListener listener ) { - GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX, documentId); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.get(getRequest, ActionListener.wrap(response -> { - context.restore(); - if (!response.isExists()) { - // no need to delete if it's not there to start with + getWorkflowState(workflowId, tenantId, ActionListener.wrap(workflowState -> { + canDeleteStateConsumer.accept( + (clearStatus || workflowState.resourcesCreated().isEmpty()) + && !ProvisioningProgress.IN_PROGRESS.equals(ProvisioningProgress.valueOf(workflowState.getProvisioningProgress())) + ); + }, exception -> { + if (exception instanceof FlowFrameworkException + && ((FlowFrameworkException) exception).getRestStatus() == RestStatus.NOT_FOUND) { canDeleteStateConsumer.accept(Boolean.FALSE); - return; - } - try ( - XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - WorkflowState workflowState = WorkflowState.parse(parser); - canDeleteStateConsumer.accept( - (clearStatus || workflowState.resourcesCreated().isEmpty()) - && !ProvisioningProgress.IN_PROGRESS.equals( - ProvisioningProgress.valueOf(workflowState.getProvisioningProgress()) - ) - ); - } catch (Exception e) { - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Failed to parse workflow state {}", documentId) - .getFormattedMessage(); - ; - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR)); + } else { + listener.onFailure(exception); } - }, exception -> { - logger.error("Failed to get workflow state for {} ", documentId); - canDeleteStateConsumer.accept(Boolean.FALSE); - })); - } catch (Exception e) { - String errorMessage = "Failed to retrieve workflow state to check provisioning status"; - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + }), context); } } /** * Updates a complete document in the workflow state index * @param documentId the document ID + * @param tenantId the tenant ID * @param updatedDocument a complete document to update the global state index with * @param listener action listener */ public void updateFlowFrameworkSystemIndexDoc( String documentId, + String tenantId, ToXContentObject updatedDocument, ActionListener listener ) { @@ -628,13 +713,38 @@ public void updateFlowFrameworkSystemIndexDoc( listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); } else { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - UpdateRequest updateRequest = new UpdateRequest(WORKFLOW_STATE_INDEX, documentId); - XContentBuilder builder = XContentFactory.jsonBuilder(); - updatedDocument.toXContent(builder, null); - updateRequest.doc(builder); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - updateRequest.retryOnConflict(RETRIES); - client.update(updateRequest, ActionListener.runBefore(listener, context::restore)); + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest.builder() + .index(WORKFLOW_STATE_INDEX) + .id(documentId) + .tenantId(tenantId) + .dataObject(updatedDocument) + .retryOnConflict(RETRIES) + .build(); + sdkClient.updateDataObjectAsync(updateRequest).whenComplete((r, throwable) -> { + context.restore(); + if (throwable == null) { + UpdateResponse response; + try { + response = UpdateResponse.fromXContent(r.parser()); + logger.info("Updated workflow state doc: {}", documentId); + listener.onResponse(response); + } catch (Exception e) { + logger.error("Failed to parse update response", e); + listener.onFailure( + new FlowFrameworkException("Failed to parse update response", RestStatus.INTERNAL_SERVER_ERROR) + ); + } + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( + "Failed to update {} entry : {}", + WORKFLOW_STATE_INDEX, + documentId + ).getFormattedMessage(); + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }); } catch (Exception e) { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( "Failed to update {} entry : {}", @@ -650,11 +760,13 @@ public void updateFlowFrameworkSystemIndexDoc( /** * Updates a partial document in the workflow state index * @param documentId the document ID + * @param tenantId the tenant ID * @param updatedFields the fields to update the global state index with * @param listener action listener */ public void updateFlowFrameworkSystemIndexDoc( String documentId, + String tenantId, Map updatedFields, ActionListener listener ) { @@ -668,13 +780,39 @@ public void updateFlowFrameworkSystemIndexDoc( listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); } else { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - UpdateRequest updateRequest = new UpdateRequest(WORKFLOW_STATE_INDEX, documentId); Map updatedContent = new HashMap<>(updatedFields); - updateRequest.doc(updatedContent); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - updateRequest.retryOnConflict(RETRIES); + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest.builder() + .index(WORKFLOW_STATE_INDEX) + .id(documentId) + .tenantId(tenantId) + .dataObject(updatedContent) + .retryOnConflict(RETRIES) + .build(); // TODO: decide what condition can be considered as an update conflict and add retry strategy - client.update(updateRequest, ActionListener.runBefore(listener, context::restore)); + sdkClient.updateDataObjectAsync(updateRequest).whenComplete((r, throwable) -> { + context.restore(); + if (throwable == null) { + try { + UpdateResponse response = UpdateResponse.fromXContent(r.parser()); + logger.info("Updated workflow state doc: {}", documentId); + listener.onResponse(response); + } catch (Exception e) { + logger.error("Failed to parse update response", e); + listener.onFailure( + new FlowFrameworkException("Failed to parse update response", RestStatus.INTERNAL_SERVER_ERROR) + ); + } + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( + "Failed to update {} entry : {}", + WORKFLOW_STATE_INDEX, + documentId + ).getFormattedMessage(); + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }); } catch (Exception e) { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( "Failed to update {} entry : {}", @@ -690,9 +828,10 @@ public void updateFlowFrameworkSystemIndexDoc( /** * Deletes a document in the workflow state index * @param documentId the document ID + * @param tenantId the tenant Id * @param listener action listener */ - public void deleteFlowFrameworkSystemIndexDoc(String documentId, ActionListener listener) { + public void deleteFlowFrameworkSystemIndexDoc(String documentId, String tenantId, ActionListener listener) { if (!doesIndexExist(WORKFLOW_STATE_INDEX)) { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( "Failed to delete document {} due to missing {} index", @@ -703,17 +842,35 @@ public void deleteFlowFrameworkSystemIndexDoc(String documentId, ActionListener< listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); } else { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - DeleteRequest deleteRequest = new DeleteRequest(WORKFLOW_STATE_INDEX, documentId); - deleteRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.delete(deleteRequest, ActionListener.runBefore(listener, context::restore)); - } catch (Exception e) { - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( - "Failed to delete {} entry : {}", - WORKFLOW_STATE_INDEX, - documentId - ).getFormattedMessage(); - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest.builder() + .index(WORKFLOW_STATE_INDEX) + .id(documentId) + .tenantId(tenantId) + .build(); + sdkClient.deleteDataObjectAsync(deleteRequest).whenComplete((r, throwable) -> { + context.restore(); + if (throwable == null) { + try { + DeleteResponse response = DeleteResponse.fromXContent(r.parser()); + logger.info("Deleted workflow state doc: {}", documentId); + listener.onResponse(response); + } catch (Exception e) { + logger.error("Failed to parse delete response", e); + listener.onFailure( + new FlowFrameworkException("Failed to parse delete response", RestStatus.INTERNAL_SERVER_ERROR) + ); + } + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( + "Failed to delete {} entry : {}", + WORKFLOW_STATE_INDEX, + documentId + ).getFormattedMessage(); + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }); } } } @@ -724,6 +881,7 @@ public void deleteFlowFrameworkSystemIndexDoc(String documentId, ActionListener< * @param nodeId current process node (workflow step) id * @param workflowStepName the workflow step name that created the resource * @param resourceId the id of the newly created resource + * @param tenantId the tenant id * @param listener the ActionListener for this step to handle completing the future after update */ public void addResourceToStateIndex( @@ -731,39 +889,56 @@ public void addResourceToStateIndex( String nodeId, String workflowStepName, String resourceId, + String tenantId, ActionListener listener ) { String workflowId = currentNodeInputs.getWorkflowId(); + if (!validateStateIndexExists(workflowId, listener)) { + return; + } String resourceName = getResourceByWorkflowStep(workflowStepName); ResourceCreated newResource = new ResourceCreated(workflowStepName, nodeId, resourceName, resourceId); - if (!doesIndexExist(WORKFLOW_STATE_INDEX)) { - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( - "Failed to update state for {} due to missing {} index", + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + getAndUpdateResourceInStateDocumentWithRetries( workflowId, - WORKFLOW_STATE_INDEX - ).getFormattedMessage(); - logger.error(errorMessage); - listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); - } else { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - getAndUpdateResourceInStateDocumentWithRetries( - workflowId, - newResource, - OpType.INDEX, - RETRIES, - ActionListener.runBefore(listener, context::restore) - ); - } + tenantId, + newResource, + OpType.INDEX, + RETRIES, + ActionListener.runBefore(listener, context::restore) + ); } } /** * Removes a resource from the state index, including common exception handling * @param workflowId The workflow document id in the state index + * @param tenantId The tenant id * @param resourceToDelete The resource to delete * @param listener the ActionListener for this step to handle completing the future after update */ - public void deleteResourceFromStateIndex(String workflowId, ResourceCreated resourceToDelete, ActionListener listener) { + public void deleteResourceFromStateIndex( + String workflowId, + String tenantId, + ResourceCreated resourceToDelete, + ActionListener listener + ) { + if (!validateStateIndexExists(workflowId, listener)) { + return; + } + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + getAndUpdateResourceInStateDocumentWithRetries( + workflowId, + tenantId, + resourceToDelete, + OpType.DELETE, + RETRIES, + ActionListener.runBefore(listener, context::restore) + ); + } + } + + private boolean validateStateIndexExists(String workflowId, ActionListener listener) { if (!doesIndexExist(WORKFLOW_STATE_INDEX)) { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( "Failed to update state for {} due to missing {} index", @@ -772,22 +947,15 @@ public void deleteResourceFromStateIndex(String workflowId, ResourceCreated reso ).getFormattedMessage(); logger.error(errorMessage); listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); - } else { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - getAndUpdateResourceInStateDocumentWithRetries( - workflowId, - resourceToDelete, - OpType.DELETE, - RETRIES, - ActionListener.runBefore(listener, context::restore) - ); - } + return false; } + return true; } /** * Performs a get and update of a State Index document adding or removing a resource with strong consistency and retries * @param workflowId The document id to update + * @param tenantId * @param resource The resource to add or remove from the resources created list * @param operation The operation to perform on the resource (INDEX to append to the list or DELETE to remove) * @param retries The number of retries on update version conflicts @@ -795,17 +963,47 @@ public void deleteResourceFromStateIndex(String workflowId, ResourceCreated reso */ private void getAndUpdateResourceInStateDocumentWithRetries( String workflowId, + String tenantId, ResourceCreated resource, OpType operation, int retries, ActionListener listener ) { - GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX, workflowId); - client.get(getRequest, ActionListener.wrap(getResponse -> { - if (!getResponse.isExists()) { - listener.onFailure(new FlowFrameworkException("Workflow state not found for " + workflowId, RestStatus.NOT_FOUND)); - return; + GetDataObjectRequest getRequest = GetDataObjectRequest.builder() + .index(WORKFLOW_STATE_INDEX) + .id(workflowId) + .tenantId(tenantId) + .build(); + sdkClient.getDataObjectAsync(getRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + GetResponse getResponse = GetResponse.fromXContent(r.parser()); + handleStateGetResponse(workflowId, tenantId, resource, operation, retries, listener, getResponse); + } catch (Exception e) { + logger.error("Failed to parse get response", e); + listener.onFailure(new FlowFrameworkException("Failed to parse get response", INTERNAL_SERVER_ERROR)); + } + } else { + Exception ex = SdkClientUtils.unwrapAndConvertToException(throwable); + handleStateUpdateException(workflowId, tenantId, resource, operation, 0, listener, ex); } + }); + } + + private void handleStateGetResponse( + String workflowId, + String tenantId, + ResourceCreated resource, + OpType operation, + int retries, + ActionListener listener, + GetResponse getResponse + ) { + if (!getResponse.isExists()) { + listener.onFailure(new FlowFrameworkException("Workflow state not found for " + workflowId, RestStatus.NOT_FOUND)); + return; + } + try { WorkflowState currentState = WorkflowState.parse(getResponse.getSourceAsString()); List resourcesCreated = new ArrayList<>(currentState.resourcesCreated()); if (operation == OpType.DELETE) { @@ -813,21 +1011,31 @@ private void getAndUpdateResourceInStateDocumentWithRetries( } else { resourcesCreated.add(resource); } - XContentBuilder builder = XContentFactory.jsonBuilder(); WorkflowState newState = WorkflowState.builder(currentState).resourcesCreated(resourcesCreated).build(); - newState.toXContent(builder, null); - UpdateRequest updateRequest = new UpdateRequest(WORKFLOW_STATE_INDEX, workflowId).doc(builder) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .setIfSeqNo(getResponse.getSeqNo()) - .setIfPrimaryTerm(getResponse.getPrimaryTerm()); - client.update( - updateRequest, - ActionListener.wrap( - r -> handleStateUpdateSuccess(workflowId, resource, operation, listener), - e -> handleStateUpdateException(workflowId, resource, operation, retries, listener, e) - ) - ); - }, ex -> handleStateUpdateException(workflowId, resource, operation, 0, listener, ex))); + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest.builder() + .index(WORKFLOW_STATE_INDEX) + .id(workflowId) + .tenantId(tenantId) + .dataObject(newState) + .ifSeqNo(getResponse.getSeqNo()) + .ifPrimaryTerm(getResponse.getPrimaryTerm()) + .build(); + sdkClient.updateDataObjectAsync(updateRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + handleStateUpdateSuccess(workflowId, resource, operation, listener); + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + handleStateUpdateException(workflowId, tenantId, resource, operation, retries, listener, e); + } + }); + } catch (Exception e) { + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( + "Failed to parse workflow state response for {}", + workflowId + ).getFormattedMessage(); + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage, INTERNAL_SERVER_ERROR)); + } } private void handleStateUpdateSuccess( @@ -852,15 +1060,16 @@ private void handleStateUpdateSuccess( private void handleStateUpdateException( String workflowId, + String tenantId, ResourceCreated newResource, OpType operation, int retries, ActionListener listener, Exception e ) { - if (e instanceof VersionConflictEngineException && retries > 0) { + if (e instanceof OpenSearchStatusException && ((OpenSearchStatusException) e).status() == RestStatus.CONFLICT && retries > 0) { // Retry if we haven't exhausted retries - getAndUpdateResourceInStateDocumentWithRetries(workflowId, newResource, operation, retries - 1, listener); + getAndUpdateResourceInStateDocumentWithRetries(workflowId, tenantId, newResource, operation, retries - 1, listener); return; } String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index 4c8c2c9ef..a0d544f54 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -38,6 +38,7 @@ import static org.opensearch.flowframework.common.CommonValue.LAST_PROVISIONED_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.LAST_UPDATED_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.flowframework.common.CommonValue.UI_METADATA_FIELD; import static org.opensearch.flowframework.common.CommonValue.USER_FIELD; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; @@ -75,6 +76,7 @@ public class Template implements ToXContentObject { private final Instant createdTime; private final Instant lastUpdatedTime; private final Instant lastProvisionedTime; + private String tenantId; /** * Instantiate the object representing a use case template @@ -90,6 +92,7 @@ public class Template implements ToXContentObject { * @param createdTime Created time as an Instant * @param lastUpdatedTime Last Updated time as an Instant * @param lastProvisionedTime Last Provisioned time as an Instant + * @param tenantId The tenant id */ public Template( String name, @@ -102,7 +105,8 @@ public Template( User user, Instant createdTime, Instant lastUpdatedTime, - Instant lastProvisionedTime + Instant lastProvisionedTime, + String tenantId ) { this.name = name; this.description = description; @@ -115,6 +119,7 @@ public Template( this.createdTime = createdTime; this.lastUpdatedTime = lastUpdatedTime; this.lastProvisionedTime = lastProvisionedTime; + this.tenantId = tenantId; } /** @@ -132,6 +137,7 @@ public static class Builder { private Instant createdTime = null; private Instant lastUpdatedTime = null; private Instant lastProvisionedTime = null; + private String tenantId = null; /** * Empty Constructor for the Builder object @@ -160,6 +166,7 @@ private Builder(Template t) { this.createdTime = t.createdTime(); this.lastUpdatedTime = t.lastUpdatedTime(); this.lastProvisionedTime = t.lastProvisionedTime(); + this.tenantId = t.getTenantId(); } /** @@ -272,6 +279,16 @@ public Builder lastProvisionedTime(Instant lastProvisionedTime) { return this; } + /** + * Builder method for adding user + * @param tenantId the tenant id + * @return the Builder object + */ + public Builder tenantId(String tenantId) { + this.tenantId = tenantId; + return this; + } + /** * Allows building a template * @return Template Object containing all needed fields @@ -288,7 +305,8 @@ public Template build() { this.user, this.createdTime, this.lastUpdatedTime, - this.lastProvisionedTime + this.lastProvisionedTime, + this.tenantId ); } } @@ -358,6 +376,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(LAST_PROVISIONED_TIME_FIELD, lastProvisionedTime.toEpochMilli()); } + if (tenantId != null) { + xContentBuilder.field(TENANT_ID_FIELD, tenantId); + } + return xContentBuilder.endObject(); } @@ -421,6 +443,7 @@ public static Template parse(XContentParser parser, boolean fieldUpdate) throws Instant createdTime = null; Instant lastUpdatedTime = null; Instant lastProvisionedTime = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -488,6 +511,9 @@ public static Template parse(XContentParser parser, boolean fieldUpdate) throws case LAST_PROVISIONED_TIME_FIELD: lastProvisionedTime = ParseUtils.parseInstant(parser); break; + case TENANT_ID_FIELD: + tenantId = parser.text(); + break; default: throw new FlowFrameworkException( "Unable to parse field [" + fieldName + "] in a template object.", @@ -507,7 +533,7 @@ public static Template parse(XContentParser parser, boolean fieldUpdate) throws } } - return new Builder().name(name) + Template template = new Builder().name(name) .description(description) .useCase(useCase) .templateVersion(templateVersion) @@ -518,7 +544,9 @@ public static Template parse(XContentParser parser, boolean fieldUpdate) throws .createdTime(createdTime) .lastUpdatedTime(lastUpdatedTime) .lastProvisionedTime(lastProvisionedTime) + .tenantId(tenantId) .build(); + return template; } /** @@ -541,6 +569,20 @@ public static Template parse(String json) throws IOException { } } + /** + * Creates an empty template with the given tenant ID + * + * @param tenantId the tenantID + * @return an empty template containing the tenant id if it's not null, null otherwise + */ + public static Template createEmptyTemplateWithTenantId(String tenantId) { + if (tenantId == null) { + return null; + } + Template emptyTemplate = builder().name("").tenantId(tenantId).build(); + return emptyTemplate; + } + /** * Output this object in a compact JSON string. * @@ -657,6 +699,22 @@ public Instant lastProvisionedTime() { return lastProvisionedTime; } + /** + * The tenant id + * @return the tenant id + */ + public String getTenantId() { + return tenantId; + } + + /** + * Sets the tenant id + * @param tenantId the tenant id to set + */ + public void setTenantId(String tenantId) { + this.tenantId = tenantId; + } + @Override public String toString() { return "Template [name=" diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowState.java b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java index 6a4b81a55..2c5a23080 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowState.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java @@ -21,6 +21,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParseException; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.util.ParseUtils; @@ -38,6 +39,7 @@ import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.flowframework.common.CommonValue.USER_FIELD; import static org.opensearch.flowframework.common.CommonValue.USER_OUTPUTS_FIELD; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID_FIELD; @@ -58,6 +60,7 @@ public class WorkflowState implements ToXContentObject, Writeable { private User user; private Map userOutputs; private List resourcesCreated; + private String tenantId; /** * Instantiate the object representing the workflow state @@ -71,6 +74,7 @@ public class WorkflowState implements ToXContentObject, Writeable { * @param user The user extracted from the thread context from the request * @param userOutputs A map of essential API responses for backend to use and lookup. * @param resourcesCreated A list of all the resources created. + * @param tenantId The tenant id */ public WorkflowState( String workflowId, @@ -81,7 +85,8 @@ public WorkflowState( Instant provisionEndTime, User user, Map userOutputs, - List resourcesCreated + List resourcesCreated, + String tenantId ) { this.workflowId = workflowId; this.error = error; @@ -92,6 +97,7 @@ public WorkflowState( this.user = user; this.userOutputs = Map.copyOf(userOutputs); this.resourcesCreated = List.copyOf(resourcesCreated); + this.tenantId = tenantId; } private WorkflowState() {} @@ -116,6 +122,9 @@ public WorkflowState(StreamInput input) throws IOException { for (int r = 0; r < resourceCount; r++) { resourcesCreated.add(new ResourceCreated(input)); } + if (input.getVersion().onOrAfter(CommonValue.VERSION_2_19_0)) { + this.tenantId = input.readOptionalString(); + } } /** @@ -148,6 +157,7 @@ public static class Builder { private User user = null; private Map userOutputs = null; private List resourcesCreated = null; + private String tenantId = null; /** * Empty Constructor for the Builder object @@ -168,6 +178,7 @@ private Builder(WorkflowState existingState) { this.user = existingState.getUser(); this.userOutputs = existingState.userOutputs(); this.resourcesCreated = existingState.resourcesCreated(); + this.tenantId = existingState.getTenantId(); } /** @@ -260,6 +271,16 @@ public Builder resourcesCreated(List resourcesCreated) { return this; } + /** + * Builder method for adding tenant id + * @param tenantId tenant id + * @return the Builder object + */ + public Builder tenantId(String tenantId) { + this.tenantId = tenantId; + return this; + } + /** * Allows building a workflowState * @return WorkflowState workflowState Object containing all needed fields @@ -275,6 +296,7 @@ public WorkflowState build() { workflowState.user = this.user; workflowState.userOutputs = this.userOutputs; workflowState.resourcesCreated = this.resourcesCreated; + workflowState.tenantId = this.tenantId; return workflowState; } } @@ -314,6 +336,9 @@ public static WorkflowState updateExistingWorkflowState(WorkflowState existingSt if (stateWithNewFields.resourcesCreated() != null) { builder.resourcesCreated(stateWithNewFields.resourcesCreated()); } + if (stateWithNewFields.getTenantId() != null) { + builder.tenantId(stateWithNewFields.getTenantId()); + } return builder.build(); } @@ -347,6 +372,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (resourcesCreated != null && !resourcesCreated.isEmpty()) { xContentBuilder.field(RESOURCES_CREATED_FIELD, resourcesCreated.toArray()); } + if (tenantId != null) { + xContentBuilder.field(TENANT_ID_FIELD, tenantId); + } return xContentBuilder.endObject(); } @@ -377,6 +405,9 @@ public void writeTo(StreamOutput output) throws IOException { for (ResourceCreated resource : resourcesCreated) { resource.writeTo(output); } + if (output.getVersion().onOrAfter(CommonValue.VERSION_2_19_0)) { + output.writeOptionalString(tenantId); + } } /** @@ -396,6 +427,7 @@ public static WorkflowState parse(XContentParser parser) throws IOException { User user = null; Map userOutputs = new HashMap<>(); List resourcesCreated = new ArrayList<>(); + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -455,6 +487,9 @@ public static WorkflowState parse(XContentParser parser) throws IOException { throw e; } break; + case TENANT_ID_FIELD: + tenantId = parser.text(); + break; default: throw new FlowFrameworkException( "Unable to parse field [" + fieldName + "] in a workflowState object.", @@ -471,6 +506,7 @@ public static WorkflowState parse(XContentParser parser) throws IOException { .user(user) .userOutputs(userOutputs) .resourcesCreated(resourcesCreated) + .tenantId(tenantId) .build(); } @@ -562,6 +598,14 @@ public List resourcesCreated() { return resourcesCreated; } + /** + * The tenant id associated with this workflow-state + * @return the tenantId + */ + public String getTenantId() { + return tenantId; + } + @Override public String toString() { return "WorkflowState [workflowId=" diff --git a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java index 9fbc0f6d5..e5ae25287 100644 --- a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java @@ -17,6 +17,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; @@ -73,22 +74,32 @@ public AbstractSearchWorkflowAction( @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { - FlowFrameworkException ffe = new FlowFrameworkException( - "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", - RestStatus.FORBIDDEN - ); + try { + if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { + FlowFrameworkException ffe = new FlowFrameworkException( + "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", + RestStatus.FORBIDDEN + ); + return channel -> channel.sendResponse( + new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } + String tenantId = TenantAwareHelper.getTenantID(flowFrameworkSettings.isMultiTenancyEnabled(), request); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); + searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); + searchSourceBuilder.timeout(flowFrameworkSettings.getRequestTimeout()); + + // The transport action needs the tenant id but also only takes a SearchRequest. + // The tenant filtering will be handled by the metadata client. + // We'll use the preference field to communicate the tenant ID and strip it on the other end + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index).preference(tenantId); + return channel -> client.execute(actionType, searchRequest, search(channel)); + } catch (FlowFrameworkException ex) { return channel -> channel.sendResponse( - new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) ); } - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); - searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); - searchSourceBuilder.timeout(flowFrameworkSettings.getRequestTimeout()); - - SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); - return channel -> client.execute(actionType, searchRequest, search(channel)); } /** diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index b106b05f2..a3beab336 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -25,6 +25,7 @@ import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; @@ -104,60 +105,64 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli .stream() .filter(e -> !request.consumedParams().contains(e.getKey())) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { - FlowFrameworkException ffe = new FlowFrameworkException( - "This API is disabled. To enable it, set [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", - RestStatus.FORBIDDEN - ); - return channel -> channel.sendResponse( - new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) - ); - } - if (!provision && !params.isEmpty()) { - FlowFrameworkException ffe = new FlowFrameworkException( - "Only the parameters " + request.consumedParams() + " are permitted unless the provision parameter is set to true.", - RestStatus.BAD_REQUEST - ); - return processError(ffe, params, request); - } - if (provision && updateFields) { - FlowFrameworkException ffe = new FlowFrameworkException( - "You can not use both the " + PROVISION_WORKFLOW + " and " + UPDATE_WORKFLOW_FIELDS + " parameters in the same request.", - RestStatus.BAD_REQUEST - ); - return processError(ffe, params, request); - } - if (reprovision && workflowId == null) { - FlowFrameworkException ffe = new FlowFrameworkException( - "You can not use the " + REPROVISION_WORKFLOW + " parameter to create a new template.", - RestStatus.BAD_REQUEST - ); - return processError(ffe, params, request); - } - if (reprovision && useCase != null) { - FlowFrameworkException ffe = new FlowFrameworkException( - "You cannot use the " + REPROVISION_WORKFLOW + " and " + USE_CASE + " parameters in the same request.", - RestStatus.BAD_REQUEST - ); - return processError(ffe, params, request); - } - if (reprovision && !params.isEmpty()) { - FlowFrameworkException ffe = new FlowFrameworkException( - "Only the parameters " + request.consumedParams() + " are permitted unless the provision parameter is set to true.", - RestStatus.BAD_REQUEST - ); - return processError(ffe, params, request); - } - // Ensure wait_for_completion is not set unless reprovision or provision is true - if (waitForCompletionTimeout != TimeValue.MINUS_ONE && !(reprovision || provision)) { - FlowFrameworkException ffe = new FlowFrameworkException( - "Request parameters 'wait_for_completion_timeout' are not allowed unless the 'provision' or 'reprovision' parameter is set to true.", - RestStatus.BAD_REQUEST - ); - return processError(ffe, params, request); - } - try { + if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { + FlowFrameworkException ffe = new FlowFrameworkException( + "This API is disabled. To enable it, set [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", + RestStatus.FORBIDDEN + ); + return channel -> channel.sendResponse( + new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } + String tenantId = TenantAwareHelper.getTenantID(flowFrameworkSettings.isMultiTenancyEnabled(), request); + if (!provision && !params.isEmpty()) { + FlowFrameworkException ffe = new FlowFrameworkException( + "Only the parameters " + request.consumedParams() + " are permitted unless the provision parameter is set to true.", + RestStatus.BAD_REQUEST + ); + return processError(ffe, params, request); + } + if (provision && updateFields) { + FlowFrameworkException ffe = new FlowFrameworkException( + "You can not use both the " + + PROVISION_WORKFLOW + + " and " + + UPDATE_WORKFLOW_FIELDS + + " parameters in the same request.", + RestStatus.BAD_REQUEST + ); + return processError(ffe, params, request); + } + if (reprovision && workflowId == null) { + FlowFrameworkException ffe = new FlowFrameworkException( + "You can not use the " + REPROVISION_WORKFLOW + " parameter to create a new template.", + RestStatus.BAD_REQUEST + ); + return processError(ffe, params, request); + } + if (reprovision && useCase != null) { + FlowFrameworkException ffe = new FlowFrameworkException( + "You cannot use the " + REPROVISION_WORKFLOW + " and " + USE_CASE + " parameters in the same request.", + RestStatus.BAD_REQUEST + ); + return processError(ffe, params, request); + } + if (reprovision && !params.isEmpty()) { + FlowFrameworkException ffe = new FlowFrameworkException( + "Only the parameters " + request.consumedParams() + " are permitted unless the provision parameter is set to true.", + RestStatus.BAD_REQUEST + ); + return processError(ffe, params, request); + } + // Ensure wait_for_completion is not set unless reprovision or provision is true + if (waitForCompletionTimeout != TimeValue.MINUS_ONE && !(reprovision || provision)) { + FlowFrameworkException ffe = new FlowFrameworkException( + "Request parameter 'wait_for_completion_timeout' is not allowed unless the 'provision' or 'reprovision' parameter is set to true.", + RestStatus.BAD_REQUEST + ); + return processError(ffe, params, request); + } Template template; Map useCaseDefaultsMap = Collections.emptyMap(); if (useCase != null) { @@ -234,6 +239,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli if (waitForCompletionTimeout != TimeValue.MINUS_ONE) { params = Map.of(WAIT_FOR_COMPLETION_TIMEOUT, waitForCompletionTimeout.toString()); } + if (tenantId != null) { + template.setTenantId(tenantId); + } WorkflowRequest workflowRequest = new WorkflowRequest( workflowId, template, diff --git a/src/main/java/org/opensearch/flowframework/rest/RestDeleteWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestDeleteWorkflowAction.java index ebef48c86..65251d235 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestDeleteWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestDeleteWorkflowAction.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.DeleteWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; @@ -32,6 +33,7 @@ import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.model.Template.createEmptyTemplateWithTenantId; /** * Rest Action to facilitate requests to delete a stored template @@ -71,6 +73,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request RestStatus.FORBIDDEN ); } + String tenantId = TenantAwareHelper.getTenantID(flowFrameworkFeatureEnabledSetting.isMultiTenancyEnabled(), request); // Always consume content to silently ignore it // https://github.com/opensearch-project/flow-framework/issues/578 @@ -80,7 +83,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request if (workflowId == null) { throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, request.params()); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, createEmptyTemplateWithTenantId(tenantId), request.params()); return channel -> client.execute(DeleteWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java index d75255b8d..783404ff0 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.DeprovisionWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; @@ -34,6 +35,7 @@ import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.model.Template.createEmptyTemplateWithTenantId; /** * Rest Action to facilitate requests to de-provision a workflow @@ -68,6 +70,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request RestStatus.FORBIDDEN ); } + String tenantId = TenantAwareHelper.getTenantID(flowFrameworkFeatureEnabledSetting.isMultiTenancyEnabled(), request); // Always consume content to silently ignore it // https://github.com/opensearch-project/flow-framework/issues/578 @@ -79,7 +82,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request } WorkflowRequest workflowRequest = new WorkflowRequest( workflowId, - null, + createEmptyTemplateWithTenantId(tenantId), allowDelete == null ? Collections.emptyMap() : Map.of(ALLOW_DELETE, allowDelete) ); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java index 81141a380..3a05b4fb4 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.GetWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; @@ -31,6 +32,7 @@ import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.model.Template.createEmptyTemplateWithTenantId; /** * Rest Action to facilitate requests to get a stored template @@ -69,7 +71,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request RestStatus.FORBIDDEN ); } - + String tenantId = TenantAwareHelper.getTenantID(flowFrameworkFeatureEnabledSetting.isMultiTenancyEnabled(), request); // Always consume content to silently ignore it // https://github.com/opensearch-project/flow-framework/issues/578 request.content(); @@ -79,7 +81,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, createEmptyTemplateWithTenantId(tenantId)); return channel -> client.execute(GetWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java index bdf4df35f..286d4d1e1 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.GetWorkflowStateAction; import org.opensearch.flowframework.transport.GetWorkflowStateRequest; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; @@ -65,6 +66,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request RestStatus.FORBIDDEN ); } + String tenantId = TenantAwareHelper.getTenantID(flowFrameworkFeatureEnabledSetting.isMultiTenancyEnabled(), request); // Always consume content to silently ignore it // https://github.com/opensearch-project/flow-framework/issues/578 @@ -75,7 +77,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } - GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest(workflowId, all); + GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest(workflowId, all, tenantId); return channel -> client.execute(GetWorkflowStateAction.INSTANCE, getWorkflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java index b889cede7..d970a080d 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStepAction.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.GetWorkflowStepAction; import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; @@ -70,6 +71,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli RestStatus.FORBIDDEN ); } + String tenantId = TenantAwareHelper.getTenantID(flowFrameworkSettings.isMultiTenancyEnabled(), request); // Always consume content to silently ignore it // https://github.com/opensearch-project/flow-framework/issues/578 diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index e197312ed..fa53aec27 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -22,6 +22,7 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; @@ -38,6 +39,7 @@ import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.model.Template.createEmptyTemplateWithTenantId; /** * Rest action to facilitate requests to provision a workflow from an inline defined or stored use case template @@ -84,12 +86,18 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli RestStatus.FORBIDDEN ); } + String tenantId = TenantAwareHelper.getTenantID(flowFrameworkFeatureEnabledSetting.isMultiTenancyEnabled(), request); // Validate params if (workflowId == null) { throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } // Create request and provision - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, params, waitForCompletionTimeout); + WorkflowRequest workflowRequest = new WorkflowRequest( + workflowId, + createEmptyTemplateWithTenantId(tenantId), + params, + waitForCompletionTimeout + ); return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 40fea99e0..2cc082994 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -12,8 +12,8 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessageFactory; import org.opensearch.ExceptionsHelper; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -26,7 +26,6 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -34,15 +33,21 @@ import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.plugins.PluginsService; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import java.io.IOException; import java.time.Instant; import java.util.Arrays; import java.util.Collections; @@ -69,6 +74,7 @@ public class CreateWorkflowTransportAction extends HandledTransportAction listener) { + String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId(); + if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) { + return; + } User user = getUserContext(client); String workflowId = request.getWorkflowId(); try { - resolveUserAndExecute(user, workflowId, listener, () -> createExecute(request, user, listener)); + resolveUserAndExecute( + user, + workflowId, + tenantId, + flowFrameworkSettings.isMultiTenancyEnabled(), + listener, + () -> createExecute(request, user, tenantId, listener) + ); } catch (Exception e) { logger.error("Failed to create workflow", e); listener.onFailure(e); @@ -129,12 +149,15 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener listener, Runnable function ) { @@ -160,11 +183,14 @@ private void resolveUserAndExecute( getWorkflow( requestedUser, workflowId, + tenantId, filterByBackendRole, false, + isMultitenancyEnabled, listener, function, client, + sdkClient, clusterService, xContentRegistry ); @@ -190,9 +216,10 @@ private void resolveUserAndExecute( * 4. Create or update provisioning progress index * @param request the workflow request * @param user the user making the request + * @param tenantId the tenant id * @param listener the action listener */ - private void createExecute(WorkflowRequest request, User user, ActionListener listener) { + private void createExecute(WorkflowRequest request, User user, String tenantId, ActionListener listener) { Instant creationTime = Instant.now(); Template templateWithUser = new Template( request.getTemplate().name(), @@ -205,7 +232,8 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { if (FALSE.equals(max)) { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( @@ -253,7 +282,7 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { + flowFrameworkIndicesHandler.initializeConfigIndex(tenantId, ActionListener.wrap(isInitialized -> { if (FALSE.equals(isInitialized)) { listener.onFailure( new FlowFrameworkException("Failed to initalize config index", RestStatus.INTERNAL_SERVER_ERROR) @@ -265,13 +294,14 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { flowFrameworkIndicesHandler.putInitialStateToWorkflowState( globalContextResponse.getId(), + tenantId, user, ActionListener.wrap(stateResponse -> { logger.info("Creating state workflow doc: {}", globalContextResponse.getId()); if (request.isProvision()) { WorkflowRequest workflowRequest = new WorkflowRequest( globalContextResponse.getId(), - null, + Template.createEmptyTemplateWithTenantId(tenantId), request.getParams(), waitForTimeCompletion ); @@ -357,156 +387,177 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { - context.restore(); - if (getResponse.isExists()) { - - Template existingTemplate = Template.parse(getResponse.getSourceAsString()); - Template template = isFieldUpdate - ? Template.updateExistingTemplate(existingTemplate, templateWithUser) - : Template.builder(templateWithUser) - .createdTime(existingTemplate.createdTime()) - .lastUpdatedTime(Instant.now()) - .lastProvisionedTime(existingTemplate.lastProvisionedTime()) - .build(); - - if (request.isReprovision()) { - // Reprovision request - ReprovisionWorkflowRequest reprovisionRequest = new ReprovisionWorkflowRequest( - getResponse.getId(), - existingTemplate, - template, - waitForTimeCompletion - ); - logger.info("Reprovisioning parameter is set, continuing to reprovision workflow {}", getResponse.getId()); - client.execute( - ReprovisionWorkflowAction.INSTANCE, - reprovisionRequest, - ActionListener.wrap(reprovisionResponse -> { - listener.onResponse( - reprovisionRequest.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE - ? new WorkflowResponse(reprovisionResponse.getWorkflowId()) - : new WorkflowResponse( - reprovisionResponse.getWorkflowId(), - reprovisionResponse.getWorkflowState() - ) - ); - }, exception -> { - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( - "Reprovisioning failed for workflow {}", - workflowId - ).getFormattedMessage(); - logger.error(errorMessage, exception); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - } - }) - ); - } else { - - // Update existing entry, full document replacement - flowFrameworkIndicesHandler.updateTemplateInGlobalContext( - request.getWorkflowId(), - template, - ActionListener.wrap(response -> { - // Regular update, reset provisioning status, ignore state index if updating fields - if (!isFieldUpdate) { - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( - request.getWorkflowId(), - Map.ofEntries( - Map.entry(STATE_FIELD, State.NOT_STARTED), - Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.NOT_STARTED) - ), - ActionListener.wrap(updateResponse -> { - logger.info( - "updated workflow {} state to {}", - request.getWorkflowId(), - State.NOT_STARTED.name() - ); - listener.onResponse(new WorkflowResponse(request.getWorkflowId())); - }, exception -> { - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( - "Failed to update workflow {} in template index", - request.getWorkflowId() - ).getFormattedMessage(); - logger.error(errorMessage, exception); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) - ); - } - }) - ); - } else { - listener.onResponse(new WorkflowResponse(request.getWorkflowId())); - } - }, exception -> { - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( - "Failed to update use case template {}", - request.getWorkflowId() - ).getFormattedMessage(); - logger.error(errorMessage, exception); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - } - }), - isFieldUpdate - ); + sdkClient.getDataObjectAsync( + GetDataObjectRequest.builder().index(GLOBAL_CONTEXT_INDEX).id(workflowId).tenantId(tenantId).build() + ).whenComplete((r, throwable) -> { + if (throwable == null) { + context.restore(); + try { + GetResponse getResponse = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); + if (getResponse.isExists()) { + handleWorkflowExists(request, templateWithUser, getResponse, waitForTimeCompletion, listener); + } else { + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( + "Failed to retrieve template ({}) from global context.", + workflowId + ).getFormattedMessage(); + logger.error(errorMessage); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); + } + } catch (IOException e) { + logger.error("Failed to parse workflow getResponse: {}", workflowId, e); + listener.onFailure(e); } } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( "Failed to retrieve template ({}) from global context.", workflowId ).getFormattedMessage(); - logger.error(errorMessage); - listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); } - }, exception -> { - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( - "Failed to retrieve template ({}) from global context.", - workflowId - ).getFormattedMessage(); - logger.error(errorMessage, exception); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - })); + }); } } } + private void handleWorkflowExists( + WorkflowRequest request, + Template templateWithUser, + GetResponse getResponse, + TimeValue waitForTimeCompletion, + ActionListener listener + ) throws IOException { + Template existingTemplate = Template.parse(getResponse.getSourceAsString()); + Template template = request.isUpdateFields() + ? Template.updateExistingTemplate(existingTemplate, templateWithUser) + : Template.builder(templateWithUser) + .createdTime(existingTemplate.createdTime()) + .lastUpdatedTime(Instant.now()) + .lastProvisionedTime(existingTemplate.lastProvisionedTime()) + .tenantId(existingTemplate.getTenantId()) + .build(); + + if (request.isReprovision()) { + handleReprovision(request.getWorkflowId(), existingTemplate, template, waitForTimeCompletion, listener); + } else { + // Update existing entry, full document replacement + handleFullDocUpdate(request, template, listener); + } + } + + private void handleReprovision( + String workflowId, + Template existingTemplate, + Template template, + TimeValue waitForTimeCompletion, + ActionListener listener + ) { + ReprovisionWorkflowRequest reprovisionRequest = new ReprovisionWorkflowRequest( + workflowId, + existingTemplate, + template, + waitForTimeCompletion + ); + logger.info("Reprovisioning parameter is set, continuing to reprovision workflow {}", workflowId); + client.execute(ReprovisionWorkflowAction.INSTANCE, reprovisionRequest, ActionListener.wrap(reprovisionResponse -> { + listener.onResponse(new WorkflowResponse(reprovisionResponse.getWorkflowId())); + }, exception -> { + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Reprovisioning failed for workflow {}", workflowId) + .getFormattedMessage(); + logger.error(errorMessage, exception); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + })); + } + + private void handleFullDocUpdate(WorkflowRequest request, Template template, ActionListener listener) { + final boolean isFieldUpdate = request.isUpdateFields(); + flowFrameworkIndicesHandler.updateTemplateInGlobalContext(request.getWorkflowId(), template, ActionListener.wrap(response -> { + // Regular update, reset provisioning status, ignore state index if updating fields + if (!isFieldUpdate) { + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + request.getWorkflowId(), + template.getTenantId(), + Map.ofEntries( + Map.entry(STATE_FIELD, State.NOT_STARTED), + Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.NOT_STARTED) + ), + ActionListener.wrap(updateResponse -> { + logger.info("updated workflow {} state to {}", request.getWorkflowId(), State.NOT_STARTED.name()); + listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + }, exception -> { + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( + "Failed to update workflow {} in template index", + request.getWorkflowId() + ).getFormattedMessage(); + logger.error(errorMessage, exception); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }) + ); + } else { + listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + } + }, exception -> { + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( + "Failed to update use case template {}", + request.getWorkflowId() + ).getFormattedMessage(); + logger.error(errorMessage, exception); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }), isFieldUpdate); + } + /** * Checks if the max workflows limit has been reachesd * @param requestTimeOut request time out * @param maxWorkflow max workflows * @param internalListener listener for search request */ - void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, ActionListener internalListener) { - if (!flowFrameworkIndicesHandler.doesIndexExist(CommonValue.GLOBAL_CONTEXT_INDEX)) { + void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, String tenantId, ActionListener internalListener) { + if (!flowFrameworkIndicesHandler.doesIndexExist(GLOBAL_CONTEXT_INDEX)) { internalListener.onResponse(true); } else { QueryBuilder query = QueryBuilders.matchAllQuery(); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeOut); - - SearchRequest searchRequest = new SearchRequest(CommonValue.GLOBAL_CONTEXT_INDEX).source(searchSourceBuilder); + SearchDataObjectRequest searchRequest = SearchDataObjectRequest.builder() + .indices(GLOBAL_CONTEXT_INDEX) + .searchSourceBuilder(searchSourceBuilder) + .tenantId(tenantId) + .build(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - logger.info("Querying existing workflows to count the max"); - client.search(searchRequest, ActionListener.wrap(searchResponse -> { - context.restore(); - internalListener.onResponse(searchResponse.getHits().getTotalHits().value < maxWorkflow); - }, exception -> { - String errorMessage = "Unable to fetch the workflows"; - logger.error(errorMessage, exception); - internalListener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - })); + sdkClient.searchDataObjectAsync(searchRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + context.restore(); + try { + SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); + internalListener.onResponse(searchResponse.getHits().getTotalHits().value < maxWorkflow); + } catch (Exception e) { + logger.error("Failed to parse workflow searchResponse", e); + internalListener.onFailure(e); + } + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + String errorMessage = "Unable to fetch the workflows"; + logger.error(errorMessage, exception); + internalListener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }); } catch (Exception e) { String errorMessage = "Unable to fetch the workflows"; logger.error(errorMessage, e); @@ -517,7 +568,12 @@ void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, ActionList private void validateWorkflows(Template template) throws Exception { for (Workflow workflow : template.workflows().values()) { - List sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null, Collections.emptyMap(), "fakeTenantId"); + List sortedNodes = workflowProcessSorter.sortProcessNodes( + workflow, + null, + Collections.emptyMap(), + template.getTenantId() + ); workflowProcessSorter.validate(sortedNodes, pluginsService); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java index 41e61c01a..81c724065 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java @@ -11,7 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.core.util.Booleans; -import org.opensearch.action.delete.DeleteRequest; +import org.apache.logging.log4j.message.ParameterizedMessageFactory; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -24,8 +24,13 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.TenantAwareHelper; +import org.opensearch.remote.metadata.client.DeleteDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -43,7 +48,9 @@ public class DeleteWorkflowTransportAction extends HandledTransportAction listener) { if (flowFrameworkIndicesHandler.doesIndexExist(GLOBAL_CONTEXT_INDEX)) { + String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId(); + if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) { + return; + } String workflowId = request.getWorkflowId(); User user = getUserContext(client); @@ -90,11 +107,14 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener executeDeleteRequest(request, listener, context), + () -> executeDeleteRequest(request, tenantId, listener, context), client, + sdkClient, clusterService, xContentRegistry ); @@ -109,27 +129,49 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener listener, ThreadContext.StoredContext context ) { String workflowId = request.getWorkflowId(); - DeleteRequest deleteRequest = new DeleteRequest(GLOBAL_CONTEXT_INDEX, workflowId); - logger.info("Deleting workflow doc: {}", workflowId); - client.delete(deleteRequest, ActionListener.runBefore(listener, context::restore)); + DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest.builder() + .index(GLOBAL_CONTEXT_INDEX) + .id(workflowId) + .tenantId(tenantId) + .build(); + sdkClient.deleteDataObjectAsync(deleteRequest).whenComplete((r, throwable) -> { + context.restore(); + if (throwable == null) { + try { + DeleteResponse response = DeleteResponse.fromXContent(r.parser()); + listener.onResponse(response); + } catch (Exception e) { + logger.error("Failed to parse delete response", e); + listener.onFailure(new FlowFrameworkException("Failed to parse delete response", RestStatus.INTERNAL_SERVER_ERROR)); + } + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Failed to delete template {}", workflowId) + .getFormattedMessage(); + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR)); + } + }); // Whether to force deletion of corresponding state final boolean clearStatus = Booleans.parseBoolean(request.getParams().get(CLEAR_STATUS), false); ActionListener stateListener = ActionListener.wrap(response -> { logger.info("Deleted workflow state doc: {}", workflowId); }, exception -> { logger.info("Failed to delete workflow state doc: {}", workflowId, exception); }); - flowFrameworkIndicesHandler.canDeleteWorkflowStateDoc(workflowId, clearStatus, canDelete -> { + flowFrameworkIndicesHandler.canDeleteWorkflowStateDoc(workflowId, tenantId, clearStatus, canDelete -> { if (Boolean.TRUE.equals(canDelete)) { - flowFrameworkIndicesHandler.deleteFlowFrameworkSystemIndexDoc(workflowId, stateListener); + flowFrameworkIndicesHandler.deleteFlowFrameworkSystemIndexDoc(workflowId, tenantId, stateListener); } }, stateListener); } diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java index b840e1362..5eacbc699 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -32,10 +32,12 @@ import org.opensearch.flowframework.model.ProvisioningProgress; import org.opensearch.flowframework.model.ResourceCreated; import org.opensearch.flowframework.model.State; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowStep; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -72,6 +74,7 @@ public class DeprovisionWorkflowTransportAction extends HandledTransportAction listener) { + String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId(); + if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) { + return; + } String workflowId = request.getWorkflowId(); - User user = getUserContext(client); // Stash thread context to interact with system index @@ -128,11 +137,14 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener executeDeprovisionRequest(request, listener, context, user), + () -> executeDeprovisionRequest(request, tenantId, listener, context, user), client, + sdkClient, clusterService, xContentRegistry ); @@ -146,13 +158,14 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener listener, ThreadContext.StoredContext context, User user ) { String workflowId = request.getWorkflowId(); String allowDelete = request.getParams().get(ALLOW_DELETE); - GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); + GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true, tenantId); logger.info("Querying state for workflow: {}", workflowId); client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { context.restore(); @@ -163,6 +176,7 @@ private void executeDeprovisionRequest( .execute( () -> executeDeprovisionSequence( workflowId, + tenantId, response.getWorkflowState().resourcesCreated(), deleteAllowedResources, listener, @@ -181,6 +195,7 @@ private void executeDeprovisionRequest( private void executeDeprovisionSequence( String workflowId, + String tenantId, List resourcesCreated, Set deleteAllowedResources, ActionListener listener, @@ -212,7 +227,7 @@ private void executeDeprovisionSequence( this.threadPool, DEPROVISION_WORKFLOW_THREAD_POOL, flowFrameworkSettings.getRequestTimeout(), - "fakeTenantId" + tenantId ) ); } @@ -236,7 +251,7 @@ private void executeDeprovisionSequence( deprovisionFuture.get(); logger.info("Successful {} for {}", deprovisionNode.id(), resourceNameAndId); // Remove from state index resource list - flowFrameworkIndicesHandler.deleteResourceFromStateIndex(workflowId, resource, stateUpdateFuture); + flowFrameworkIndicesHandler.deleteResourceFromStateIndex(workflowId, tenantId, resource, stateUpdateFuture); try { // Wait at most 1 second for state index update. stateUpdateFuture.actionGet(1, TimeUnit.SECONDS); @@ -276,7 +291,7 @@ private void executeDeprovisionSequence( this.threadPool, DEPROVISION_WORKFLOW_THREAD_POOL, pn.nodeTimeout(), - "fakeTenantId" + tenantId ); }).collect(Collectors.toList()); // Pause briefly before next loop @@ -300,11 +315,12 @@ private void executeDeprovisionSequence( logger.info("Resources requiring allow_delete: {}.", deleteNotAllowed); } // This is a redundant best-effort backup to the incremental deletion done earlier - updateWorkflowState(workflowId, remainingResources, deleteNotAllowed, listener, user); + updateWorkflowState(workflowId, tenantId, remainingResources, deleteNotAllowed, listener, user); } private void updateWorkflowState( String workflowId, + String tenantId, List remainingResources, List deleteNotAllowed, ActionListener listener, @@ -312,15 +328,24 @@ private void updateWorkflowState( ) { if (remainingResources.isEmpty() && deleteNotAllowed.isEmpty()) { // Successful deprovision of all resources, reset state to initial - flowFrameworkIndicesHandler.doesTemplateExist(workflowId, templateExists -> { + flowFrameworkIndicesHandler.doesTemplateExist(workflowId, tenantId, templateExists -> { if (Boolean.TRUE.equals(templateExists)) { - flowFrameworkIndicesHandler.putInitialStateToWorkflowState(workflowId, user, ActionListener.wrap(indexResponse -> { - logger.info("Reset workflow {} state to NOT_STARTED", workflowId); - }, exception -> { logger.error("Failed to reset to initial workflow state for {}", workflowId, exception); })); + flowFrameworkIndicesHandler.putInitialStateToWorkflowState( + workflowId, + tenantId, + user, + ActionListener.wrap(indexResponse -> { + logger.info("Reset workflow {} state to NOT_STARTED", workflowId); + }, exception -> { logger.error("Failed to reset to initial workflow state for {}", workflowId, exception); }) + ); } else { - flowFrameworkIndicesHandler.deleteFlowFrameworkSystemIndexDoc(workflowId, ActionListener.wrap(deleteResponse -> { - logger.info("Deleted workflow {} state", workflowId); - }, exception -> { logger.error("Failed to delete workflow state for {}", workflowId, exception); })); + flowFrameworkIndicesHandler.deleteFlowFrameworkSystemIndexDoc( + workflowId, + tenantId, + ActionListener.wrap(deleteResponse -> { + logger.info("Deleted workflow {} state", workflowId); + }, exception -> { logger.error("Failed to delete workflow state for {}", workflowId, exception); }) + ); } // return workflow ID listener.onResponse(new WorkflowResponse(workflowId)); @@ -332,6 +357,7 @@ private void updateWorkflowState( stateIndexResources.addAll(deleteNotAllowed); flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( workflowId, + tenantId, Map.ofEntries( Map.entry(STATE_FIELD, State.COMPLETED), Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.DONE), diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java index 7fd546c25..79f1c1f2c 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java @@ -13,6 +13,7 @@ import org.opensearch.common.Nullable; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.flowframework.common.CommonValue; import java.io.IOException; @@ -32,14 +33,18 @@ public class GetWorkflowStateRequest extends ActionRequest { */ private boolean all; + private String tenantId; + /** * Instantiates a new GetWorkflowStateRequest * @param workflowId the documentId of the workflow * @param all whether the get request is looking for all fields in status + * @param tenantId the tenant id */ - public GetWorkflowStateRequest(@Nullable String workflowId, boolean all) { + public GetWorkflowStateRequest(@Nullable String workflowId, boolean all, String tenantId) { this.workflowId = workflowId; this.all = all; + this.tenantId = tenantId; } /** @@ -51,6 +56,9 @@ public GetWorkflowStateRequest(StreamInput in) throws IOException { super(in); this.workflowId = in.readString(); this.all = in.readBoolean(); + if (in.getVersion().onOrAfter(CommonValue.VERSION_2_19_0)) { + this.tenantId = in.readOptionalString(); + } } /** @@ -70,11 +78,22 @@ public boolean getAll() { return this.all; } + /** + * Gets the tenant Id + * @return the tenant id + */ + public String getTenantId() { + return this.tenantId; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(workflowId); out.writeBoolean(all); + if (out.getVersion().onOrAfter(CommonValue.VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } @Override diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java index 4106e492a..59d7f7294 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java @@ -12,7 +12,6 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessageFactory; import org.opensearch.ExceptionsHelper; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -22,18 +21,16 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; -import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.util.ParseUtils; -import org.opensearch.index.IndexNotFoundException; +import org.opensearch.flowframework.util.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; import static org.opensearch.flowframework.util.ParseUtils.resolveUserAndExecute; @@ -47,7 +44,10 @@ public class GetWorkflowStateTransportAction extends HandledTransportAction listener) { + String tenantId = request.getTenantId(); + if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) { + return; + } String workflowId = request.getWorkflowId(); User user = ParseUtils.getUserContext(client); @@ -88,11 +101,14 @@ protected void doExecute(Task task, GetWorkflowStateRequest request, ActionListe resolveUserAndExecute( user, workflowId, + tenantId, filterByEnabled, true, + flowFrameworkSettings.isMultiTenancyEnabled(), listener, - () -> executeGetWorkflowStateRequest(request, listener, context), + () -> executeGetWorkflowStateRequest(request, tenantId, listener, context), client, + sdkClient, clusterService, xContentRegistry ); @@ -108,41 +124,20 @@ protected void doExecute(Task task, GetWorkflowStateRequest request, ActionListe /** * Execute the get workflow state request * @param request the get workflow state request + * @param tenantId the tenant id * @param listener the action listener * @param context the thread context */ private void executeGetWorkflowStateRequest( GetWorkflowStateRequest request, + String tenantId, ActionListener listener, ThreadContext.StoredContext context ) { String workflowId = request.getWorkflowId(); - GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX).id(workflowId); - logger.info("Querying state workflow doc: {}", workflowId); - client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - WorkflowState workflowState = WorkflowState.parse(parser); - listener.onResponse(new GetWorkflowStateResponse(workflowState, request.getAll())); - } catch (Exception e) { - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Failed to parse workflowState: {}", r.getId()) - .getFormattedMessage(); - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); - } - } else { - listener.onFailure(new FlowFrameworkException("Fail to find workflow status of " + workflowId, RestStatus.NOT_FOUND)); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - listener.onFailure(new FlowFrameworkException("Fail to find workflow status of " + workflowId, RestStatus.NOT_FOUND)); - } else { - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Failed to get workflow status of: {}", workflowId) - .getFormattedMessage(); - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); - } - }), context::restore)); + flowFrameworkIndicesHandler.getWorkflowState(workflowId, tenantId, ActionListener.wrap(workflowState -> { + GetWorkflowStateResponse workflowStateResponse = new GetWorkflowStateResponse(workflowState, request.getAll()); + listener.onResponse(workflowStateResponse); + }, listener::onFailure), context); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java index 2462a839d..ed70e65b2 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java @@ -12,7 +12,6 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessageFactory; import org.opensearch.ExceptionsHelper; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -24,11 +23,14 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.flowframework.util.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -45,7 +47,9 @@ public class GetWorkflowTransportAction extends HandledTransportAction listener) { if (flowFrameworkIndicesHandler.doesIndexExist(GLOBAL_CONTEXT_INDEX)) { - + String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId(); + if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) { + return; + } String workflowId = request.getWorkflowId(); - User user = getUserContext(client); // Retrieve workflow by ID @@ -97,11 +109,14 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener executeGetRequest(request, listener, context), + () -> executeGetRequest(request, tenantId, listener, context), client, + sdkClient, clusterService, xContentRegistry ); @@ -129,15 +144,13 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener listener, ThreadContext.StoredContext context ) { String workflowId = request.getWorkflowId(); - GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); logger.info("Querying workflow from global context: {}", workflowId); - client.get(getRequest, ActionListener.wrap(response -> { - context.restore(); - + flowFrameworkIndicesHandler.getTemplate(workflowId, tenantId, ActionListener.wrap(response -> { if (!response.isExists()) { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( "Failed to retrieve template ({}) from global context.", @@ -158,6 +171,6 @@ private void executeGetRequest( ).getFormattedMessage(); logger.error(errorMessage, exception); listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - })); + }), context); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index e0d405075..d5550a8f5 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -12,7 +12,6 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessageFactory; import org.opensearch.ExceptionsHelper; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.support.PlainActionFuture; @@ -26,6 +25,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.ProvisioningProgress; @@ -33,12 +33,13 @@ import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.flowframework.util.WorkflowTimeoutUtility; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.plugins.PluginsService; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; -import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import java.time.Instant; @@ -52,7 +53,6 @@ import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.ERROR_FIELD; -import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_END_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD; @@ -71,10 +71,11 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction listener) { // Retrieve use case template from global context + String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId(); + if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) { + return; + } String workflowId = request.getWorkflowId(); - User user = getUserContext(client); // Stash thread context to interact with system index @@ -135,11 +142,14 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener executeProvisionRequest(request, listener, context), + () -> executeProvisionRequest(request, tenantId, listener, context), client, + sdkClient, clusterService, xContentRegistry ); @@ -163,18 +173,19 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener listener, ThreadContext.StoredContext context ) { String workflowId = request.getWorkflowId(); - GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); logger.info("Querying workflow from global context: {}", workflowId); - client.get(getRequest, ActionListener.wrap(response -> { + flowFrameworkIndicesHandler.getTemplate(workflowId, tenantId, ActionListener.wrap(response -> { context.restore(); if (!response.isExists()) { @@ -199,15 +210,16 @@ private void executeProvisionRequest( provisionWorkflow, workflowId, request.getParams(), - "fakeTenantId" + tenantId ); workflowProcessSorter.validate(provisionProcessSequence, pluginsService); - flowFrameworkIndicesHandler.getProvisioningProgress(workflowId, progress -> { + flowFrameworkIndicesHandler.getProvisioningProgress(workflowId, tenantId, progress -> { if (ProvisioningProgress.NOT_STARTED.equals(progress.orElse(null))) { // update state index flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( workflowId, + tenantId, Map.ofEntries( Map.entry(STATE_FIELD, State.PROVISIONING), Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.IN_PROGRESS), @@ -217,10 +229,11 @@ private void executeProvisionRequest( ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", request.getWorkflowId(), State.PROVISIONING); if (request.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) { - executeWorkflowAsync(workflowId, provisionProcessSequence, listener); + executeWorkflowAsync(workflowId, tenantId, provisionProcessSequence, listener); } else { executeWorkflowSync( workflowId, + tenantId, provisionProcessSequence, listener, request.getWaitForCompletionTimeout().getMillis() @@ -283,7 +296,7 @@ private void executeProvisionRequest( logger.error(errorMessage, exception); listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); } - })); + }), context); } /** @@ -292,10 +305,16 @@ private void executeProvisionRequest( * @param workflowSequence The sorted workflow to execute * @param listener ActionListener for any failures that don't get caught earlier in below step */ - private void executeWorkflowAsync(String workflowId, List workflowSequence, ActionListener listener) { + private void executeWorkflowAsync( + String workflowId, + String tenantId, + List workflowSequence, + ActionListener listener + ) { try { - threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL) - .execute(() -> { executeWorkflow(workflowSequence, workflowId, listener, false); }); + client.threadPool().executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { + executeWorkflow(workflowSequence, workflowId, tenantId, listener, false); + }); } catch (Exception exception) { listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(exception))); } @@ -312,6 +331,7 @@ private void executeWorkflowAsync(String workflowId, List workflowS */ private void executeWorkflowSync( String workflowId, + String tenantId, List workflowSequence, ActionListener listener, long timeout @@ -320,7 +340,7 @@ private void executeWorkflowSync( CompletableFuture.runAsync(() -> { try { - executeWorkflow(workflowSequence, workflowId, new ActionListener<>() { + executeWorkflow(workflowSequence, workflowId, tenantId, new ActionListener<>() { @Override public void onResponse(WorkflowResponse workflowResponse) { WorkflowTimeoutUtility.handleResponse(workflowId, workflowResponse, isResponseSent, listener); @@ -334,21 +354,23 @@ public void onFailure(Exception e) { } catch (Exception ex) { WorkflowTimeoutUtility.handleFailure(workflowId, ex, isResponseSent, listener); } - }, threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL)); + }, client.threadPool().executor(PROVISION_WORKFLOW_THREAD_POOL)); - WorkflowTimeoutUtility.scheduleTimeoutHandler(client, threadPool, workflowId, listener, timeout, isResponseSent); + WorkflowTimeoutUtility.scheduleTimeoutHandler(client, client.threadPool(), workflowId, tenantId, listener, timeout, isResponseSent); } /** * Executes the given workflow sequence * @param workflowSequence The topologically sorted workflow to execute * @param workflowId The workflowId associated with the workflow that is executing + * @param tenantId The tenant id * @param listener The ActionListener to handle the workflow response or failure * @param isSyncExecution Flag indicating whether the workflow should be executed synchronously (true) or asynchronously (false) */ private void executeWorkflow( List workflowSequence, String workflowId, + String tenantId, ActionListener listener, boolean isSyncExecution ) { @@ -382,6 +404,7 @@ private void executeWorkflow( logger.info("Provisioning completed successfully for workflow {}", workflowId); flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( workflowId, + tenantId, Map.ofEntries( Map.entry(STATE_FIELD, State.COMPLETED), Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.DONE), @@ -392,7 +415,7 @@ private void executeWorkflow( if (isSyncExecution) { client.execute( GetWorkflowStateAction.INSTANCE, - new GetWorkflowStateRequest(workflowId, false), + new GetWorkflowStateRequest(workflowId, false, tenantId), ActionListener.wrap(response -> { listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())); }, exception -> { @@ -423,6 +446,7 @@ private void executeWorkflow( + status.toString(); flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( workflowId, + tenantId, Map.ofEntries( Map.entry(STATE_FIELD, State.FAILED), Map.entry(ERROR_FIELD, errorMessage), diff --git a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java index ff0dcef62..ec5637cdb 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java @@ -35,11 +35,13 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.flowframework.util.WorkflowTimeoutUtility; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.plugins.PluginsService; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -75,6 +77,7 @@ public class ReprovisionWorkflowTransportAction extends HandledTransportAction listener) { - + String tenantId = request.getUpdatedTemplate() == null ? null : request.getUpdatedTemplate().getTenantId(); + if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) { + return; + } String workflowId = request.getWorkflowId(); User user = getUserContext(client); @@ -142,11 +151,14 @@ protected void doExecute(Task task, ReprovisionWorkflowRequest request, ActionLi resolveUserAndExecute( user, workflowId, + tenantId, filterByEnabled, false, + flowFrameworkSettings.isMultiTenancyEnabled(), listener, - () -> executeReprovisionRequest(request, listener, context), + () -> executeReprovisionRequest(request, tenantId, listener, context), client, + sdkClient, clusterService, xContentRegistry ); @@ -164,18 +176,20 @@ protected void doExecute(Task task, ReprovisionWorkflowRequest request, ActionLi /** * Execute the reprovision request * @param request the reprovision request + * @param tenantId * @param listener the action listener * @param context the thread context */ private void executeReprovisionRequest( ReprovisionWorkflowRequest request, + String tenantId, ActionListener listener, ThreadContext.StoredContext context ) { String workflowId = request.getWorkflowId(); logger.info("Querying state for workflow: {}", workflowId); // Retrieve state and resources created - GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); + GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true, tenantId); client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { context.restore(); @@ -199,7 +213,7 @@ private void executeReprovisionRequest( provisionWorkflow, request.getWorkflowId(), Collections.emptyMap(), // TODO : Add suport to reprovision substitution templates - "fakeTenantId" + tenantId ); try { @@ -216,25 +230,34 @@ private void executeReprovisionRequest( workflowId, originalTemplate, updatedTemplate, - resourceCreated + resourceCreated, + tenantId ); // Remove error field if any prior to subsequent execution if (response.getWorkflowState().getError() != null) { WorkflowState newState = WorkflowState.builder(response.getWorkflowState()).error(null).build(); - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc(workflowId, newState, ActionListener.wrap(updateResponse -> { - - }, exception -> { - String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Failed to update workflow state: {}", workflowId) - .getFormattedMessage(); - logger.error(errorMessage, exception); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - })); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + workflowId, + tenantId, + newState, + ActionListener.wrap(updateResponse -> { + + }, exception -> { + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( + "Failed to update workflow state: {}", + workflowId + ).getFormattedMessage(); + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + }) + ); } // Update State Index, maintain resources created for subsequent execution flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( workflowId, + tenantId, Map.ofEntries( Map.entry(STATE_FIELD, State.PROVISIONING), Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.IN_PROGRESS), @@ -337,7 +360,15 @@ public void onFailure(Exception e) { WorkflowTimeoutUtility.handleFailure(workflowId, ex, isResponseSent, listener); } }, threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL)); - WorkflowTimeoutUtility.scheduleTimeoutHandler(client, threadPool, workflowId, listener, timeout, isResponseSent); + WorkflowTimeoutUtility.scheduleTimeoutHandler( + client, + threadPool, + workflowId, + template.getTenantId(), + listener, + timeout, + isResponseSent + ); } /** @@ -396,6 +427,7 @@ private void executeWorkflow( logger.info("Reprovisioning completed successfully for workflow {}", workflowId); flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( workflowId, + template.getTenantId(), Map.ofEntries( Map.entry(STATE_FIELD, State.COMPLETED), Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.DONE), @@ -407,7 +439,7 @@ private void executeWorkflow( if (isSyncExecution) { client.execute( GetWorkflowStateAction.INSTANCE, - new GetWorkflowStateRequest(workflowId, false), + new GetWorkflowStateRequest(workflowId, false, template.getTenantId()), ActionListener.wrap(response -> { listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())); }, exception -> { @@ -438,6 +470,7 @@ private void executeWorkflow( + status.toString(); flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( workflowId, + template.getTenantId(), Map.ofEntries( Map.entry(STATE_FIELD, State.FAILED), Map.entry(ERROR_FIELD, errorMessage), diff --git a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java index f20c57adb..138c55ea2 100644 --- a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java @@ -46,7 +46,13 @@ public SearchWorkflowStateTransportAction(TransportService transportService, Act @Override protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { try { - searchHandler.search(request, actionListener); + // We used the SearchRequest preference field to convey a tenant id if any + String tenantId = null; + if (request.preference() != null) { + tenantId = request.preference(); + request.preference(null); + } + searchHandler.search(request, tenantId, actionListener); } catch (Exception e) { String errorMessage = "Failed to search workflow states in global context"; logger.error(errorMessage, e); diff --git a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java index 46f0afb10..40c0a72e2 100644 --- a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java @@ -46,7 +46,13 @@ public SearchWorkflowTransportAction(TransportService transportService, ActionFi @Override protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { try { - searchHandler.search(request, actionListener); + // We used the SearchRequest preference field to convey a tenant id if any + String tenantId = null; + if (request.preference() != null) { + tenantId = request.preference(); + request.preference(null); + } + searchHandler.search(request, tenantId, actionListener); } catch (Exception e) { String errorMessage = "Failed to search workflows in global context"; logger.error(errorMessage, e); diff --git a/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java b/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java index 512b0bea2..5b914fb4a 100644 --- a/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java +++ b/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java @@ -19,9 +19,16 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; +import java.util.Arrays; + +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.flowframework.util.ParseUtils.isAdmin; import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; @@ -31,6 +38,7 @@ public class SearchHandler { private final Logger logger = LogManager.getLogger(SearchHandler.class); private final Client client; + private final SdkClient sdkClient; private volatile Boolean filterByBackendRole; /** @@ -38,10 +46,18 @@ public class SearchHandler { * @param settings settings * @param clusterService cluster service * @param client The node client to retrieve a stored use case template + * @param sdkClient The multitenant client * @param filterByBackendRoleSetting filter role backend settings */ - public SearchHandler(Settings settings, ClusterService clusterService, Client client, Setting filterByBackendRoleSetting) { + public SearchHandler( + Settings settings, + ClusterService clusterService, + Client client, + SdkClient sdkClient, + Setting filterByBackendRoleSetting + ) { this.client = client; + this.sdkClient = sdkClient; filterByBackendRole = filterByBackendRoleSetting.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSetting, it -> filterByBackendRole = it); } @@ -49,16 +65,17 @@ public SearchHandler(Settings settings, ClusterService clusterService, Client cl /** * Search workflows in global context * @param request SearchRequest + * @param tenantId the tenant ID * @param actionListener ActionListener */ - public void search(SearchRequest request, ActionListener actionListener) { + public void search(SearchRequest request, String tenantId, ActionListener actionListener) { // AccessController should take care of letting the user with right permission to view the workflow User user = ParseUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { logger.info("Searching workflows in global context"); SearchSourceBuilder searchSourceBuilder = request.source(); searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder)); - validateRole(request, user, actionListener, context); + validateRole(request, tenantId, user, actionListener, context); } catch (Exception e) { logger.error("Failed to search workflows in global context", e); actionListener.onFailure(e); @@ -68,12 +85,14 @@ public void search(SearchRequest request, ActionListener actionL /** * Validate user role and call search * @param request SearchRequest + * @param tenantId the tenant id * @param user User * @param listener ActionListener * @param context ThreadContext */ public void validateRole( SearchRequest request, + String tenantId, User user, ActionListener listener, ThreadContext.StoredContext context @@ -83,16 +102,40 @@ public void validateRole( // Case 2: If Security is enabled and filter is disabled, proceed with search as // user is already authenticated to hit this API. // case 3: user is admin which means we don't have to check backend role filtering - client.search(request, ActionListener.runBefore(listener, context::restore)); + doSearch(request, tenantId, ActionListener.runBefore(listener, context::restore)); } else { // Security is enabled, filter is enabled and user isn't admin try { ParseUtils.addUserBackendRolesFilter(user, request.source()); logger.debug("Filtering result by {}", user.getBackendRoles()); - client.search(request, ActionListener.runBefore(listener, context::restore)); + doSearch(request, tenantId, ActionListener.runBefore(listener, context::restore)); } catch (Exception e) { listener.onFailure(e); } } } + + private void doSearch(SearchRequest request, String tenantId, ActionListener listener) { + SearchDataObjectRequest searchRequest = SearchDataObjectRequest.builder() + .indices(request.indices()) + .tenantId(tenantId) + .searchSourceBuilder(request.source()) + .build(); + sdkClient.searchDataObjectAsync(searchRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); + logger.info(Arrays.toString(request.indices()) + " search complete: {}", searchResponse.getHits().getTotalHits()); + listener.onResponse(searchResponse); + } catch (Exception e) { + logger.error("Failed to parse search response", e); + listener.onFailure(new FlowFrameworkException("Failed to parse search response", INTERNAL_SERVER_ERROR)); + } + } else { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + logger.error("Search failed for indices: {}", Arrays.toString(request.indices()), cause); + listener.onFailure(cause); + } + }); + } } diff --git a/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java index 02daeaff0..52a7d2251 100644 --- a/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java @@ -11,29 +11,33 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Nullable; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.Config; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.PutDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.common.SdkClientUtils; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import javax.crypto.spec.SecretKeySpec; +import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.time.Instant; import java.util.ArrayList; @@ -41,7 +45,12 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Function; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.BiFunction; import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.CommitmentPolicy; @@ -66,38 +75,46 @@ public class EncryptorUtils { // https://github.com/aws/aws-encryption-sdk-java/issues/1879 private static final String WRAPPING_ALGORITHM = "AES/GCM/NOPADDING"; + // concurrent map can't have null as a key. This key is to support single tenancy + public static final String DEFAULT_TENANT_ID = ""; + private final ClusterService clusterService; private final Client client; - private String masterKey; + private final SdkClient sdkClient; + private final Map tenantMasterKeys; private final NamedXContentRegistry xContentRegistry; /** * Instantiates a new EncryptorUtils object * @param clusterService the cluster service * @param client the node client + * @param sdkClient the Multitenant Client * @param xContentRegistry the OpenSearch XContent Registry */ - public EncryptorUtils(ClusterService clusterService, Client client, NamedXContentRegistry xContentRegistry) { - this.masterKey = null; + public EncryptorUtils(ClusterService clusterService, Client client, SdkClient sdkClient, NamedXContentRegistry xContentRegistry) { + this.tenantMasterKeys = new ConcurrentHashMap<>(); this.clusterService = clusterService; this.client = client; + this.sdkClient = sdkClient; this.xContentRegistry = xContentRegistry; } /** * Sets the master key + * @param tenantId The tenant id. If null, sets the key for the default id. * @param masterKey the master key */ - void setMasterKey(String masterKey) { - this.masterKey = masterKey; + void setMasterKey(@Nullable String tenantId, String masterKey) { + this.tenantMasterKeys.put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), masterKey); } /** * Returns the master key + * @param tenantId The tenant id. If null, gets the key for the default id. * @return the master key */ - String getMasterKey() { - return this.masterKey; + String getMasterKey(@Nullable String tenantId) { + return tenantMasterKeys.get(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID)); } /** @@ -135,7 +152,7 @@ public Template decryptTemplateCredentials(Template template) { * @param cipherFunction the encryption/decryption function to apply on credential values * @return template with encrypted credentials */ - private Template processTemplateCredentials(Template template, Function cipherFunction) { + private Template processTemplateCredentials(Template template, BiFunction cipherFunction) { Map processedWorkflows = new HashMap<>(); for (Map.Entry entry : template.workflows().entrySet()) { @@ -145,7 +162,7 @@ private Template processTemplateCredentials(Template template, Function credentials = new HashMap<>((Map) node.userInputs().get(CREDENTIAL_FIELD)); - credentials.replaceAll((key, cred) -> cipherFunction.apply(cred)); + credentials.replaceAll((key, cred) -> cipherFunction.apply(cred, template.getTenantId())); // Replace credentials field in node user inputs Map processedUserInputs = new HashMap<>(); @@ -175,12 +192,23 @@ private Template processTemplateCredentials(Template template, Function latch.countDown()); + try { + if (!latch.await(5, TimeUnit.SECONDS)) { + throw new FlowFrameworkException("Timeout while initializing master key", RestStatus.INTERNAL_SERVER_ERROR); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FlowFrameworkException("Interrupted while initializing master key", RestStatus.REQUEST_TIMEOUT); + } + final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); - byte[] bytes = Base64.getDecoder().decode(masterKey); + byte[] bytes = Base64.getDecoder().decode(getMasterKey(tenantId)); JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, ALGORITHM), PROVIDER, "", WRAPPING_ALGORITHM); final CryptoResult encryptResult = crypto.encryptData( jceMasterKey, @@ -192,13 +220,22 @@ String encrypt(final String credential) { /** * Decrypts the given credential * @param encryptedCredential the credential to decrypt + * @param tenantId The tenant id. If null, decrypts for the default tenant id. * @return the decrypted credential */ - String decrypt(final String encryptedCredential) { - initializeMasterKeyIfAbsent(); + String decrypt(final String encryptedCredential, @Nullable String tenantId) { + CountDownLatch latch = new CountDownLatch(1); + initializeMasterKeyIfAbsent(tenantId).whenComplete((v, throwable) -> latch.countDown()); + try { + if (!latch.await(5, TimeUnit.SECONDS)) { + throw new FlowFrameworkException("Timeout while initializing master key", RestStatus.INTERNAL_SERVER_ERROR); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FlowFrameworkException("Interrupted while initializing master key", RestStatus.REQUEST_TIMEOUT); + } final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); - - byte[] bytes = Base64.getDecoder().decode(masterKey); + byte[] bytes = Base64.getDecoder().decode(getMasterKey(tenantId)); JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, ALGORITHM), PROVIDER, "", WRAPPING_ALGORITHM); final CryptoResult decryptedResult = crypto.decryptData( @@ -248,99 +285,146 @@ public Template redactTemplateSecuredFields(User user, Template template) { /** * Retrieves an existing master key or generates a new key to index + * @param tenantId The tenant id. If null, initializes the key for the default tenant id. * @param listener the action listener */ - public void initializeMasterKey(ActionListener listener) { - // Index has either been created or it already exists, need to check if master key has been initalized already, if not then - // generate - // This is necessary in case of global context index restoration from snapshot, will need to use the same master key to decrypt - // stored credentials - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - // Using the master_key string as the document id - GetRequest getRequest = new GetRequest(CONFIG_INDEX).id(MASTER_KEY); - client.get(getRequest, ActionListener.wrap(getResponse -> { - if (!getResponse.isExists()) { - Config config = new Config(generateMasterKey(), Instant.now()); - IndexRequest masterKeyIndexRequest = new IndexRequest(CONFIG_INDEX).id(MASTER_KEY) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - masterKeyIndexRequest.source(config.toXContent(builder, ToXContent.EMPTY_PARAMS)); - } - client.index(masterKeyIndexRequest, ActionListener.wrap(indexResponse -> { - context.restore(); - // Set generated key to master - logger.info("Config has been initialized successfully"); - this.masterKey = config.masterKey(); - listener.onResponse(true); - }, indexException -> { - logger.error("Failed to index config", indexException); - listener.onFailure(indexException); - })); + public void initializeMasterKey(@Nullable String tenantId, ActionListener listener) { + // Config index has already been created or verified + cacheMasterKeyFromConfigIndex(tenantId).thenApply(v -> { + // Key exists and has been cached successfully + listener.onResponse(true); + return null; + }).exceptionally(throwable -> { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + // The cacheMasterKey method only completes exceptionally with FFE + if (exception instanceof FlowFrameworkException) { + FlowFrameworkException ffe = (FlowFrameworkException) exception; + if (ffe.status() == RestStatus.NOT_FOUND) { + // Key doesn't exist, need to generate and index a new one + generateAndIndexNewMasterKey(tenantId, listener); + } else { + listener.onFailure(ffe); + } + } else { + // Shouldn't get here + listener.onFailure(exception); + } + return null; + }); + } + private void generateAndIndexNewMasterKey(String tenantId, ActionListener listener) { + String masterKeyId = tenantId == null ? MASTER_KEY : MASTER_KEY + "_" + hashString(tenantId); + Config config = new Config(generateMasterKey(), Instant.now()); + PutDataObjectRequest putRequest = PutDataObjectRequest.builder() + .index(CONFIG_INDEX) + .id(masterKeyId) + .tenantId(tenantId) + .dataObject(config) + .build(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + sdkClient.putDataObjectAsync(putRequest).whenComplete((r, throwable) -> { + context.restore(); + if (throwable == null) { + // Set generated key to master + logger.info("Config has been initialized successfully"); + setMasterKey(tenantId, config.masterKey()); + listener.onResponse(true); } else { - context.restore(); - // Set existing key to master - logger.debug("Config has already been initialized, fetching key"); - try ( - XContentParser parser = ParseUtils.createXContentParserFromRegistry( - xContentRegistry, - getResponse.getSourceAsBytesRef() - ) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Config config = Config.parse(parser); - this.masterKey = config.masterKey(); - listener.onResponse(true); - } catch (FlowFrameworkException e) { - listener.onFailure(e); - } + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + logger.error("Failed to index new master key in config for tenant id {}", tenantId, exception); + listener.onFailure(exception); } - }, getRequestException -> { - logger.error("Failed to search for config from config index", getRequestException); - listener.onFailure(getRequestException); - })); - - } catch (Exception e) { - logger.error("Failed to retrieve config from config index", e); - listener.onFailure(e); + }); } } /** - * Retrieves master key from system index if not yet set + * Called by encrypt and decrypt functions to retrieve master key from tenantMasterKeys map if set. If not, checks config system index (which must exist), fetches key and puts in tenantMasterKeys map. + * @param tenantId The tenant id. If null, initializes the key for the default id. + * @return a future that will complete when the key is initialized (or throws an exception) */ - void initializeMasterKeyIfAbsent() { - if (masterKey != null) { - return; + CompletableFuture initializeMasterKeyIfAbsent(@Nullable String tenantId) { + // Happy case, key already in map + if (this.tenantMasterKeys.containsKey(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID))) { + return CompletableFuture.completedFuture(null); } - + // Key not in map if (!clusterService.state().metadata().hasIndex(CONFIG_INDEX)) { - throw new FlowFrameworkException("Config Index has not been initialized", RestStatus.INTERNAL_SERVER_ERROR); - } else { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - GetRequest getRequest = new GetRequest(CONFIG_INDEX).id(MASTER_KEY); - client.get(getRequest, ActionListener.wrap(response -> { - context.restore(); - if (response.isExists()) { - try ( - XContentParser parser = ParseUtils.createXContentParserFromRegistry( - xContentRegistry, - response.getSourceAsBytesRef() - ) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Config config = Config.parse(parser); - this.masterKey = config.masterKey(); + return CompletableFuture.failedFuture( + new FlowFrameworkException("Config Index has not been initialized", RestStatus.INTERNAL_SERVER_ERROR) + ); + } + // Fetch from config index and store in map + return cacheMasterKeyFromConfigIndex(tenantId); + } + + private CompletableFuture cacheMasterKeyFromConfigIndex(String tenantId) { + // This method assumes the config index must exist + final CompletableFuture resultFuture = new CompletableFuture<>(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + FetchSourceContext fetchSourceContext = new FetchSourceContext(true); + String masterKeyId = tenantId == null ? MASTER_KEY : MASTER_KEY + "_" + hashString(tenantId); + sdkClient.getDataObjectAsync( + GetDataObjectRequest.builder() + .index(CONFIG_INDEX) + .id(masterKeyId) + .tenantId(tenantId) + .fetchSourceContext(fetchSourceContext) + .build() + ).whenComplete((r, throwable) -> { + context.restore(); + if (throwable == null) { + try { + GetResponse response = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); + if (response != null && response.isExists()) { + try ( + XContentParser parser = ParseUtils.createXContentParserFromRegistry( + xContentRegistry, + response.getSourceAsBytesRef() + ) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Config config = Config.parse(parser); + setMasterKey(tenantId, config.masterKey()); + resultFuture.complete(null); + } + } else { + resultFuture.completeExceptionally( + new FlowFrameworkException("Master key has not been initialized in config index", RestStatus.NOT_FOUND) + ); } - } else { - throw new FlowFrameworkException("Master key has not been initialized in config index", RestStatus.NOT_FOUND); + } catch (IOException e) { + logger.error("Failed to parse config index getResponse", e); + resultFuture.completeExceptionally( + new FlowFrameworkException("Failed to parse config index getResponse", RestStatus.INTERNAL_SERVER_ERROR) + ); } - }, - exception -> { - throw new FlowFrameworkException("Failed to get master key from config index", ExceptionsHelper.status(exception)); - } - )); - } + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + logger.error("Failed to get master key from config index", exception); + resultFuture.completeExceptionally( + new FlowFrameworkException("Failed to get master key from config index", ExceptionsHelper.status(exception)) + ); + } + }); + return resultFuture; + } + } + + private String hashString(String input) { + try { + // Create a MessageDigest instance for SHA-256 + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + + // Perform the hashing and get the byte array + byte[] hashBytes = digest.digest(input.getBytes(StandardCharsets.UTF_8)); + + // Convert the byte array to a Base64 encoded string + return Base64.getUrlEncoder().encodeToString(hashBytes); + + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("Error: Unable to compute hash", e); } } diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 7428249f5..3d6eea424 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -14,7 +14,6 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessageFactory; import org.apache.lucene.search.join.ScoreMode; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -37,7 +36,6 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.WorkflowState; -import org.opensearch.flowframework.transport.WorkflowResponse; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.NestedQueryBuilder; @@ -45,6 +43,9 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; import java.io.FileNotFoundException; @@ -285,40 +286,50 @@ public static SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSou * Resolve user and execute the function * @param requestedUser the user to execute the request * @param workflowId workflow id + * @param tenantId tenant id * @param filterByEnabled filter by enabled setting * @param statePresent state present for the transport action + * @param isMultitenancyEnabled whether multitenancy is enabled * @param listener action listener * @param function workflow function * @param client node client + * @param sdkClient multitenant client * @param clusterService cluster service * @param xContentRegistry contentRegister to parse get response */ public static void resolveUserAndExecute( User requestedUser, String workflowId, + String tenantId, Boolean filterByEnabled, Boolean statePresent, + boolean isMultitenancyEnabled, ActionListener listener, Runnable function, Client client, + SdkClient sdkClient, ClusterService clusterService, NamedXContentRegistry xContentRegistry ) { try { - if (requestedUser == null || filterByEnabled == Boolean.FALSE) { + if (!isMultitenancyEnabled && (requestedUser == null || filterByEnabled == Boolean.FALSE)) { // requestedUser == null means security is disabled or user is superadmin. In this case we don't need to - // check if request user have access to the workflow or not. + // check if request user have access to the workflow or not unless we have multitenancy // !filterByEnabled means security is enabled and filterByEnabled is disabled function.run(); } else { + // we need to validate either user access, multitenancy, or both, which requires getting the workflow getWorkflow( requestedUser, workflowId, + tenantId, filterByEnabled, statePresent, + isMultitenancyEnabled, listener, function, client, + sdkClient, clusterService, xContentRegistry ); @@ -381,49 +392,62 @@ public static void checkFilterByBackendRoles(User requestedUser) { * Get workflow * @param requestUser the user to execute the request * @param workflowId workflow id + * @param tenantId tenant id * @param filterByEnabled filter by enabled setting * @param statePresent state present for the transport action + * @param isMultitenancyEnabled if multi tenancy is enabled * @param listener action listener * @param function workflow function * @param client node client + * @param sdkClient the tenant aware client * @param clusterService cluster service * @param xContentRegistry contentRegister to parse get response */ public static void getWorkflow( User requestUser, String workflowId, + String tenantId, Boolean filterByEnabled, Boolean statePresent, - ActionListener listener, + boolean isMultitenancyEnabled, + ActionListener listener, Runnable function, Client client, + SdkClient sdkClient, ClusterService clusterService, NamedXContentRegistry xContentRegistry ) { String index = statePresent ? WORKFLOW_STATE_INDEX : GLOBAL_CONTEXT_INDEX; if (clusterService.state().metadata().hasIndex(index)) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - GetRequest request = new GetRequest(index, workflowId); - client.get( - request, - ActionListener.wrap( - response -> onGetWorkflowResponse( - response, - requestUser, - workflowId, - filterByEnabled, - statePresent, - listener, - function, - xContentRegistry, - context - ), - exception -> { - logger.error("Failed to get workflow: {}", workflowId, exception); - listener.onFailure(exception); + GetDataObjectRequest request = GetDataObjectRequest.builder().index(index).id(workflowId).tenantId(tenantId).build(); + sdkClient.getDataObjectAsync(request).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + GetResponse getResponse = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); + onGetWorkflowResponse( + getResponse, + requestUser, + workflowId, + tenantId, + filterByEnabled, + statePresent, + isMultitenancyEnabled, + listener, + function, + xContentRegistry, + context + ); + } catch (IOException e) { + logger.error("Failed to parse workflow getResponse: {}", workflowId, e); + listener.onFailure(e); } - ) - ); + } else { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + logger.error("Failed to get workflow: {}", workflowId, exception); + listener.onFailure(exception); + } + }); } } else { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Failed to retrieve template ({}).", workflowId) @@ -435,11 +459,13 @@ public static void getWorkflow( /** * Execute the function if user has permissions to access the resource - * @param requestUser the user to execute the request * @param response get response + * @param requestUser the user to execute the request * @param workflowId workflow id + * @param tenantId tenant id * @param filterByEnabled filter by enabled setting * @param statePresent state present for the transport action + * @param isMultitenancyEnabled if multi tenancy is enabled * @param listener action listener * @param function workflow function * @param xContentRegistry contentRegister to parse get response @@ -449,9 +475,11 @@ public static void onGetWorkflowResponse( GetResponse response, User requestUser, String workflowId, + String tenantId, Boolean filterByEnabled, Boolean statePresent, - ActionListener listener, + boolean isMultitenancyEnabled, + ActionListener listener, Runnable function, NamedXContentRegistry xContentRegistry, ThreadContext.StoredContext context @@ -462,8 +490,20 @@ public static void onGetWorkflowResponse( ) { context.restore(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - User resourceUser = statePresent ? WorkflowState.parse(parser).getUser() : Template.parse(parser).getUser(); - + User resourceUser; + if (statePresent) { + WorkflowState state = WorkflowState.parse(parser); + resourceUser = state.getUser(); + if (!TenantAwareHelper.validateTenantResource(isMultitenancyEnabled, tenantId, state.getTenantId(), listener)) { + return; + } + } else { + Template template = Template.parse(parser); + resourceUser = template.getUser(); + if (!TenantAwareHelper.validateTenantResource(isMultitenancyEnabled, tenantId, template.getTenantId(), listener)) { + return; + } + } if (!filterByEnabled || checkUserPermissions(requestUser, resourceUser, workflowId) || isAdmin(requestUser)) { function.run(); } else { diff --git a/src/main/java/org/opensearch/flowframework/util/TenantAwareHelper.java b/src/main/java/org/opensearch/flowframework/util/TenantAwareHelper.java new file mode 100644 index 000000000..eb78bc02f --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/util/TenantAwareHelper.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.CommonValue; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.rest.RestRequest; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * Helper class for tenant ID validation + */ +public class TenantAwareHelper { + + /** + * Validates the tenant ID based on the multi-tenancy feature setting. + * + * @param isMultiTenancyEnabled whether the multi-tenancy feature is enabled. + * @param tenantId The tenant ID to validate. + * @param listener The action listener to handle failure cases. + * @return true if the tenant ID is valid or if multi-tenancy is not enabled; false if the tenant ID is invalid and multi-tenancy is enabled. + */ + public static boolean validateTenantId(boolean isMultiTenancyEnabled, String tenantId, ActionListener listener) { + if (isMultiTenancyEnabled && tenantId == null) { + listener.onFailure(new FlowFrameworkException("No permission to access this resource", RestStatus.FORBIDDEN)); + return false; + } else { + return true; + } + } + + /** + * Validates the tenant resource by comparing the tenant ID from the request with the tenant ID from the resource. + * + * @param isMultiTenancyEnabled whether the multi-tenancy feature is enabled. + * @param tenantIdFromRequest The tenant ID obtained from the request. + * @param tenantIdFromResource The tenant ID obtained from the resource. + * @param listener The action listener to handle failure cases. + * @return true if the tenant IDs match or if multi-tenancy is not enabled; false if the tenant IDs do not match and multi-tenancy is enabled. + */ + public static boolean validateTenantResource( + boolean isMultiTenancyEnabled, + String tenantIdFromRequest, + String tenantIdFromResource, + ActionListener listener + ) { + if (isMultiTenancyEnabled && !Objects.equals(tenantIdFromRequest, tenantIdFromResource)) { + listener.onFailure(new FlowFrameworkException("No permission to access this resource", RestStatus.FORBIDDEN)); + return false; + } else return true; + } + + /** + * Finds the tenant id in the REST Headers + * @param isMultiTenancyEnabled whether multitenancy is enabled + * @param restRequest the RestRequest + * @return The tenant ID from the headers or null if multitenancy is not enabled + */ + public static String getTenantID(Boolean isMultiTenancyEnabled, RestRequest restRequest) { + if (!isMultiTenancyEnabled) { + return null; + } + + Map> headers = restRequest.getHeaders(); + + List tenantIdList = headers.get(CommonValue.TENANT_ID_HEADER); + if (tenantIdList == null || tenantIdList.isEmpty()) { + throw new FlowFrameworkException("Tenant ID header is missing or has no value", RestStatus.FORBIDDEN); + } + + String tenantId = tenantIdList.get(0); + if (tenantId == null) { + throw new FlowFrameworkException("Tenant ID can't be null", RestStatus.FORBIDDEN); + } + + return tenantId; + } +} diff --git a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java index cbed72b3d..2b2e7cc46 100644 --- a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java +++ b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java @@ -40,6 +40,7 @@ public class WorkflowTimeoutUtility { * @param client The OpenSearch client used to interact with the cluster. * @param threadPool The thread pool to schedule the timeout task. * @param workflowId The unique identifier of the workflow being executed. + * @param tenantId The tenant id. * @param listener The listener to notify when the task completes or times out. * @param timeout The timeout duration in milliseconds. * @param isResponseSent An atomic boolean to ensure the response is sent only once. @@ -49,6 +50,7 @@ public static ActionListener scheduleTimeoutHandler( Client client, ThreadPool threadPool, final String workflowId, + final String tenantId, ActionListener listener, long timeout, AtomicBoolean isResponseSent @@ -56,7 +58,7 @@ public static ActionListener scheduleTimeoutHandler( // Ensure timeout is within the valid range (non-negative) long adjustedTimeout = Math.max(timeout, TimeValue.timeValueMillis(0).millis()); Scheduler.ScheduledCancellable scheduledCancellable = threadPool.schedule( - new WorkflowTimeoutListener(client, workflowId, listener, isResponseSent), + new WorkflowTimeoutListener(client, workflowId, tenantId, listener, isResponseSent), TimeValue.timeValueMillis(adjustedTimeout), PROVISION_WORKFLOW_THREAD_POOL ); @@ -70,12 +72,20 @@ public static ActionListener scheduleTimeoutHandler( private static class WorkflowTimeoutListener implements Runnable { private final Client client; private final String workflowId; + private final String tenantId; private final ActionListener listener; private final AtomicBoolean isResponseSent; - WorkflowTimeoutListener(Client client, String workflowId, ActionListener listener, AtomicBoolean isResponseSent) { + WorkflowTimeoutListener( + Client client, + String workflowId, + String tenantId, + ActionListener listener, + AtomicBoolean isResponseSent + ) { this.client = client; this.workflowId = workflowId; + this.tenantId = tenantId; this.listener = listener; this.isResponseSent = isResponseSent; } @@ -85,7 +95,7 @@ public void run() { // This AtomicBoolean ensures that the timeout logic is executed only once, preventing duplicate responses. if (isResponseSent.compareAndSet(false, true)) { logger.warn("Workflow execution timed out for workflowId: {}", workflowId); - fetchWorkflowStateAfterTimeout(client, workflowId, listener); + fetchWorkflowStateAfterTimeout(client, workflowId, tenantId, listener); } } } @@ -180,17 +190,19 @@ public static void handleFailure( * * @param client The OpenSearch client used to fetch the workflow state. * @param workflowId The unique identifier of the workflow. + * @param tenantId The tenant id * @param listener The listener to notify with the updated state or failure. */ public static void fetchWorkflowStateAfterTimeout( final Client client, final String workflowId, + final String tenantId, final ActionListener listener ) { logger.info("Fetching workflow state after timeout"); client.execute( GetWorkflowStateAction.INSTANCE, - new GetWorkflowStateRequest(workflowId, false), + new GetWorkflowStateRequest(workflowId, false, tenantId), ActionListener.wrap( response -> listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())), exception -> listener.onFailure( diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java index e779ddf4f..d1916dc85 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java @@ -105,6 +105,7 @@ public void onResponse(AcknowledgedResponse acknowledgedResponse) { currentNodeId, getName(), pipelineId, + tenantId, createPipelineFuture ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java index 79adb1e1c..60f714443 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java @@ -195,6 +195,7 @@ public PlainActionFuture execute( registerLocalModelFuture, taskId, "Local model registration", + tenantId, ActionListener.wrap(mlTaskWorkflowData -> { // Registered Model Resource has been updated String resourceName = getResourceByWorkflowStep(getName()); @@ -219,6 +220,7 @@ public PlainActionFuture execute( currentNodeId, DeployModelStep.NAME, id, + tenantId, deployUpdateListener ); } else { diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java index bf6585dc8..c251e44cb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java @@ -68,6 +68,7 @@ protected AbstractRetryableWorkflowStep( * @param future the workflow step future * @param taskId the ml task id * @param workflowStep the workflow step which requires a retry get ml task functionality + * @param tenantId the tenant ID * @param mlTaskListener the ML Task Listener */ protected void retryableGetMlTask( @@ -76,6 +77,7 @@ protected void retryableGetMlTask( PlainActionFuture future, String taskId, String workflowStep, + String tenantId, ActionListener mlTaskListener ) { CompletableFuture.runAsync(() -> { @@ -91,7 +93,14 @@ protected void retryableGetMlTask( content.put(REGISTER_MODEL_STATUS, response.getState().toString()); mlTaskListener.onResponse(new WorkflowData(content, r.getWorkflowId(), r.getNodeId())); }, mlTaskListener::onFailure); - flowFrameworkIndicesHandler.addResourceToStateIndex(currentNodeInputs, nodeId, getName(), id, resourceListener); + flowFrameworkIndicesHandler.addResourceToStateIndex( + currentNodeInputs, + nodeId, + getName(), + id, + tenantId, + resourceListener + ); break; case FAILED: case COMPLETED_WITH_ERROR: diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index bccf636d8..74dc52034 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -91,6 +91,7 @@ public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { currentNodeId, getName(), mlCreateConnectorResponse.getConnectorId(), + tenantId, createConnectorFuture ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 14cbb0736..7f9233f7b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -116,6 +116,7 @@ public PlainActionFuture execute( currentNodeId, getName(), indexName, + tenantId, createIndexFuture ); }, ex -> { diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 46ee3a24b..2d331d7e7 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -99,6 +99,7 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { deployModelFuture, taskId, "Deploy model", + tenantId, ActionListener.wrap( deployModelFuture::onResponse, e -> deployModelFuture.onFailure( diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index 45cc08bdc..0f5c2b501 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -101,6 +101,7 @@ public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { currentNodeId, getName(), mlRegisterAgentResponse.getAgentId(), + tenantId, registerAgentModelFuture ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java index c694eabd8..871e1f24d 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java @@ -89,6 +89,7 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse currentNodeId, getName(), mlRegisterModelGroupResponse.getModelGroupId(), + tenantId, resourceListener ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 7842a0659..183a757e9 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -168,6 +168,7 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { currentNodeId, getName(), mlRegisterModelResponse.getModelId(), + tenantId, registerUpdateListener ); } @@ -185,6 +186,7 @@ private void updateDeployResource(String resourceName, MLRegisterModelResponse m currentNodeId, DeployModelStep.NAME, mlRegisterModelResponse.getModelId(), + tenantId, deployUpdateListener ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index d0abb8c6e..2f64a2683 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -156,6 +156,7 @@ public List sortProcessNodes(Workflow workflow, String workflowId, * @param originalTemplate the original template currently indexed * @param updatedTemplate the updated template to be executed * @param resourcesCreated the resources previously created for the workflow + * @param tenantId the tenant id * @throws Exception for issues creating the reprovision sequence * @return A list of ProcessNode */ @@ -163,7 +164,8 @@ public List createReprovisionSequence( String workflowId, Template originalTemplate, Template updatedTemplate, - List resourcesCreated + List resourcesCreated, + String tenantId ) throws Exception { Workflow updatedWorkflow = updatedTemplate.workflows().get(PROVISION_WORKFLOW); @@ -205,7 +207,8 @@ public List createReprovisionSequence( updatedWorkflow, sortedUpdatedNodes, originalTemplateMap, - resourcesCreated + resourcesCreated, + tenantId ); // If the reprovision sequence consists entirely of WorkflowDataSteps, then no modifications were made to the exisiting template. @@ -223,6 +226,7 @@ public List createReprovisionSequence( * @param sortedUpdatedNodes the topologically sorted updated template nodes * @param originalTemplateMap a map of node Id to workflow node of the original template * @param resourcesCreated a list of resources created for this template + * @param tenantId the tenant id * @return a list of process node representing the reprovision sequence * @throws Exception for issues creating the reprovision sequence */ @@ -231,7 +235,8 @@ private List createReprovisionSequence( Workflow updatedWorkflow, List sortedUpdatedNodes, Map originalTemplateMap, - List resourcesCreated + List resourcesCreated, + String tenantId ) throws Exception { Map idToNodeMap = new HashMap<>(); List reprovisionSequence = new ArrayList<>(); @@ -243,7 +248,8 @@ private List createReprovisionSequence( originalTemplateMap, resourcesCreated, workflowId, - idToNodeMap + idToNodeMap, + tenantId ); if (processNode != null) { idToNodeMap.put(processNode.id(), processNode); @@ -262,6 +268,7 @@ private List createReprovisionSequence( * @param resourcesCreated a list of resources created for this template * @param workflowId the workflow ID associated with the template * @param idToNodeMap a map of the current reprovision sequence + * @param tenantId the tenant id * @return a ProcessNode * @throws Exception for issues creating the process node */ @@ -271,7 +278,8 @@ private ProcessNode createProcessNode( Map originalTemplateMap, List resourcesCreated, String workflowId, - Map idToNodeMap + Map idToNodeMap, + String tenantId ) throws Exception { WorkflowData data = new WorkflowData(node.userInputs(), updatedWorkflow.userParams(), workflowId, node.id()); List predecessorNodes = updatedWorkflow.edges() @@ -284,15 +292,15 @@ private ProcessNode createProcessNode( if (!originalTemplateMap.containsKey(node.id())) { // Case 1: Additive modification, create new node - return createNewProcessNode(node, data, predecessorNodes, nodeTimeout); + return createNewProcessNode(node, data, predecessorNodes, nodeTimeout, tenantId); } else { WorkflowNode originalNode = originalTemplateMap.get(node.id()); if (shouldUpdateNode(node, originalNode)) { // Case 2: Existing modification, create update step - return createUpdateProcessNode(node, data, predecessorNodes, nodeTimeout); + return createUpdateProcessNode(node, data, predecessorNodes, nodeTimeout, tenantId); } else { // Case 4: No modification to existing node, create proxy step - return createWorkflowDataStepNode(node, data, predecessorNodes, nodeTimeout, resourcesCreated); + return createWorkflowDataStepNode(node, data, predecessorNodes, nodeTimeout, resourcesCreated, tenantId); } } } @@ -303,13 +311,15 @@ private ProcessNode createProcessNode( * @param data the current node data * @param predecessorNodes the current node predecessors * @param nodeTimeout the current node timeout + * @param tenantId the tenant id * @return a Process Node */ private ProcessNode createNewProcessNode( WorkflowNode node, WorkflowData data, List predecessorNodes, - TimeValue nodeTimeout + TimeValue nodeTimeout, + String tenantId ) { WorkflowStep step = workflowStepFactory.createStep(node.type()); return new ProcessNode( @@ -322,7 +332,7 @@ private ProcessNode createNewProcessNode( threadPool, PROVISION_WORKFLOW_THREAD_POOL, nodeTimeout, - "fakeTenantId" + tenantId ); } @@ -332,6 +342,7 @@ private ProcessNode createNewProcessNode( * @param data the current node data * @param predecessorNodes the current node predecessors * @param nodeTimeout the current node timeout + * @param tenantId the tenant id * @return a ProcessNode * @throws FlowFrameworkException if the current node does not support updates */ @@ -339,7 +350,8 @@ private ProcessNode createUpdateProcessNode( WorkflowNode node, WorkflowData data, List predecessorNodes, - TimeValue nodeTimeout + TimeValue nodeTimeout, + String tenantId ) throws FlowFrameworkException { String updateStepName = WorkflowResources.getUpdateStepByWorkflowStep(node.type()); if (updateStepName != null) { @@ -354,7 +366,7 @@ private ProcessNode createUpdateProcessNode( threadPool, PROVISION_WORKFLOW_THREAD_POOL, nodeTimeout, - "fakeTenantId" + tenantId ); } else { // Case 3 : Cannot update step (not supported) @@ -372,6 +384,7 @@ private ProcessNode createUpdateProcessNode( * @param predecessorNodes the current node predecessors * @param nodeTimeout the current node timeout * @param resourcesCreated the list of resources created for the template assoicated with this node + * @param tenantId the tenant id * @return a Process node */ private ProcessNode createWorkflowDataStepNode( @@ -379,7 +392,8 @@ private ProcessNode createWorkflowDataStepNode( WorkflowData data, List predecessorNodes, TimeValue nodeTimeout, - List resourcesCreated + List resourcesCreated, + String tenantId ) { ResourceCreated nodeResource = resourcesCreated.stream() .filter(rc -> rc.workflowStepId().equals(node.id())) @@ -397,7 +411,7 @@ private ProcessNode createWorkflowDataStepNode( threadPool, PROVISION_WORKFLOW_THREAD_POOL, nodeTimeout, - "fakeTenantId" + tenantId ); } else { return null; diff --git a/src/main/resources/mappings/global-context.json b/src/main/resources/mappings/global-context.json index 544b4a9af..b79ba07df 100644 --- a/src/main/resources/mappings/global-context.json +++ b/src/main/resources/mappings/global-context.json @@ -1,7 +1,7 @@ { "dynamic": false, "_meta": { - "schema_version": 3 + "schema_version": 4 }, "properties": { "workflow_id": { @@ -86,6 +86,9 @@ "type": "date", "format": "strict_date_time||epoch_millis" }, + "tenant_id": { + "type": "keyword" + }, "ui_metadata": { "type": "object", "enabled": false diff --git a/src/main/resources/mappings/workflow-state.json b/src/main/resources/mappings/workflow-state.json index ecb635413..c9627d670 100644 --- a/src/main/resources/mappings/workflow-state.json +++ b/src/main/resources/mappings/workflow-state.json @@ -1,7 +1,7 @@ { "dynamic": false, "_meta": { - "schema_version": 3 + "schema_version": 4 }, "properties": { "schema_version": { @@ -84,6 +84,9 @@ "type": "keyword" } } + }, + "tenant_id": { + "type": "keyword" } } } diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 65b0a9a75..bec97becc 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -30,8 +30,13 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_MULTI_TENANCY_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.REMOTE_METADATA_ENDPOINT; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.REMOTE_METADATA_REGION; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.REMOTE_METADATA_SERVICE_NAME; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.REMOTE_METADATA_TYPE; import static org.opensearch.flowframework.common.FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION; import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT; import static org.mockito.Mockito.mock; @@ -59,6 +64,7 @@ public void setUp() throws Exception { when(client.admin()).thenReturn(adminClient); when(adminClient.cluster()).thenReturn(clusterAdminClient); threadPool = new TestThreadPool(FlowFrameworkPluginTests.class.getName()); + when(client.threadPool()).thenReturn(threadPool); environment = mock(Environment.class); settings = Settings.builder().build(); @@ -72,7 +78,12 @@ public void setUp() throws Exception { MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, TASK_REQUEST_RETRY_DURATION, - FILTER_BY_BACKEND_ROLES + FILTER_BY_BACKEND_ROLES, + FLOW_FRAMEWORK_MULTI_TENANCY_ENABLED, + REMOTE_METADATA_TYPE, + REMOTE_METADATA_ENDPOINT, + REMOTE_METADATA_REGION, + REMOTE_METADATA_SERVICE_NAME ) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); @@ -89,13 +100,13 @@ public void tearDown() throws Exception { public void testPlugin() throws IOException { try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { assertEquals( - 6, + 7, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); assertEquals(9, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); assertEquals(10, ffp.getActions().size()); assertEquals(3, ffp.getExecutorBuilders(settings).size()); - assertEquals(6, ffp.getSettings().size()); + assertEquals(11, ffp.getSettings().size()); Collection systemIndexDescriptors = ffp.getSystemIndexDescriptors(Settings.EMPTY); assertEquals(3, systemIndexDescriptors.size()); diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java index 038574a30..9da2d4d18 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java @@ -11,6 +11,7 @@ import com.google.gson.JsonArray; import org.apache.commons.lang3.RandomStringUtils; import org.apache.http.Header; +import org.apache.http.HttpEntity; import org.apache.http.HttpHeaders; import org.apache.http.HttpHost; import org.apache.http.auth.AuthScope; @@ -141,6 +142,19 @@ protected String getProtocol() { return isHttps() ? "https" : "http"; } + public static Map responseToMap(Response response) throws IOException { + HttpEntity entity = response.getEntity(); + assertNotNull(response); + String entityString = TestHelpers.httpEntityToString(entity); + XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + entityString + ); + parser.nextToken(); + return parser.map(); + } + // Utility fn for deleting indices. Should only be used when not allowed in a regular context // (e.g., deleting system indices) protected static void deleteIndexWithAdminClient(String name) throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkTenantAwareRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkTenantAwareRestTestCase.java new file mode 100644 index 000000000..8113968ea --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkTenantAwareRestTestCase.java @@ -0,0 +1,211 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework; + +import org.apache.http.Header; +import org.apache.http.message.BasicHeader; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.rest.FakeRestRequest; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static org.opensearch.common.xcontent.XContentType.JSON; +import static org.opensearch.flowframework.common.CommonValue.TENANT_ID_HEADER; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_MULTI_TENANCY_ENABLED; + +public abstract class FlowFrameworkTenantAwareRestTestCase extends FlowFrameworkRestTestCase { + + // Toggle to run DDB tests + // TODO: Get this from a property + protected static final boolean DDB = false; + + protected static final String DOC_ID = "_id"; + + // REST methods + protected static final String POST = RestRequest.Method.POST.name(); + protected static final String GET = RestRequest.Method.GET.name(); + protected static final String PUT = RestRequest.Method.PUT.name(); + protected static final String DELETE = RestRequest.Method.DELETE.name(); + + // REST body + protected static final String MATCH_ALL_QUERY = "{\"query\":{\"match_all\":{}}}"; + protected static final String EMPTY_CONTENT = "{}"; + + // REST Response error reasons + protected static final String MISSING_TENANT_REASON = "Tenant ID header is missing or has no value"; + protected static final String NO_PERMISSION_REASON = "No permission to access this resource"; + + protected String tenantId = randomAlphaOfLength(5); + protected String otherTenantId = randomAlphaOfLength(6); + + protected final RestRequest tenantRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders( + Map.of(TENANT_ID_HEADER, singletonList(tenantId)) + ).build(); + protected final RestRequest otherTenantRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders( + Map.of(TENANT_ID_HEADER, singletonList(otherTenantId)) + ).build(); + protected final RestRequest nullTenantRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(emptyMap()) + .build(); + + protected final RestRequest tenantMatchAllRequest = getRestRequestWithHeadersAndContent(tenantId, MATCH_ALL_QUERY); + protected final RestRequest otherTenantMatchAllRequest = getRestRequestWithHeadersAndContent(otherTenantId, MATCH_ALL_QUERY); + protected final RestRequest nullTenantMatchAllRequest = getRestRequestWithHeadersAndContent(null, MATCH_ALL_QUERY); + + protected static boolean isMultiTenancyEnabled() throws IOException { + // pass -Dtests.rest.tenantaware=true on gradle command line to enable + return Boolean.parseBoolean(System.getProperty(FLOW_FRAMEWORK_MULTI_TENANCY_ENABLED.getKey())) + || Boolean.parseBoolean(System.getenv(FLOW_FRAMEWORK_MULTI_TENANCY_ENABLED.getKey())); + } + + protected static Response makeRequest(RestRequest request, String method, String path) throws IOException { + return TestHelpers.makeRequest( + client(), + method, + path, + request.params(), + request.content().utf8ToString(), + getHeadersFromRequest(request) + ); + } + + private static List
getHeadersFromRequest(RestRequest request) { + return request.getHeaders() + .entrySet() + .stream() + .map(e -> new BasicHeader(e.getKey(), e.getValue().stream().collect(Collectors.joining(",")))) + .collect(Collectors.toList()); + } + + protected static RestRequest getRestRequestWithHeadersAndContent(String tenantId, String requestContent) { + Map> headers = new HashMap<>(); + if (tenantId != null) { + headers.put(TENANT_ID_HEADER, singletonList(tenantId)); + } + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders(headers) + .withContent(new BytesArray(requestContent), JSON) + .build(); + } + + @SuppressWarnings("unchecked") + protected static String getErrorReasonFromResponseMap(Map map) { + // FlowFrameworkExceptions have a simple error field + if (map.get("error") instanceof String) { + return (String) map.get("error"); + } + + // OpenSearchStatusExceptions have different possibilities based on client + String type = ((Map) map.get("error")).get("type"); + + // { + // "error": { + // "root_cause": [ + // { + // "type": "status_exception", + // "reason": "You don't have permission to access this resource" + // } + // ], + // "type": "status_exception", + // "reason": "You don't have permission to access this resource" + // }, + // "status": 403 + // } + if ("status_exception".equals(type)) { + return ((Map) map.get("error")).get("reason"); + } + + // { + // "error": { + // "reason": "System Error", + // "details": "You don't have permission to access this resource", + // "type": "OpenSearchStatusException" + // }, + // "status": 403 + // } + return ((Map) map.get("error")).get("details"); + } + + protected static SearchResponse searchResponseFromResponse(Response response) throws IOException { + XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + TestHelpers.httpEntityToString(response.getEntity()).getBytes(UTF_8) + ); + return SearchResponse.fromXContent(parser); + } + + protected static void assertBadRequest(Response response) { + assertEquals(RestStatus.BAD_REQUEST.getStatus(), response.getStatusLine().getStatusCode()); + } + + protected static void assertNotFound(Response response) { + assertEquals(RestStatus.NOT_FOUND.getStatus(), response.getStatusLine().getStatusCode()); + } + + protected static void assertForbidden(Response response) { + assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode()); + } + + protected static void assertUnauthorized(Response response) { + assertEquals(RestStatus.UNAUTHORIZED.getStatus(), response.getStatusLine().getStatusCode()); + } + + protected static void assertOkOrAccepted(Response response) { + assertTrue(List.of(RestStatus.OK.getStatus(), RestStatus.ACCEPTED.getStatus()).contains(response.getStatusLine().getStatusCode())); + } + + /** + * Delete the specified document and wait until a search matches only the specified number of hits + * @param tenantId The tenant ID to filter the search by + * @param restPath The base path for the REST API + * @param id The document ID to be appended to the REST API for deletion + * @param hits The number of hits to expect after the deletion is processed + * @throws Exception on failures with building or making the request + */ + protected static void deleteAndWaitForSearch(String tenantId, String restPath, String id, int hits) throws Exception { + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders( + Map.of(TENANT_ID_HEADER, singletonList(tenantId)) + ).build(); + // First process the deletion. Dependent resources (e.g. model with connector) may cause 409 status until they are deleted + assertBusy(() -> { + try { + Response deleteResponse = makeRequest(request, DELETE, restPath + id); + // first successful deletion should produce an OK + assertOK(deleteResponse); + } catch (ResponseException e) { + // repeat deletions can produce a 404, treat as a success + assertNotFound(e.getResponse()); + } + }, 20, TimeUnit.SECONDS); + // Deletion processed, now wait for it to disappear from search + RestRequest searchRequest = getRestRequestWithHeadersAndContent(tenantId, MATCH_ALL_QUERY); + assertBusy(() -> { + Response response = makeRequest(searchRequest, GET, restPath + "_search"); + assertOK(response); + SearchResponse searchResponse = searchResponseFromResponse(response); + assertEquals(hits, searchResponse.getHits().getTotalHits().value); + }, 20, TimeUnit.SECONDS); + } +} diff --git a/src/test/java/org/opensearch/flowframework/TestHelpers.java b/src/test/java/org/opensearch/flowframework/TestHelpers.java index b65e7b599..cc83b8b75 100644 --- a/src/test/java/org/opensearch/flowframework/TestHelpers.java +++ b/src/test/java/org/opensearch/flowframework/TestHelpers.java @@ -53,9 +53,13 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.CommonValue.MASTER_KEY; import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import static org.opensearch.test.OpenSearchTestCase.randomAlphaOfLength; import static org.apache.http.entity.ContentType.APPLICATION_JSON; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class TestHelpers { @@ -191,6 +195,22 @@ public static SearchRequest matchAllRequest() { } public static GetResponse createGetResponse(ToXContentObject o, String id, String indexName) throws IOException { + if (o == null) { + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.getId()).thenReturn(MASTER_KEY); + when(getResponse.getSource()).thenReturn(null); + when(getResponse.toXContent(any(XContentBuilder.class), any())).thenAnswer(invocation -> { + XContentBuilder builder = invocation.getArgument(0); + builder.startObject() + .field("_index", indexName) + .field("_id", id) + .field("found", false) + // .nullField("_source") + .endObject(); + return builder; + }); + return getResponse; + } XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); return new GetResponse( new GetResult( diff --git a/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java b/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java index 2745fde1e..24254020a 100644 --- a/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java +++ b/src/test/java/org/opensearch/flowframework/common/FlowFrameworkSettingsTests.java @@ -43,7 +43,8 @@ public void setUp() throws Exception { FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION, FlowFrameworkSettings.MAX_WORKFLOW_STEPS, FlowFrameworkSettings.MAX_WORKFLOWS, - FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT + FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT, + FlowFrameworkSettings.FLOW_FRAMEWORK_MULTI_TENANCY_ENABLED ) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); @@ -63,5 +64,6 @@ public void testSettings() throws IOException { assertEquals(Optional.of(50), Optional.ofNullable(flowFrameworkSettings.getMaxWorkflowSteps())); assertEquals(Optional.of(1000), Optional.ofNullable(flowFrameworkSettings.getMaxWorkflows())); assertEquals(Optional.of(TimeValue.timeValueSeconds(10)), Optional.ofNullable(flowFrameworkSettings.getRequestTimeout())); + assertFalse(flowFrameworkSettings.isMultiTenancyEnabled()); } } diff --git a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java index a7dd7f75e..16cabf643 100644 --- a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java +++ b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java @@ -35,6 +35,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.TestHelpers; @@ -50,6 +51,8 @@ import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.index.get.GetResult; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -81,6 +84,7 @@ public class FlowFrameworkIndicesHandlerTests extends OpenSearchTestCase { @Mock private Client client; + private SdkClient sdkClient; @Mock private CreateIndexStep createIndexStep; @Mock @@ -94,6 +98,8 @@ public class FlowFrameworkIndicesHandlerTests extends OpenSearchTestCase { @Mock protected ClusterService clusterService; @Mock + protected NamedXContentRegistry namedXContentRegistry; + @Mock private FlowFrameworkIndicesHandler flowMock; private static final String META = "_meta"; private static final String SCHEMA_VERSION_FIELD = "schemaVersion"; @@ -112,7 +118,14 @@ public void setUp() throws Exception { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService, encryptorUtils, xContentRegistry()); + sdkClient = SdkClientFactory.createSdkClient(client, namedXContentRegistry, Collections.emptyMap()); + flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler( + client, + sdkClient, + clusterService, + encryptorUtils, + xContentRegistry() + ); adminClient = mock(AdminClient.class); indicesAdminClient = mock(IndicesAdminClient.class); metadata = mock(Metadata.class); @@ -138,6 +151,7 @@ public void setUp() throws Exception { TestHelpers.randomUser(), null, null, + null, null ); } @@ -168,7 +182,7 @@ public void testFailedUpdateTemplateInGlobalContext() throws IOException { verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals( - "Failed to update template for workflow_id : 1, global_context index does not exist.", + "Failed to update template for workflow_id : 1, global context index does not exist.", exceptionCaptor.getValue().getMessage() ); } @@ -283,10 +297,10 @@ public void testIsWorkflowProvisionedFailedParsing() { responseListener.onResponse(new GetResponse(getResult)); return null; }).when(client).get(any(GetRequest.class), any()); - flowFrameworkIndicesHandler.getProvisioningProgress(documentId, function, listener); + flowFrameworkIndicesHandler.getProvisioningProgress(documentId, null, function, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - assertTrue(exceptionCaptor.getValue().getMessage().contains("Failed to parse workflow state")); + assertTrue(exceptionCaptor.getValue().getMessage().contains("Failed to parse workflowState")); } public void testCanDeleteWorkflowStateDoc() { @@ -302,7 +316,8 @@ public void testCanDeleteWorkflowStateDoc() { Instant.now(), TestHelpers.randomUser(), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + null ); doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); @@ -315,7 +330,7 @@ public void testCanDeleteWorkflowStateDoc() { return null; }).when(client).get(any(GetRequest.class), any()); - flowFrameworkIndicesHandler.canDeleteWorkflowStateDoc(documentId, false, canDelete -> { assertTrue(canDelete); }, listener); + flowFrameworkIndicesHandler.canDeleteWorkflowStateDoc(documentId, null, false, canDelete -> { assertTrue(canDelete); }, listener); } public void testCanNotDeleteWorkflowStateDocInProgress() { @@ -331,7 +346,8 @@ public void testCanNotDeleteWorkflowStateDocInProgress() { Instant.now(), TestHelpers.randomUser(), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + null ); doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); @@ -344,7 +360,7 @@ public void testCanNotDeleteWorkflowStateDocInProgress() { return null; }).when(client).get(any(GetRequest.class), any()); - flowFrameworkIndicesHandler.canDeleteWorkflowStateDoc(documentId, true, canDelete -> { assertFalse(canDelete); }, listener); + flowFrameworkIndicesHandler.canDeleteWorkflowStateDoc(documentId, null, true, canDelete -> { assertFalse(canDelete); }, listener); } public void testDeleteWorkflowStateDocResourcesExist() { @@ -360,7 +376,8 @@ public void testDeleteWorkflowStateDocResourcesExist() { Instant.now(), TestHelpers.randomUser(), Collections.emptyMap(), - List.of(new ResourceCreated("w", "x", "y", "z")) + List.of(new ResourceCreated("w", "x", "y", "z")), + null ); doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); @@ -374,10 +391,10 @@ public void testDeleteWorkflowStateDocResourcesExist() { }).when(client).get(any(GetRequest.class), any()); // Can't delete because resources exist - flowFrameworkIndicesHandler.canDeleteWorkflowStateDoc(documentId, false, canDelete -> { assertFalse(canDelete); }, listener); + flowFrameworkIndicesHandler.canDeleteWorkflowStateDoc(documentId, null, false, canDelete -> { assertFalse(canDelete); }, listener); // But can delete if clearStatus set true - flowFrameworkIndicesHandler.canDeleteWorkflowStateDoc(documentId, true, canDelete -> { assertTrue(canDelete); }, listener); + flowFrameworkIndicesHandler.canDeleteWorkflowStateDoc(documentId, null, true, canDelete -> { assertTrue(canDelete); }, listener); } public void testDoesTemplateExist() { @@ -396,7 +413,7 @@ public void testDoesTemplateExist() { responseListener.onResponse(new GetResponse(getResult)); return null; }).when(client).get(any(GetRequest.class), any()); - flowFrameworkIndicesHandler.doesTemplateExist(documentId, function, listener); + flowFrameworkIndicesHandler.doesTemplateExist(documentId, null, function, listener); verify(function).accept(true); } @@ -417,7 +434,7 @@ public void testUpdateFlowFrameworkSystemIndexDoc() throws IOException { return null; }).when(client).update(any(UpdateRequest.class), any()); - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", Map.of("foo", "bar"), listener); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", null, Map.of("foo", "bar"), listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(UpdateResponse.class); verify(listener, times(1)).onResponse(responseCaptor.capture()); @@ -430,15 +447,15 @@ public void testUpdateFlowFrameworkSystemIndexDoc() throws IOException { return null; }).when(client).update(any(UpdateRequest.class), any()); - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", Map.of("foo", "bar"), listener); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", null, Map.of("foo", "bar"), listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - assertEquals("Failed to update state", exceptionCaptor.getValue().getMessage()); + assertEquals("Failed to update .plugins-flow-framework-state entry : 1", exceptionCaptor.getValue().getMessage()); // test no index when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(false); - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", Map.of("foo", "bar"), listener); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", null, Map.of("foo", "bar"), listener); verify(listener, times(2)).onFailure(exceptionCaptor.capture()); assertEquals( @@ -474,7 +491,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } }; - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", fooBar, listener); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", null, fooBar, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(UpdateResponse.class); verify(listener, times(1)).onResponse(responseCaptor.capture()); @@ -487,15 +504,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return null; }).when(client).update(any(UpdateRequest.class), any()); - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", fooBar, listener); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", null, fooBar, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - assertEquals("Failed to update state", exceptionCaptor.getValue().getMessage()); + assertEquals("Failed to update .plugins-flow-framework-state entry : 1", exceptionCaptor.getValue().getMessage()); // test no index when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(false); - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", fooBar, listener); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", null, fooBar, listener); verify(listener, times(2)).onFailure(exceptionCaptor.capture()); assertEquals( @@ -521,7 +538,7 @@ public void testDeleteFlowFrameworkSystemIndexDoc() throws IOException { return null; }).when(client).delete(any(DeleteRequest.class), any()); - flowFrameworkIndicesHandler.deleteFlowFrameworkSystemIndexDoc("1", listener); + flowFrameworkIndicesHandler.deleteFlowFrameworkSystemIndexDoc("1", null, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(DeleteResponse.class); verify(listener, times(1)).onResponse(responseCaptor.capture()); @@ -534,15 +551,15 @@ public void testDeleteFlowFrameworkSystemIndexDoc() throws IOException { return null; }).when(client).delete(any(DeleteRequest.class), any()); - flowFrameworkIndicesHandler.deleteFlowFrameworkSystemIndexDoc("1", listener); + flowFrameworkIndicesHandler.deleteFlowFrameworkSystemIndexDoc("1", null, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - assertEquals("Failed to delete state", exceptionCaptor.getValue().getMessage()); + assertEquals("Failed to delete .plugins-flow-framework-state entry : 1", exceptionCaptor.getValue().getMessage()); // test no index when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(false); - flowFrameworkIndicesHandler.deleteFlowFrameworkSystemIndexDoc("1", listener); + flowFrameworkIndicesHandler.deleteFlowFrameworkSystemIndexDoc("1", null, listener); verify(listener, times(2)).onFailure(exceptionCaptor.capture()); assertEquals( @@ -582,6 +599,7 @@ public void testAddResourceToStateIndex() { "node_id", CreateConnectorStep.NAME, "this_id", + null, listener ); @@ -601,6 +619,7 @@ public void testAddResourceToStateIndex() { "node_id", CreateConnectorStep.NAME, "this_id", + null, listener ); @@ -625,6 +644,7 @@ public void testAddResourceToStateIndex() { "node_id", CreateConnectorStep.NAME, "this_id", + null, notFoundListener ); @@ -641,6 +661,7 @@ public void testAddResourceToStateIndex() { "node_id", CreateConnectorStep.NAME, "this_id", + null, indexNotFoundListener ); @@ -679,7 +700,7 @@ public void testDeleteResourceFromStateIndex() { return null; }).when(client).update(any(UpdateRequest.class), any()); - flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", resourceToDelete, listener); + flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", null, resourceToDelete, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowData.class); verify(listener, times(1)).onResponse(responseCaptor.capture()); @@ -692,7 +713,7 @@ public void testDeleteResourceFromStateIndex() { return null; }).when(client).update(any(UpdateRequest.class), any()); - flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", resourceToDelete, listener); + flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", null, resourceToDelete, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); @@ -710,7 +731,7 @@ public void testDeleteResourceFromStateIndex() { responseListener.onResponse(new GetResponse(getResult)); return null; }).when(client).get(any(GetRequest.class), any()); - flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", resourceToDelete, notFoundListener); + flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", null, resourceToDelete, notFoundListener); exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(notFoundListener, times(1)).onFailure(exceptionCaptor.capture()); @@ -720,7 +741,7 @@ public void testDeleteResourceFromStateIndex() { when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(false); @SuppressWarnings("unchecked") ActionListener indexNotFoundListener = mock(ActionListener.class); - flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", resourceToDelete, indexNotFoundListener); + flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", null, resourceToDelete, indexNotFoundListener); exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(indexNotFoundListener, times(1)).onFailure(exceptionCaptor.capture()); @@ -771,6 +792,7 @@ public void testAddResourceToStateIndexWithRetries() { "node_id", CreateConnectorStep.NAME, "this_id", + null, retryListener ); @@ -817,6 +839,7 @@ public void testAddResourceToStateIndexWithRetries() { "node_id", CreateConnectorStep.NAME, "this_id", + null, threeRetryListener ); @@ -866,7 +889,7 @@ public void testDeleteResourceFromStateIndexWithRetries() { return null; }).when(client).update(any(UpdateRequest.class), any()); - flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", resourceToDelete, retryListener); + flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", null, resourceToDelete, retryListener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowData.class); verify(retryListener, times(1)).onResponse(responseCaptor.capture()); @@ -906,7 +929,7 @@ public void testDeleteResourceFromStateIndexWithRetries() { return null; }).when(client).update(any(UpdateRequest.class), any()); - flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", resourceToDelete, threeRetryListener); + flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", null, resourceToDelete, threeRetryListener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(threeRetryListener, times(1)).onFailure(exceptionCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/model/TemplateTests.java b/src/test/java/org/opensearch/flowframework/model/TemplateTests.java index d8550682b..ec419aeeb 100644 --- a/src/test/java/org/opensearch/flowframework/model/TemplateTests.java +++ b/src/test/java/org/opensearch/flowframework/model/TemplateTests.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.Collections; import java.util.List; import java.util.Map; @@ -47,7 +48,7 @@ public void testTemplate() throws IOException { Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); Map uiMetadata = null; - Instant now = Instant.now(); + Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS); Template template = new Template( "test", "a test template", @@ -59,6 +60,7 @@ public void testTemplate() throws IOException { null, now, now, + null, null ); @@ -74,6 +76,9 @@ public void testTemplate() throws IOException { assertEquals(now, template.lastUpdatedTime()); assertNull(template.lastProvisionedTime()); assertEquals("Workflow [userParams={key=value}, nodes=[A, B], edges=[A->B]]", wf.toString()); + assertNull(template.getTenantId()); + template.setTenantId("tenant-id"); + assertEquals("tenant-id", template.getTenantId()); String json = TemplateTestJsonUtil.parseToJson(template); @@ -86,10 +91,11 @@ public void testTemplate() throws IOException { assertEquals(uiMetadata, templateX.getUiMetadata()); Workflow wfX = templateX.workflows().get("workflow"); assertNotNull(wfX); - assertEquals(now, template.createdTime()); - assertEquals(now, template.lastUpdatedTime()); - assertNull(template.lastProvisionedTime()); + assertEquals(now, templateX.createdTime()); + assertEquals(now, templateX.lastUpdatedTime()); + assertNull(templateX.lastProvisionedTime()); assertEquals("Workflow [userParams={key=value}, nodes=[A, B], edges=[A->B]]", wfX.toString()); + assertEquals("tenant-id", templateX.getTenantId()); // Test invalid field if updating XContentParser parser = JsonXContent.jsonXContent.createParser( @@ -116,6 +122,7 @@ public void testUpdateExistingTemplate() { null, now, now, + null, null ); Template updated = Template.builder().name("name two").description("description two").useCase("use case two").build(); @@ -184,4 +191,16 @@ public void testNullToEmptyString() throws IOException { assertTrue(json.contains("\"description\":\"\"")); assertTrue(json.contains("\"use_case\":\"\"")); } + + public void testCreateEmptyTemplateWithTenantId() { + String tenantId = "test-tenant"; + Template t = Template.createEmptyTemplateWithTenantId(tenantId); + assertNotNull(t); + assertEquals(tenantId, t.getTenantId()); + } + + public void testCreateEmptyTemplateWithTenantId_NullTenantId() { + Template t = Template.createEmptyTemplateWithTenantId(null); + assertNull(t); + } } diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index c437c32e3..f09559bcd 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -88,7 +88,7 @@ public void testFailedUpdateWorkflow() throws Exception { Map responseMap = entityAsMap(response); String workflowId = (String) responseMap.get(WORKFLOW_ID); - Response provisionResponse = provisionResponse = provisionWorkflow(client(), workflowId); + Response provisionResponse = provisionWorkflow(client(), workflowId); assertEquals(RestStatus.OK, TestHelpers.restStatus(provisionResponse)); getAndAssertWorkflowStatus(client(), workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS); @@ -277,6 +277,8 @@ public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception { 120, TimeUnit.SECONDS ); + // Force a refresh so that search results are up to date + refreshAllIndices(); // Hit Search State API with the workflow id created above String query = "{\"query\":{\"ids\":{\"values\":[\"" + workflowId + "\"]}}}"; diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index 747de4351..dd230243c 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -83,6 +83,7 @@ public void setUp() throws Exception { TestHelpers.randomUser(), null, null, + null, null ); @@ -176,7 +177,7 @@ public void testCreateWorkflowRequestWithWaitForTimeCompletionTimeoutButNoProvis channel.capturedResponse() .content() .utf8ToString() - .contains("are not allowed unless the 'provision' or 'reprovision' parameter is set to true.") + .contains("is not allowed unless the 'provision' or 'reprovision' parameter is set to true.") ); } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestWorkflowStateTenantAwareIT.java b/src/test/java/org/opensearch/flowframework/rest/RestWorkflowStateTenantAwareIT.java new file mode 100644 index 000000000..899b92610 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/rest/RestWorkflowStateTenantAwareIT.java @@ -0,0 +1,356 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.flowframework.FlowFrameworkTenantAwareRestTestCase; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.opensearch.flowframework.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; + +public class RestWorkflowStateTenantAwareIT extends FlowFrameworkTenantAwareRestTestCase { + + private static final String WORKFLOW_PATH = WORKFLOW_URI + "/"; + private static final String WORKFLOW_STATE_PATH = WORKFLOW_URI + "/state/"; + private static final String PROVISION = "/_provision"; + private static final String DEPROVISION = "/_deprovision"; + private static final String STATUS_ALL = "/_status?all=true"; + private static final String CLEAR_STATUS = "?clear_status=true"; + + public void testWorkflowStateCRUD() throws Exception { + boolean multiTenancyEnabled = isMultiTenancyEnabled(); + + /* + * Create + */ + // Create a workflow with a tenant id + RestRequest createWorkflowRequest = getRestRequestWithHeadersAndContent(tenantId, createRemoteModelTemplate()); + Response response = makeRequest(createWorkflowRequest, POST, WORKFLOW_PATH); + assertOK(response); + Map map = responseToMap(response); + assertTrue(map.containsKey(WORKFLOW_ID)); + String workflowId = map.get(WORKFLOW_ID).toString(); + + /* + * Get + */ + // Now try to get that workflow's state + response = makeRequest(tenantRequest, GET, WORKFLOW_PATH + workflowId + STATUS_ALL); + assertOK(response); + map = responseToMap(response); + assertEquals("NOT_STARTED", map.get("state")); + if (multiTenancyEnabled) { + assertEquals(tenantId, map.get(TENANT_ID_FIELD)); + } else { + assertNull(map.get(TENANT_ID_FIELD)); + } + + // Now try again with an other ID + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> makeRequest(otherTenantRequest, GET, WORKFLOW_PATH + workflowId + STATUS_ALL) + ); + response = ex.getResponse(); + map = responseToMap(response); + if (DDB) { + assertNotFound(response); + assertEquals("Failed to retrieve template (" + workflowId + ") from global context.", getErrorReasonFromResponseMap(map)); + } else { + assertForbidden(response); + assertEquals(NO_PERMISSION_REASON, getErrorReasonFromResponseMap(map)); + } + } else { + response = makeRequest(otherTenantRequest, GET, WORKFLOW_PATH + workflowId + STATUS_ALL); + assertOK(response); + map = responseToMap(response); + assertEquals("NOT_STARTED", map.get("state")); + } + + // Now try again with a null ID + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> makeRequest(nullTenantRequest, GET, WORKFLOW_PATH + workflowId + STATUS_ALL) + ); + response = ex.getResponse(); + map = responseToMap(response); + assertForbidden(response); + assertEquals(MISSING_TENANT_REASON, getErrorReasonFromResponseMap(map)); + } else { + response = makeRequest(nullTenantRequest, GET, WORKFLOW_PATH + workflowId + STATUS_ALL); + assertOK(response); + map = responseToMap(response); + assertEquals("NOT_STARTED", map.get("state")); + } + + /* + * Provision + */ + // Try to provision with the wrong tenant id + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> makeRequest(otherTenantRequest, POST, WORKFLOW_PATH + workflowId + PROVISION) + ); + response = ex.getResponse(); + map = responseToMap(response); + if (DDB) { + assertNotFound(response); + assertEquals("Failed to retrieve template (" + workflowId + ") from global context.", getErrorReasonFromResponseMap(map)); + } else { + assertForbidden(response); + assertEquals(NO_PERMISSION_REASON, getErrorReasonFromResponseMap(map)); + } + } + + // Verify state still not started + response = makeRequest(tenantRequest, GET, WORKFLOW_PATH + workflowId + STATUS_ALL); + assertOK(response); + map = responseToMap(response); + assertEquals("NOT_STARTED", map.get("state")); + + // Now try again with a null ID + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> makeRequest(nullTenantRequest, POST, WORKFLOW_PATH + workflowId + PROVISION) + ); + response = ex.getResponse(); + map = responseToMap(response); + assertForbidden(response); + assertEquals(MISSING_TENANT_REASON, getErrorReasonFromResponseMap(map)); + } + + // Verify state still not started + response = makeRequest(tenantRequest, GET, WORKFLOW_PATH + workflowId + STATUS_ALL); + assertOK(response); + map = responseToMap(response); + assertEquals("NOT_STARTED", map.get("state")); + + // Now finally provision the right way + response = makeRequest(tenantRequest, POST, WORKFLOW_PATH + workflowId + PROVISION); + assertOK(response); + map = responseToMap(response); + assertTrue(map.containsKey(WORKFLOW_ID)); + assertEquals(workflowId, map.get(WORKFLOW_ID).toString()); + + assertBusy(() -> { + // Verify state completed + Response restResponse = makeRequest(tenantRequest, GET, WORKFLOW_PATH + workflowId + STATUS_ALL); + assertOK(restResponse); + Map stateMap = responseToMap(restResponse); + assertEquals("COMPLETED", stateMap.get("state")); + }, 20, TimeUnit.SECONDS); + + /* + * Search + */ + // Create and provision second workflow using otherTenantId + RestRequest otherWorkflowRequest = getRestRequestWithHeadersAndContent(otherTenantId, createRemoteModelTemplate()); + response = makeRequest(otherWorkflowRequest, POST, WORKFLOW_URI + "?provision=true"); + assertOK(response); + map = responseToMap(response); + assertTrue(map.containsKey(WORKFLOW_ID)); + String otherWorkflowId = map.get(WORKFLOW_ID).toString(); + + assertBusy(() -> { + // Verify state completed + Response restResponse = makeRequest(otherTenantRequest, GET, WORKFLOW_PATH + otherWorkflowId + STATUS_ALL); + assertOK(restResponse); + Map stateMap = responseToMap(restResponse); + assertEquals("COMPLETED", stateMap.get("state")); + }, 20, TimeUnit.SECONDS); + + // Retry these tests until they pass. Search requires refresh, can take 15s on DDB + refreshAllIndices(); + + assertBusy(() -> { + // Search should show only the workflow state for tenant + Response restResponse = makeRequest(tenantMatchAllRequest, GET, WORKFLOW_STATE_PATH + "_search"); + assertOK(restResponse); + SearchResponse searchResponse = searchResponseFromResponse(restResponse); + if (multiTenancyEnabled) { + assertEquals(1, searchResponse.getHits().getTotalHits().value); + assertEquals(tenantId, searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID_FIELD)); + } else { + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertNull(searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID_FIELD)); + assertNull(searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID_FIELD)); + } + }, 20, TimeUnit.SECONDS); + + assertBusy(() -> { + // Search should show only the workflow for other tenant + Response restResponse = makeRequest(otherTenantMatchAllRequest, GET, WORKFLOW_STATE_PATH + "_search"); + assertOK(restResponse); + SearchResponse searchResponse = searchResponseFromResponse(restResponse); + if (multiTenancyEnabled) { + assertEquals(1, searchResponse.getHits().getTotalHits().value); + assertEquals(otherTenantId, searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID_FIELD)); + } else { + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertNull(searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID_FIELD)); + assertNull(searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID_FIELD)); + } + }, 20, TimeUnit.SECONDS); + + // Search should fail without a tenant id + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> makeRequest(nullTenantMatchAllRequest, GET, WORKFLOW_STATE_PATH + "_search") + ); + response = ex.getResponse(); + assertForbidden(response); + map = responseToMap(response); + assertEquals(MISSING_TENANT_REASON, getErrorReasonFromResponseMap(map)); + } else { + response = makeRequest(nullTenantMatchAllRequest, GET, WORKFLOW_PATH + "_search"); + assertOK(response); + SearchResponse searchResponse = searchResponseFromResponse(response); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertNull(searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID_FIELD)); + assertNull(searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID_FIELD)); + } + + /* + * Delete + */ + // Deleting template without state prevents deleting the state later. Working around this with tenant id checks elsewhere is + // possible but complex and better handled after resolving https://github.com/opensearch-project/flow-framework/issues/986 + response = makeRequest(tenantRequest, DELETE, WORKFLOW_PATH + workflowId + (multiTenancyEnabled ? CLEAR_STATUS : "")); + assertOK(response); + map = responseToMap(response); + assertEquals(workflowId, map.get(DOC_ID).toString()); + + // Verify the deletion + ResponseException ex = assertThrows(ResponseException.class, () -> makeRequest(tenantRequest, GET, WORKFLOW_PATH + workflowId)); + response = ex.getResponse(); + assertNotFound(response); + map = responseToMap(response); + assertEquals("Failed to retrieve template (" + workflowId + ") from global context.", getErrorReasonFromResponseMap(map)); + + if (!multiTenancyEnabled) { + // Verify state still exists + response = makeRequest(tenantRequest, GET, WORKFLOW_PATH + workflowId + STATUS_ALL); + assertOK(response); + map = responseToMap(response); + assertEquals("COMPLETED", map.get("state")); + + // Now delete with clear status + response = makeRequest(tenantRequest, DELETE, WORKFLOW_PATH + workflowId + CLEAR_STATUS); + assertOK(response); + map = responseToMap(response); + assertEquals("not_found", map.get("result")); + } + + // Verify state deleted + ex = assertThrows(ResponseException.class, () -> makeRequest(tenantRequest, GET, WORKFLOW_PATH + workflowId + STATUS_ALL)); + response = ex.getResponse(); + assertNotFound(response); + + /* + * Deprovision + */ + // Try to deprovision with the wrong tenant id + if (multiTenancyEnabled) { + ex = assertThrows( + ResponseException.class, + () -> makeRequest(tenantRequest, POST, WORKFLOW_PATH + otherWorkflowId + DEPROVISION) + ); + response = ex.getResponse(); + map = responseToMap(response); + if (DDB) { + assertNotFound(response); + assertEquals( + "Failed to retrieve template (" + otherWorkflowId + ") from global context.", + getErrorReasonFromResponseMap(map) + ); + } else { + assertForbidden(response); + assertEquals(NO_PERMISSION_REASON, getErrorReasonFromResponseMap(map)); + } + } + + // Verify state still completed + response = makeRequest(otherTenantRequest, GET, WORKFLOW_PATH + otherWorkflowId + STATUS_ALL); + assertOK(response); + map = responseToMap(response); + assertEquals("COMPLETED", map.get("state")); + + // Now try again with a null ID + if (multiTenancyEnabled) { + ex = assertThrows( + ResponseException.class, + () -> makeRequest(nullTenantRequest, POST, WORKFLOW_PATH + otherWorkflowId + DEPROVISION) + ); + response = ex.getResponse(); + map = responseToMap(response); + assertForbidden(response); + assertEquals(MISSING_TENANT_REASON, getErrorReasonFromResponseMap(map)); + } + + // Verify state still completed + response = makeRequest(otherTenantRequest, GET, WORKFLOW_PATH + otherWorkflowId + STATUS_ALL); + assertOK(response); + map = responseToMap(response); + assertEquals("COMPLETED", map.get("state")); + + // Now finally deprovision the right way + response = makeRequest(otherTenantRequest, POST, WORKFLOW_PATH + otherWorkflowId + DEPROVISION); + // Expect 200, may be 202 + assertOkOrAccepted(response); + map = responseToMap(response); + assertTrue(map.containsKey(WORKFLOW_ID)); + assertEquals(otherWorkflowId, map.get(WORKFLOW_ID).toString()); + + assertBusy(() -> { + // Verify state not started + Response restResponse = makeRequest(otherTenantRequest, GET, WORKFLOW_PATH + otherWorkflowId + STATUS_ALL); + assertOK(restResponse); + Map stateMap = responseToMap(restResponse); + assertEquals("NOT_STARTED", stateMap.get("state")); + }, 20, TimeUnit.SECONDS); + + // Delete workflow from tenant without specifying to delete state + response = makeRequest(otherTenantRequest, DELETE, WORKFLOW_PATH + otherWorkflowId); + assertOK(response); + map = responseToMap(response); + assertEquals(otherWorkflowId, map.get(DOC_ID).toString()); + + // Verify the deletion + ex = assertThrows(ResponseException.class, () -> makeRequest(otherTenantRequest, GET, WORKFLOW_PATH + otherWorkflowId)); + response = ex.getResponse(); + assertNotFound(response); + map = responseToMap(response); + assertEquals("Failed to retrieve template (" + otherWorkflowId + ") from global context.", getErrorReasonFromResponseMap(map)); + + // Verify state deleted + ex = assertThrows( + ResponseException.class, + () -> makeRequest(otherTenantRequest, GET, WORKFLOW_PATH + otherWorkflowId + STATUS_ALL) + ); + response = ex.getResponse(); + assertNotFound(response); + } + + private static String createRemoteModelTemplate() throws IOException { + return ParseUtils.resourceToString("/template/createconnector-registerremotemodel-deploymodel.json"); + } +} diff --git a/src/test/java/org/opensearch/flowframework/rest/RestWorkflowTenantAwareIT.java b/src/test/java/org/opensearch/flowframework/rest/RestWorkflowTenantAwareIT.java new file mode 100644 index 000000000..0120f65f2 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/rest/RestWorkflowTenantAwareIT.java @@ -0,0 +1,310 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.flowframework.FlowFrameworkTenantAwareRestTestCase; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.opensearch.flowframework.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; + +public class RestWorkflowTenantAwareIT extends FlowFrameworkTenantAwareRestTestCase { + + private static final String WORKFLOW_PATH = WORKFLOW_URI + "/"; + + public void testWorkflowCRUD() throws Exception { + boolean multiTenancyEnabled = isMultiTenancyEnabled(); + + /* + * Create + */ + // Create a workflow with a tenant id + RestRequest createWorkflowRequest = getRestRequestWithHeadersAndContent(tenantId, createNoOpTemplate()); + Response response = makeRequest(createWorkflowRequest, POST, WORKFLOW_PATH); + assertOK(response); + Map map = responseToMap(response); + assertTrue(map.containsKey(WORKFLOW_ID)); + String workflowId = map.get(WORKFLOW_ID).toString(); + + /* + * Get + */ + // Now try to get that workflow + response = makeRequest(tenantRequest, GET, WORKFLOW_PATH + workflowId); + assertOK(response); + map = responseToMap(response); + assertEquals("noop", map.get("name")); + if (multiTenancyEnabled) { + assertEquals(tenantId, map.get(TENANT_ID_FIELD)); + } else { + assertNull(map.get(TENANT_ID_FIELD)); + } + + // Now try again with an other ID + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> makeRequest(otherTenantRequest, GET, WORKFLOW_PATH + workflowId) + ); + response = ex.getResponse(); + map = responseToMap(response); + if (DDB) { + assertNotFound(response); + assertEquals("Failed to retrieve template (" + workflowId + ") from global context.", getErrorReasonFromResponseMap(map)); + } else { + assertForbidden(response); + assertEquals(NO_PERMISSION_REASON, getErrorReasonFromResponseMap(map)); + } + } else { + response = makeRequest(otherTenantRequest, GET, WORKFLOW_PATH + workflowId); + assertOK(response); + map = responseToMap(response); + assertEquals("noop", map.get("name")); + } + + // Now try again with a null ID + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> makeRequest(nullTenantRequest, GET, WORKFLOW_PATH + workflowId) + ); + response = ex.getResponse(); + map = responseToMap(response); + assertForbidden(response); + assertEquals(MISSING_TENANT_REASON, getErrorReasonFromResponseMap(map)); + } else { + response = makeRequest(nullTenantRequest, GET, WORKFLOW_PATH + workflowId); + assertOK(response); + map = responseToMap(response); + assertEquals("noop", map.get("name")); + } + + /* + * Update + */ + // Now attempt to update the workflow name + RestRequest updateRequest = getRestRequestWithHeadersAndContent(tenantId, "{\"name\":\"Updated name\"}"); + response = makeRequest(updateRequest, PUT, WORKFLOW_PATH + workflowId); + assertOK(response); + map = responseToMap(response); + assertEquals(workflowId, map.get(WORKFLOW_ID).toString()); + + // Verify the update + response = makeRequest(tenantRequest, GET, WORKFLOW_PATH + workflowId); + assertOK(response); + map = responseToMap(response); + assertEquals("Updated name", map.get("name")); + + // Try the update with other tenant ID + RestRequest otherUpdateRequest = getRestRequestWithHeadersAndContent(otherTenantId, "{\"name\":\"Other tenant name\"}"); + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> makeRequest(otherUpdateRequest, PUT, WORKFLOW_PATH + workflowId) + ); + response = ex.getResponse(); + map = responseToMap(response); + if (DDB) { + assertNotFound(response); + assertEquals("Failed to retrieve template (" + workflowId + ") from global context.", getErrorReasonFromResponseMap(map)); + } else { + assertForbidden(response); + assertEquals(NO_PERMISSION_REASON, getErrorReasonFromResponseMap(map)); + } + } else { + response = makeRequest(otherUpdateRequest, PUT, WORKFLOW_PATH + workflowId); + assertOK(response); + // Verify the update + response = makeRequest(otherTenantRequest, GET, WORKFLOW_PATH + workflowId); + assertOK(response); + map = responseToMap(response); + assertEquals("Other tenant name", map.get("name")); + } + + // Try the update with no tenant ID + RestRequest nullUpdateRequest = getRestRequestWithHeadersAndContent(null, "{\"name\":\"Null tenant name\"}"); + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> makeRequest(nullUpdateRequest, PUT, WORKFLOW_PATH + workflowId) + ); + response = ex.getResponse(); + map = responseToMap(response); + assertForbidden(response); + assertEquals(MISSING_TENANT_REASON, getErrorReasonFromResponseMap(map)); + } else { + response = makeRequest(nullUpdateRequest, PUT, WORKFLOW_PATH + workflowId); + assertOK(response); + // Verify the update + response = makeRequest(tenantRequest, GET, WORKFLOW_PATH + workflowId); + assertOK(response); + map = responseToMap(response); + assertEquals("Null tenant name", map.get("name")); + } + + // Verify no change from original update when multiTenancy enabled + if (multiTenancyEnabled) { + response = makeRequest(tenantRequest, GET, WORKFLOW_PATH + workflowId); + assertOK(response); + map = responseToMap(response); + assertEquals("Updated name", map.get("name")); + } + + /* + * Search + */ + // Create a second workflow using otherTenantId + RestRequest otherWorkflowRequest = getRestRequestWithHeadersAndContent(otherTenantId, createNoOpTemplate()); + response = makeRequest(otherWorkflowRequest, POST, WORKFLOW_PATH); + assertOK(response); + map = responseToMap(response); + assertTrue(map.containsKey(WORKFLOW_ID)); + String otherWorkflowId = map.get(WORKFLOW_ID).toString(); + + // Verify it + response = makeRequest(otherTenantRequest, GET, WORKFLOW_PATH + otherWorkflowId); + assertOK(response); + map = responseToMap(response); + assertEquals("noop", map.get("name")); + + // Retry these tests until they pass. Search requires refresh, can take 15s on DDB + refreshAllIndices(); + + assertBusy(() -> { + // Search should show only the workflow for tenant + Response restResponse = makeRequest(tenantMatchAllRequest, GET, WORKFLOW_PATH + "_search"); + assertOK(restResponse); + SearchResponse searchResponse = searchResponseFromResponse(restResponse); + if (multiTenancyEnabled) { + assertEquals(1, searchResponse.getHits().getTotalHits().value); + assertEquals(tenantId, searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID_FIELD)); + } else { + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertNull(searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID_FIELD)); + assertNull(searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID_FIELD)); + } + }, 20, TimeUnit.SECONDS); + + assertBusy(() -> { + // Search should show only the workflow for other tenant + Response restResponse = makeRequest(otherTenantMatchAllRequest, GET, WORKFLOW_PATH + "_search"); + assertOK(restResponse); + SearchResponse searchResponse = searchResponseFromResponse(restResponse); + if (multiTenancyEnabled) { + assertEquals(1, searchResponse.getHits().getTotalHits().value); + assertEquals(otherTenantId, searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID_FIELD)); + } else { + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertNull(searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID_FIELD)); + assertNull(searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID_FIELD)); + } + }, 20, TimeUnit.SECONDS); + + // Search should fail without a tenant id + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> makeRequest(nullTenantMatchAllRequest, GET, WORKFLOW_PATH + "_search") + ); + response = ex.getResponse(); + assertForbidden(response); + map = responseToMap(response); + assertEquals(MISSING_TENANT_REASON, getErrorReasonFromResponseMap(map)); + } else { + response = makeRequest(nullTenantMatchAllRequest, GET, WORKFLOW_PATH + "_search"); + assertOK(response); + SearchResponse searchResponse = searchResponseFromResponse(response); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertNull(searchResponse.getHits().getHits()[0].getSourceAsMap().get(TENANT_ID_FIELD)); + assertNull(searchResponse.getHits().getHits()[1].getSourceAsMap().get(TENANT_ID_FIELD)); + } + + /* + * Delete + */ + // Delete the workflows + // First test that we can't delete other tenant workflows + if (multiTenancyEnabled) { + ResponseException ex = assertThrows( + ResponseException.class, + () -> makeRequest(tenantRequest, DELETE, WORKFLOW_PATH + otherWorkflowId) + ); + response = ex.getResponse(); + map = responseToMap(response); + if (DDB) { + assertNotFound(response); + assertEquals( + "Failed to find workflow with the provided workflow id: " + otherWorkflowId, + getErrorReasonFromResponseMap(map) + ); + } else { + assertForbidden(response); + assertEquals(NO_PERMISSION_REASON, getErrorReasonFromResponseMap(map)); + } + + ex = assertThrows(ResponseException.class, () -> makeRequest(otherTenantRequest, DELETE, WORKFLOW_PATH + workflowId)); + response = ex.getResponse(); + map = responseToMap(response); + if (DDB) { + assertNotFound(response); + assertEquals("Failed to retrieve template (" + workflowId + ") from global context.", getErrorReasonFromResponseMap(map)); + } else { + assertForbidden(response); + assertEquals(NO_PERMISSION_REASON, getErrorReasonFromResponseMap(map)); + } + + // and can't delete without a tenant ID either + ex = assertThrows(ResponseException.class, () -> makeRequest(nullTenantRequest, DELETE, WORKFLOW_PATH + workflowId)); + response = ex.getResponse(); + map = responseToMap(response); + assertForbidden(response); + assertEquals(MISSING_TENANT_REASON, getErrorReasonFromResponseMap(map)); + } + + // Now actually do the deletions. Same result whether multi-tenancy is enabled. + // Delete from tenant + response = makeRequest(tenantRequest, DELETE, WORKFLOW_PATH + workflowId); + assertOK(response); + map = responseToMap(response); + assertEquals(workflowId, map.get(DOC_ID).toString()); + + // Verify the deletion + ResponseException ex = assertThrows(ResponseException.class, () -> makeRequest(tenantRequest, GET, WORKFLOW_PATH + workflowId)); + response = ex.getResponse(); + assertNotFound(response); + map = responseToMap(response); + assertEquals("Failed to retrieve template (" + workflowId + ") from global context.", getErrorReasonFromResponseMap(map)); + + // Delete from other tenant + response = makeRequest(otherTenantRequest, DELETE, WORKFLOW_PATH + otherWorkflowId); + assertOK(response); + map = responseToMap(response); + assertEquals(otherWorkflowId, map.get(DOC_ID).toString()); + + // Verify the deletion + ex = assertThrows(ResponseException.class, () -> makeRequest(otherTenantRequest, GET, WORKFLOW_PATH + otherWorkflowId)); + response = ex.getResponse(); + assertNotFound(response); + map = responseToMap(response); + assertEquals("Failed to retrieve template (" + otherWorkflowId + ") from global context.", getErrorReasonFromResponseMap(map)); + } + + private static String createNoOpTemplate() throws IOException { + return ParseUtils.resourceToString("/template/noop.json"); + } +} diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index e1fecdc6b..52f2e74db 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -15,6 +15,8 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; @@ -30,9 +32,9 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.flowframework.TestHelpers; import org.opensearch.flowframework.common.FlowFrameworkSettings; -import org.opensearch.flowframework.indices.FlowFrameworkIndex; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; @@ -41,8 +43,11 @@ import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.plugins.PluginsService; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -50,11 +55,9 @@ import java.io.IOException; import java.time.Instant; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -73,10 +76,10 @@ import static org.opensearch.flowframework.common.WorkflowResources.REGISTER_REMOTE_MODEL; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; @@ -93,7 +96,7 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private WorkflowProcessSorter workflowProcessSorter; private Template template; private Client client; - private ThreadPool threadPool; + private SdkClient sdkClient; private FlowFrameworkSettings flowFrameworkSettings; private PluginsService pluginsService; private ClusterService clusterService; @@ -103,8 +106,8 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); client = mock(Client.class); + this.sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); - threadPool = mock(ThreadPool.class); this.flowFrameworkSettings = mock(FlowFrameworkSettings.class); when(flowFrameworkSettings.getMaxWorkflows()).thenReturn(2); when(flowFrameworkSettings.getRequestTimeout()).thenReturn(TimeValue.timeValueSeconds(10)); @@ -118,7 +121,6 @@ public void setUp() throws Exception { clusterSettings = new ClusterSettings(Settings.EMPTY, Set.copyOf(List.of(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES))); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - FlowFrameworkIndex index = FlowFrameworkIndex.GLOBAL_CONTEXT; ClusterName clusterName = new ClusterName("test"); Settings indexSettings = Settings.builder() @@ -143,19 +145,18 @@ public void setUp() throws Exception { flowFrameworkIndicesHandler, flowFrameworkSettings, client, + sdkClient, pluginsService, clusterService, xContentRegistry(), Settings.EMPTY ) ); - // client = mock(Client.class); + + ThreadPool threadPool = mock(ThreadPool.class); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); - // threadContext = mock(ThreadContext.class); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - // when(threadContext.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT)).thenReturn("123"); - // parseUtils = mock(ParseUtils.class); Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); @@ -172,6 +173,7 @@ public void setUp() throws Exception { TestHelpers.randomUser(), null, null, + null, null ); } @@ -237,6 +239,7 @@ public void testValidation_Failed() throws Exception { TestHelpers.randomUser(), null, null, + null, null ); @@ -267,14 +270,13 @@ public void testMaxWorkflow() { doAnswer(invocation -> { ActionListener searchListener = invocation.getArgument(1); - SearchResponse searchResponse = mock(SearchResponse.class); - SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(3, TotalHits.Relation.EQUAL_TO), 1.0f); - when(searchResponse.getHits()).thenReturn(searchHits); + SearchResponse searchResponse = generateEmptySearchResponseWithHitCount(3); searchListener.onResponse(searchResponse); return null; }).when(client).search(any(SearchRequest.class), any()); createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals(("Maximum workflows limit reached: 2"), exceptionCaptor.getValue().getMessage()); @@ -294,7 +296,7 @@ public void onFailure(Exception e) { fail("Should call onResponse"); } }; - createWorkflowTransportAction.checkMaxWorkflows(new TimeValue(10, TimeUnit.SECONDS), 10, listener); + createWorkflowTransportAction.checkMaxWorkflows(new TimeValue(10, TimeUnit.SECONDS), Integer.valueOf(10), "tenant-id", listener); } public void testFailedToCreateNewWorkflow() { @@ -312,17 +314,17 @@ public void testFailedToCreateNewWorkflow() { // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { - ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + ActionListener checkMaxWorkflowListener = invocation.getArgument(3); checkMaxWorkflowListener.onResponse(true); return null; - }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), any(Integer.class), nullable(String.class), any()); // Bypass initializeConfigIndex and force onResponse doAnswer(invocation -> { - ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(1); initalizeMasterKeyIndexListener.onResponse(true); return null; - }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(nullable(String.class), any()); doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); @@ -351,17 +353,17 @@ public void testCreateNewWorkflow() { // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { - ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + ActionListener checkMaxWorkflowListener = invocation.getArgument(3); checkMaxWorkflowListener.onResponse(true); return null; - }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), any(Integer.class), nullable(String.class), any()); // Bypass initializeConfigIndex and force onResponse doAnswer(invocation -> { - ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(1); initalizeMasterKeyIndexListener.onResponse(true); return null; - }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(nullable(String.class), any()); // Bypass putTemplateToGlobalContext and force onResponse doAnswer(invocation -> { @@ -372,10 +374,10 @@ public void testCreateNewWorkflow() { // Bypass putInitialStateToWorkflowState and force on response doAnswer(invocation -> { - ActionListener responseListener = invocation.getArgument(2); + ActionListener responseListener = invocation.getArgument(3); responseListener.onResponse(new IndexResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "1", 1L, 1L, 1L, true)); return null; - }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any()); + }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any(), any()); ArgumentCaptor workflowResponseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); @@ -391,7 +393,7 @@ public void testCreateWithUserAndFilterOn() { ThreadContext threadContext = new ThreadContext(settings); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alice|odfe,aes|engineering,operations"); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - org.opensearch.threadpool.ThreadPool mockThreadPool = mock(ThreadPool.class); + ThreadPool mockThreadPool = mock(ThreadPool.class); when(client.threadPool()).thenReturn(mockThreadPool); when(mockThreadPool.getThreadContext()).thenReturn(threadContext); @@ -403,6 +405,7 @@ public void testCreateWithUserAndFilterOn() { flowFrameworkIndicesHandler, flowFrameworkSettings, client, + sdkClient, pluginsService, clusterService, xContentRegistry(), @@ -410,6 +413,7 @@ public void testCreateWithUserAndFilterOn() { ) ); + @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( null, @@ -423,17 +427,17 @@ public void testCreateWithUserAndFilterOn() { // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { - ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + ActionListener checkMaxWorkflowListener = invocation.getArgument(3); checkMaxWorkflowListener.onResponse(true); return null; - }).when(createWorkflowTransportAction1).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), any(Integer.class), nullable(String.class), any()); // Bypass initializeConfigIndex and force onResponse doAnswer(invocation -> { - ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(1); initalizeMasterKeyIndexListener.onResponse(true); return null; - }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(nullable(String.class), any()); // Bypass putTemplateToGlobalContext and force onResponse doAnswer(invocation -> { @@ -444,10 +448,10 @@ public void testCreateWithUserAndFilterOn() { // Bypass putInitialStateToWorkflowState and force on response doAnswer(invocation -> { - ActionListener responseListener = invocation.getArgument(2); + ActionListener responseListener = invocation.getArgument(3); responseListener.onResponse(new IndexResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "1", 1L, 1L, 1L, true)); return null; - }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any()); + }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any(), any()); ArgumentCaptor workflowResponseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); @@ -457,12 +461,11 @@ public void testCreateWithUserAndFilterOn() { } public void testFailedToCreateNewWorkflowWithNullUser() { - @SuppressWarnings("unchecked") Settings settings = Settings.builder().put(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES.getKey(), true).build(); ThreadContext threadContext = new ThreadContext(settings); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, null); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - org.opensearch.threadpool.ThreadPool mockThreadPool = mock(ThreadPool.class); + ThreadPool mockThreadPool = mock(ThreadPool.class); when(client.threadPool()).thenReturn(mockThreadPool); when(mockThreadPool.getThreadContext()).thenReturn(threadContext); @@ -474,6 +477,7 @@ public void testFailedToCreateNewWorkflowWithNullUser() { flowFrameworkIndicesHandler, flowFrameworkSettings, client, + sdkClient, pluginsService, clusterService, xContentRegistry(), @@ -481,6 +485,7 @@ public void testFailedToCreateNewWorkflowWithNullUser() { ) ); + @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( @@ -500,7 +505,6 @@ public void testFailedToCreateNewWorkflowWithNullUser() { } public void testFailedToCreateNewWorkflowWithNoBackendRoleUser() { - @SuppressWarnings("unchecked") Settings settings = Settings.builder().put(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES.getKey(), true).build(); ThreadContext threadContext = new ThreadContext(settings); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "test"); @@ -517,6 +521,7 @@ public void testFailedToCreateNewWorkflowWithNoBackendRoleUser() { flowFrameworkIndicesHandler, flowFrameworkSettings, client, + sdkClient, pluginsService, clusterService, xContentRegistry(), @@ -524,6 +529,7 @@ public void testFailedToCreateNewWorkflowWithNoBackendRoleUser() { ) ); + @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( @@ -558,28 +564,9 @@ public void testUpdateWorkflowWithReprovision() throws IOException { null ); - doAnswer(invocation -> { - ActionListener getListener = invocation.getArgument(1); - GetResponse getResponse = mock(GetResponse.class); - when(getResponse.isExists()).thenReturn(true); - when(getResponse.getSourceAsString()).thenReturn(template.toJson()); - getListener.onResponse(getResponse); - return null; - }).when(client).get(any(GetRequest.class), any()); - GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX); doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - assertEquals( - String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), - 2, - args.length - ); - - assertTrue(args[0] instanceof GetRequest); - assertTrue(args[1] instanceof ActionListener); - - ActionListener getListener = (ActionListener) args[1]; + ActionListener getListener = invocation.getArgument(1); getListener.onResponse(getWorkflowResponse); return null; }).when(client).get(any(GetRequest.class), any()); @@ -591,6 +578,7 @@ public void testUpdateWorkflowWithReprovision() throws IOException { }).when(client).execute(any(), any(), any()); createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); verify(listener, times(1)).onResponse(responseCaptor.capture()); @@ -610,28 +598,9 @@ public void testFailedToUpdateWorkflowWithReprovision() throws IOException { null ); - doAnswer(invocation -> { - ActionListener getListener = invocation.getArgument(1); - GetResponse getResponse = mock(GetResponse.class); - when(getResponse.isExists()).thenReturn(true); - when(getResponse.getSourceAsString()).thenReturn(template.toJson()); - getListener.onResponse(getResponse); - return null; - }).when(client).get(any(GetRequest.class), any()); - GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX); doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - assertEquals( - String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), - 2, - args.length - ); - - assertTrue(args[0] instanceof GetRequest); - assertTrue(args[1] instanceof ActionListener); - - ActionListener getListener = (ActionListener) args[1]; + ActionListener getListener = invocation.getArgument(1); getListener.onResponse(getWorkflowResponse); return null; }).when(client).get(any(GetRequest.class), any()); @@ -643,6 +612,7 @@ public void testFailedToUpdateWorkflowWithReprovision() throws IOException { }).when(client).execute(any(), any(), any()); createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(responseCaptor.capture()); @@ -654,28 +624,9 @@ public void testFailedToUpdateWorkflow() throws IOException { ActionListener listener = mock(ActionListener.class); WorkflowRequest updateWorkflow = new WorkflowRequest("1", template); - doAnswer(invocation -> { - ActionListener getListener = invocation.getArgument(1); - GetResponse getResponse = mock(GetResponse.class); - when(getResponse.isExists()).thenReturn(true); - when(getResponse.getSourceAsString()).thenReturn(template.toJson()); - getListener.onResponse(getResponse); - return null; - }).when(client).get(any(GetRequest.class), any()); - GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX); doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - assertEquals( - String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), - 2, - args.length - ); - - assertTrue(args[0] instanceof GetRequest); - assertTrue(args[1] instanceof ActionListener); - - ActionListener getListener = (ActionListener) args[1]; + ActionListener getListener = invocation.getArgument(1); getListener.onResponse(getWorkflowResponse); return null; }).when(client).get(any(GetRequest.class), any()); @@ -687,6 +638,7 @@ public void testFailedToUpdateWorkflow() throws IOException { }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(anyString(), any(Template.class), any(), anyBoolean()); createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals("Failed to update use case template 1", exceptionCaptor.getValue().getMessage()); @@ -699,26 +651,7 @@ public void testFailedToUpdateNonExistingWorkflow() throws IOException { doAnswer(invocation -> { ActionListener getListener = invocation.getArgument(1); - GetResponse getResponse = mock(GetResponse.class); - when(getResponse.isExists()).thenReturn(false); - getListener.onResponse(getResponse); - return null; - }).when(client).get(any(GetRequest.class), any()); - - GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX); - doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - assertEquals( - String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), - 2, - args.length - ); - - assertTrue(args[0] instanceof GetRequest); - assertTrue(args[1] instanceof ActionListener); - - ActionListener getListener = (ActionListener) args[1]; - getListener.onFailure(new Exception("Failed to retrieve template (2) from global context.")); + getListener.onFailure(new Exception("test")); return null; }).when(client).get(any(GetRequest.class), any()); @@ -729,9 +662,10 @@ public void testFailedToUpdateNonExistingWorkflow() throws IOException { }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(any(), any(Template.class), any()); createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - assertEquals("Failed to retrieve template (2) from global context.", exceptionCaptor.getValue().getMessage()); + assertEquals("Failed to get data object from index .plugins-flow-framework-templates", exceptionCaptor.getValue().getMessage()); } public void testUpdateWorkflow() throws IOException { @@ -739,15 +673,6 @@ public void testUpdateWorkflow() throws IOException { ActionListener listener = mock(ActionListener.class); WorkflowRequest updateWorkflow = new WorkflowRequest("1", template); - doAnswer(invocation -> { - ActionListener getListener = invocation.getArgument(1); - GetResponse getResponse = mock(GetResponse.class); - when(getResponse.isExists()).thenReturn(true); - when(getResponse.getSourceAsString()).thenReturn(Template.builder().name("test").build().toJson()); - getListener.onResponse(getResponse); - return null; - }).when(client).get(any(GetRequest.class), any()); - doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(2); responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); @@ -756,28 +681,19 @@ public void testUpdateWorkflow() throws IOException { GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX); doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - assertEquals( - String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), - 2, - args.length - ); - - assertTrue(args[0] instanceof GetRequest); - assertTrue(args[1] instanceof ActionListener); - - ActionListener getListener = (ActionListener) args[1]; + ActionListener getListener = invocation.getArgument(1); getListener.onResponse(getWorkflowResponse); return null; }).when(client).get(any(GetRequest.class), any()); doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(2); + ActionListener updateResponseListener = invocation.getArgument(3); updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); return null; - }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(anyString(), anyMap(), any()); + }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(anyString(), nullable(String.class), anyMap(), any()); createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); verify(listener, times(1)).onResponse(responseCaptor.capture()); @@ -792,28 +708,9 @@ public void testUpdateWorkflowWithField() throws IOException { WorkflowRequest updateWorkflow = new WorkflowRequest("1", template1, Map.of(UPDATE_WORKFLOW_FIELDS, "true")); - doAnswer(invocation -> { - ActionListener getListener = invocation.getArgument(1); - GetResponse getResponse = mock(GetResponse.class); - when(getResponse.isExists()).thenReturn(true); - when(getResponse.getSourceAsString()).thenReturn(template.toJson()); - getListener.onResponse(getResponse); - return null; - }).when(client).get(any(GetRequest.class), any()); - GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX); doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - assertEquals( - String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), - 2, - args.length - ); - - assertTrue(args[0] instanceof GetRequest); - assertTrue(args[1] instanceof ActionListener); - - ActionListener getListener = (ActionListener) args[1]; + ActionListener getListener = invocation.getArgument(1); getListener.onResponse(getWorkflowResponse); return null; }).when(client).get(any(GetRequest.class), any()); @@ -825,6 +722,7 @@ public void testUpdateWorkflowWithField() throws IOException { }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(anyString(), any(Template.class), any(), anyBoolean()); createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); + verify(listener, times(1)).onResponse(any()); ArgumentCaptor