From 23b52269c10992fcee7f15647321275b8d1364ce Mon Sep 17 00:00:00 2001 From: kulket <130191298+kulket@users.noreply.github.com> Date: Wed, 24 May 2023 16:26:24 -0700 Subject: [PATCH] Personalized intelligent ranking for open search requests (#138) * Personalized intelligent ranking for open search requests Description: ============ This change uses search pipeline feature of open search to allow customers to create a response processor. Response processor will rerank search results obtained from open search using Personalization solution before returning results to the user. Check List: ========== - [Y] New functionality includes testing. - [Y] All tests pass - [Y] New functionality has been documented. - [Y] New functionality has javadoc added - [Y] Commits are signed as per the DCO using --signoff Signed-off-by: Ketan Kulkarni * Support passing user id for Personalization as part of search request Description: ============ This change introduces ability ot pass user id as part of search request to be used by Personalize response processor. SearchExtSpec will be used to support getting user id as well as other personalize related request parameters to personalize search results. Check list: =========== - [x] New functionality includes testing. - [x] All tests pass - [x] New functionality has been documented. - [x] New functionality has javadoc added - [x] Commits are signed as per the DCO using --signoff Signed-off-by: Ketan Kulkarni * Incorporate changes to SearchPipelinePlugin interface Description: ============ SearchPipelinePlugin interface now uses Processor.Factory instead of Processor.Factory. Also use getResponseProcessors method instead of getProcessors method. Check list: =========== - [x] New functionality includes testing. - [x] All tests pass - [x] New functionality has been documented. - [x] New functionality has javadoc added - [x] Commits are signed as per the DCO using --signoff Signed-off-by: Ketan Kulkarni * Improve test coverage Added unit tests for PersonalizeRankingResponseProcessor and PersonalizeRequestParameterUtil. Also, needed to update JaCoCo to 0.8.9 in order to get rid of warnings about invalid bytecode versions since we upgraded to JDK 20. Signed-off-by: Michael Froh * Move configruation constants to the search pipeline processor factory Description: ============ Configuration constants are only used by search response processor factory. Move related constants from global configuration constants file to factory class to ensure they are closer to the palce where they are used. Signed-off-by: Ketan Kulkarni --------- Signed-off-by: Ketan Kulkarni Signed-off-by: Michael Froh Co-authored-by: Michael Froh --- build.gradle | 9 + .../relevance/SearchRelevancePlugin.java | 27 ++- .../PersonalizeRankingResponseProcessor.java | 168 ++++++++++++++++++ .../client/PersonalizeClient.java | 71 ++++++++ .../client/PersonalizeClientSettings.java | 95 ++++++++++ ...PersonalizeCredentialsProviderFactory.java | 89 ++++++++++ .../configuration/Constants.java | 16 ++ ...onalizeIntelligentRankerConfiguration.java | 91 ++++++++++ .../requestparameter/Constants.java | 15 ++ .../PersonalizeRequestParameterUtil.java | 35 ++++ .../PersonalizeRequestParameters.java | 85 +++++++++ ...ersonalizeRequestParametersExtBuilder.java | 80 +++++++++ .../reranker/PersonalizedRanker.java | 30 ++++ .../reranker/PersonalizedRankerFactory.java | 40 +++++ .../impl/AmazonPersonalizedRankerImpl.java | 106 +++++++++++ .../PersonalizeResponseProcessorTests.java | 134 ++++++++++++++ .../PersonalizeClientSettingsTests.java | 60 +++++++ .../client/PersonalizeClientTests.java | 42 +++++ ...nalizeCredentialsProviderFactoryTests.java | 68 +++++++ ...zeIntelligentRankerConfigurationTests.java | 32 ++++ .../ranker/PersonalizeRankerFactoryTests.java | 47 +++++ .../AmazonPersonalizeRankerImplTests.java | 62 +++++++ .../PersonalizeRequestParameterUtilTests.java | 30 ++++ ...alizeRequestParametersExtBuilderTests.java | 50 ++++++ .../PersonalizeClientSettingsTestUtil.java | 39 ++++ .../utils/PersonalizeRuntimeTestUtil.java | 34 ++++ .../utils/SearchTestUtil.java | 37 ++++ 27 files changed, 1585 insertions(+), 7 deletions(-) create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessor.java create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClient.java create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientSettings.java create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeCredentialsProviderFactory.java create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/Constants.java create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/PersonalizeIntelligentRankerConfiguration.java create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/Constants.java create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameterUtil.java create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameters.java create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilder.java create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRanker.java create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRankerFactory.java create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeResponseProcessorTests.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientSettingsTests.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientTests.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeCredentialsProviderFactoryTests.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/PersonalizeIntelligentRankerConfigurationTests.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/PersonalizeRankerFactoryTests.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameterUtilTests.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilderTests.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeClientSettingsTestUtil.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeRuntimeTestUtil.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/SearchTestUtil.java diff --git a/build.gradle b/build.gradle index 463aa61..d608e18 100644 --- a/build.gradle +++ b/build.gradle @@ -91,8 +91,17 @@ dependencies { implementation 'commons-logging:commons-logging:1.2' implementation 'com.amazonaws:aws-java-sdk-sts:1.12.300' implementation 'com.amazonaws:aws-java-sdk-core:1.12.300' + implementation 'com.amazonaws:aws-java-sdk-personalizeruntime:1.12.300' } + +allprojects { + plugins.withId('jacoco') { + jacoco.toolVersion = '0.8.9' + } +} + + test { include '**/*Tests.class' finalizedBy jacocoTestReport diff --git a/src/main/java/org/opensearch/search/relevance/SearchRelevancePlugin.java b/src/main/java/org/opensearch/search/relevance/SearchRelevancePlugin.java index b2143f9..be43dc4 100644 --- a/src/main/java/org/opensearch/search/relevance/SearchRelevancePlugin.java +++ b/src/main/java/org/opensearch/search/relevance/SearchRelevancePlugin.java @@ -10,7 +10,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.function.Supplier; @@ -27,9 +26,12 @@ import org.opensearch.env.NodeEnvironment; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; +import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.script.ScriptService; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; import org.opensearch.search.relevance.actionfilter.SearchActionFilter; import org.opensearch.search.relevance.client.OpenSearchClient; import org.opensearch.search.relevance.configuration.ResultTransformerConfigurationFactory; @@ -40,10 +42,13 @@ import org.opensearch.search.relevance.transformer.ResultTransformer; import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings; import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfigurationFactory; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.PersonalizeRankingResponseProcessor; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClientSettings; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParametersExtBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; -public class SearchRelevancePlugin extends Plugin implements ActionPlugin, SearchPlugin { +public class SearchRelevancePlugin extends Plugin implements ActionPlugin, SearchPlugin, SearchPipelinePlugin { private OpenSearchClient openSearchClient; private KendraHttpClient kendraClient; @@ -101,10 +106,18 @@ public Collection createComponents( public List> getSearchExts() { Map resultTransformerMap = getResultTransformerConfigurationFactories().stream() .collect(Collectors.toMap(ResultTransformerConfigurationFactory::getName, i -> i)); - return Collections.singletonList( - new SearchExtSpec<>(SearchConfigurationExtBuilder.NAME, - input -> new SearchConfigurationExtBuilder(input, resultTransformerMap), - parser -> SearchConfigurationExtBuilder.parse(parser, resultTransformerMap))); + return List.of(new SearchExtSpec<>(SearchConfigurationExtBuilder.NAME, + input -> new SearchConfigurationExtBuilder(input, resultTransformerMap), + parser -> SearchConfigurationExtBuilder.parse(parser, resultTransformerMap)), + new SearchExtSpec<>(PersonalizeRequestParametersExtBuilder.NAME, + input -> new PersonalizeRequestParametersExtBuilder(input), + parser -> PersonalizeRequestParametersExtBuilder.parse(parser))); + } + + @Override + public Map> getResponseProcessors(Processor.Parameters parameters) { + return Map.of(PersonalizeRankingResponseProcessor.TYPE, + new PersonalizeRankingResponseProcessor.Factory( + PersonalizeClientSettings.getClientSettings(parameters.env.settings()))); } - } diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessor.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessor.java new file mode 100644 index 0000000..e8d6b06 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessor.java @@ -0,0 +1,168 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking; + +import com.amazonaws.auth.AWSCredentialsProvider; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.search.profile.SearchProfileShardResults; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClientSettings; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeCredentialsProviderFactory; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameterUtil; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRanker; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRankerFactory; + +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.function.BiFunction; + +/** + * This is a {@link SearchResponseProcessor} that applies Personalized intelligent ranking + */ +public class PersonalizeRankingResponseProcessor implements SearchResponseProcessor { + + private static final Logger logger = LogManager.getLogger(PersonalizeRankingResponseProcessor.class); + + public static final String TYPE = "personalize_ranking"; + private final String tag; + private final String description; + private final PersonalizeClient personalizeClient; + private final PersonalizeIntelligentRankerConfiguration rankerConfig; + + /** + * Constructor for Personalize ranking response processor + * + * @param tag processor tag + * @param description processor description + * @param rankerConfig personalize ranker config + * @param client personalize client + */ + public PersonalizeRankingResponseProcessor(String tag, + String description, + PersonalizeIntelligentRankerConfiguration rankerConfig, + PersonalizeClient client) { + super(); + this.tag = tag; + this.description = description; + this.rankerConfig = rankerConfig; + this.personalizeClient = client; + } + + /** + * Transform the response hits by re ranking results using Personalize + * + * @param request Search request + * @param response Search response that needs to be transformed + * @return Transformed search response using personalized re ranking + * @throws Exception Throws exception for any error while processing response + */ + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { + SearchHits hits = response.getHits(); + + if (hits.getHits().length == 0) { + logger.info("TotalHits = 0. Returning search response without applying Personalize transform"); + return response; + } + logger.info("Personalizing search results."); + PersonalizeRequestParameters personalizeRequestParameters = + PersonalizeRequestParameterUtil.getPersonalizeRequestParameters(request); + PersonalizedRankerFactory rankerFactory = new PersonalizedRankerFactory(); + PersonalizedRanker ranker = rankerFactory.getPersonalizedRanker(rankerConfig, personalizeClient); + long startTime = System.nanoTime(); + SearchHits personalizedHits = ranker.rerank(hits, personalizeRequestParameters); + long personalizeTimeTookMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime); + + final SearchResponseSections transformedSearchResponseSections = new InternalSearchResponse(personalizedHits, + (InternalAggregations) response.getAggregations(), response.getSuggest(), + new SearchProfileShardResults(response.getProfileResults()), response.isTimedOut(), + response.isTerminatedEarly(), response.getNumReducePhases()); + + final SearchResponse transformedResponse = new SearchResponse(transformedSearchResponseSections, response.getScrollId(), + response.getTotalShards(), response.getSuccessfulShards(), + response.getSkippedShards(), response.getTook().getMillis() + personalizeTimeTookMs, response.getShardFailures(), + response.getClusters()); + + logger.info("Personalize ranking processor took " + personalizeTimeTookMs + " ms"); + + return transformedResponse; + } + + /** + * Get the type of the processor. + */ + @Override + public String getType() { + return TYPE; + } + + /** + * Get the tag of a processor. + */ + @Override + public String getTag() { + return tag; + } + + /** + * Gets the description of a processor. + */ + @Override + public String getDescription() { + return description; + } + + public static final class Factory implements Processor.Factory { + + private static final String CAMPAIGN_ARN_CONFIG_NAME = "campaign_arn"; + private static final String ITEM_ID_FIELD_CONFIG_NAME = "item_id_field"; + private static final String IAM_ROLE_ARN_CONFIG_NAME = "iam_role_arn"; + private static final String RECIPE_CONFIG_NAME = "recipe"; + private static final String REGION_CONFIG_NAME = "aws_region"; + private static final String WEIGHT_CONFIG_NAME = "weight"; + PersonalizeClientSettings personalizeClientSettings; + private final BiFunction clientBuilder; + + Factory(PersonalizeClientSettings settings, BiFunction clientBuilder) { + this.personalizeClientSettings = settings; + this.clientBuilder = clientBuilder; + } + + public Factory(PersonalizeClientSettings settings) { + this(settings, PersonalizeClient::new); + } + + @Override + public PersonalizeRankingResponseProcessor create(Map> processorFactories, String tag, String description, Map config) throws Exception { + String personalizeCampaign = ConfigurationUtils.readStringProperty(TYPE, tag, config, CAMPAIGN_ARN_CONFIG_NAME); + String iamRoleArn = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, IAM_ROLE_ARN_CONFIG_NAME); + String recipe = ConfigurationUtils.readStringProperty(TYPE, tag, config, RECIPE_CONFIG_NAME); + String itemIdField = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, ITEM_ID_FIELD_CONFIG_NAME); + String awsRegion = ConfigurationUtils.readStringProperty(TYPE, tag, config, REGION_CONFIG_NAME); + double weight = ConfigurationUtils.readDoubleProperty(TYPE, tag, config, WEIGHT_CONFIG_NAME); + + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, awsRegion, weight); + AWSCredentialsProvider credentialsProvider = PersonalizeCredentialsProviderFactory.getCredentialsProvider(personalizeClientSettings, iamRoleArn, awsRegion); + PersonalizeClient personalizeClient = clientBuilder.apply(credentialsProvider, awsRegion); + return new PersonalizeRankingResponseProcessor(tag, description, rankerConfig, personalizeClient); + } + } +} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClient.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClient.java new file mode 100644 index 0000000..22c5130 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClient.java @@ -0,0 +1,71 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.client; + +import com.amazonaws.AmazonServiceException; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.services.personalizeruntime.AmazonPersonalizeRuntime; +import com.amazonaws.services.personalizeruntime.AmazonPersonalizeRuntimeClientBuilder; +import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingRequest; +import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingResult; + +import java.io.Closeable; +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; + +/** + * Amazon Personalize client implementation for getting personalized ranking + */ +public class PersonalizeClient implements Closeable { + private final AmazonPersonalizeRuntime personalizeRuntime; + + /** + * Constructor for Amazon Personalize client + * @param credentialsProvider Credentials to be used for accessing Amazon Personalize + * @param awsRegion AWS region where Amazon Personalize campaign is hosted + */ + public PersonalizeClient(AWSCredentialsProvider credentialsProvider, String awsRegion) { + personalizeRuntime = AccessController.doPrivileged( + (PrivilegedAction) () -> AmazonPersonalizeRuntimeClientBuilder.standard() + .withCredentials(credentialsProvider) + .withRegion(awsRegion) + .build()); + } + + /** + * Get Personalize runtime client + * @return Personalize runtime client + */ + public AmazonPersonalizeRuntime getPersonalizeRuntime() { + return personalizeRuntime; + } + + /** + * Get Personalized ranking using Personalized runtime client + * @param request Get personalized ranking request + * @return Personalized ranking results + */ + public GetPersonalizedRankingResult getPersonalizedRanking(GetPersonalizedRankingRequest request) { + GetPersonalizedRankingResult result; + try { + result = AccessController.doPrivileged( + (PrivilegedAction) () -> personalizeRuntime.getPersonalizedRanking(request)); + } catch (AmazonServiceException ex) { + throw ex; + } + return result; + } + + @Override + public void close() throws IOException { + if (personalizeRuntime != null) { + personalizeRuntime.shutdown(); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientSettings.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientSettings.java new file mode 100644 index 0000000..c3de442 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientSettings.java @@ -0,0 +1,95 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.client; + +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.auth.BasicSessionCredentials; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.settings.SecureString; +import org.opensearch.common.settings.SecureSetting; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.settings.SettingsException; + +/** + * Container for personalize client settings such as AWS credentials + */ +public final class PersonalizeClientSettings { + + private static final Logger logger = LogManager.getLogger(PersonalizeClientSettings.class); + + /** + * The access key (ie login id) for connecting to Personalize. + */ + public static final Setting ACCESS_KEY_SETTING = SecureSetting.secureString("personalize_intelligent_ranking.aws.access_key", null); + + /** + * The secret key (ie password) for connecting to Personalize. + */ + public static final Setting SECRET_KEY_SETTING = SecureSetting.secureString("personalize_intelligent_ranking.aws.secret_key", null); + + /** + * The session token for connecting to Personalize. + */ + public static final Setting SESSION_TOKEN_SETTING = SecureSetting.secureString("personalize_intelligent_ranking.aws.session_token", null); + + private final AWSCredentials credentials; + + protected PersonalizeClientSettings(AWSCredentials credentials) { + this.credentials = credentials; + } + + public AWSCredentials getCredentials() { + return credentials; + } + + /** + * Load AWS credentials from open search keystore if available + * @param settings Open search settings + * @return AWS credentials + */ + static AWSCredentials loadCredentials(Settings settings) { + try (SecureString key = ACCESS_KEY_SETTING.get(settings); + SecureString secret = SECRET_KEY_SETTING.get(settings); + SecureString sessionToken = SESSION_TOKEN_SETTING.get(settings)) { + if (key.length() == 0 && secret.length() == 0) { + if (sessionToken.length() > 0) { + throw new SettingsException("Setting [{}] is set but [{}] and [{}] are not", + SESSION_TOKEN_SETTING.getKey(), ACCESS_KEY_SETTING.getKey(), SECRET_KEY_SETTING.getKey()); + } + logger.info("Using either environment variables, system properties or instance profile credentials"); + return null; + } else if (key.length() == 0 || secret.length() == 0) { + throw new SettingsException("One of settings [{}] and [{}] is not set.", + ACCESS_KEY_SETTING.getKey(), SECRET_KEY_SETTING.getKey()); + } else { + final AWSCredentials credentials; + if (sessionToken.length() == 0) { + logger.info("Using basic key/secret credentials"); + credentials = new BasicAWSCredentials(key.toString(), secret.toString()); + } else { + logger.info("Using basic session credentials"); + credentials = new BasicSessionCredentials(key.toString(), secret.toString(), sessionToken.toString()); + } + return credentials; + } + } + } + + /** + * Get Personalize client settings + * @param settings Open search settings + * @return Personalize client settings instance with AWS credentials + */ + public static PersonalizeClientSettings getClientSettings(Settings settings) { + final AWSCredentials credentials = loadCredentials(settings); + return new PersonalizeClientSettings(credentials); + } +} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeCredentialsProviderFactory.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeCredentialsProviderFactory.java new file mode 100644 index 0000000..e0e2895 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeCredentialsProviderFactory.java @@ -0,0 +1,89 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.client; + +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; +import com.amazonaws.services.securitytoken.AWSSecurityTokenService; +import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.security.AccessController; +import java.security.PrivilegedAction; + +/** + * Factory implementation for getting Personalize credentials + */ +public final class PersonalizeCredentialsProviderFactory { + private static final Logger logger = LogManager.getLogger(PersonalizeCredentialsProviderFactory.class); + private static final String ASSUME_ROLE_SESSION_NAME = "OpenSearchPersonalizeIntelligentRankingPluginSession"; + + private PersonalizeCredentialsProviderFactory() { + } + + /** + * Get AWS credentials provider either from static credentials from open search keystore or + * using DefaultAWSCredentialsProviderChain. + * @param clientSettings Personalize client settings + * @return AWS credentials provider for accessing Personalize + */ + static AWSCredentialsProvider getCredentialsProvider(PersonalizeClientSettings clientSettings) { + final AWSCredentialsProvider credentialsProvider; + final AWSCredentials credentials = clientSettings.getCredentials(); + if (credentials == null) { + logger.info("Credentials not present in open search keystore. Using DefaultAWSCredentialsProviderChain for credentials."); + credentialsProvider = AccessController.doPrivileged( + (PrivilegedAction) () -> DefaultAWSCredentialsProviderChain.getInstance()); + } else { + logger.info("Using credentials provided in open search keystore"); + credentialsProvider = AccessController.doPrivileged( + (PrivilegedAction) () -> new AWSStaticCredentialsProvider(credentials)); + } + return credentialsProvider; + } + + /** + * Get AWS credentials provider by assuming IAM role if provided or else + * use static credentials or DefaultAWSCredentialsProviderChain. + * @param clientSettings Personalize client settings + * @param personalizeIAMRole IAM role configuration for accessing Personalize + * @param awsRegion AWS region + * @return AWS credentials provider for accessing Amazon Personalize + */ + public static AWSCredentialsProvider getCredentialsProvider(PersonalizeClientSettings clientSettings, + String personalizeIAMRole, + String awsRegion) { + + final AWSCredentialsProvider credentialsProvider; + AWSCredentialsProvider baseCredentialsProvider = getCredentialsProvider(clientSettings); + + if (personalizeIAMRole != null && !personalizeIAMRole.isBlank()) { + logger.info("Using IAM Role provided to access Personalize."); + // If IAM role ARN was provided in config, then use auto-refreshed role credentials. + credentialsProvider = AccessController.doPrivileged( + (PrivilegedAction) () -> { + AWSSecurityTokenService awsSecurityTokenService = AWSSecurityTokenServiceClientBuilder.standard() + .withCredentials(baseCredentialsProvider) + .withRegion(awsRegion) + .build(); + + return new STSAssumeRoleSessionCredentialsProvider.Builder(personalizeIAMRole, ASSUME_ROLE_SESSION_NAME) + .withStsClient(awsSecurityTokenService) + .build(); + }); + } else { + logger.info("IAM Role for accessing Personalize is not provided."); + credentialsProvider = baseCredentialsProvider; + } + return credentialsProvider; + } +} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/Constants.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/Constants.java new file mode 100644 index 0000000..eefea08 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/Constants.java @@ -0,0 +1,16 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration; + +/** + * Constants for Amazon Perosnalize response processor + */ +public class Constants { + public static final String AMAZON_PERSONALIZED_RANKING_RECIPE_NAME = "aws-personalized-ranking"; +} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/PersonalizeIntelligentRankerConfiguration.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/PersonalizeIntelligentRankerConfiguration.java new file mode 100644 index 0000000..35685de --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/PersonalizeIntelligentRankerConfiguration.java @@ -0,0 +1,91 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration; + +/** + * A container for holding Personalize ranker configuration + */ +public class PersonalizeIntelligentRankerConfiguration { + private final String personalizeCampaign; + private final String iamRoleArn; + private final String recipe; + private final String itemIdField; + private final String region; + private final double weight; + + /** + * + * @param personalizeCampaign Personalize campaign + * @param iamRoleArn IAM Role ARN for accessing Personalize campaign + * @param recipe Personalize recipe associated with campaign + * @param itemIdField Item ID field to pick up item id for Personalize input + * @param region AWS region + * @param weight Configurable coefficient to control Personalization of search results + */ + public PersonalizeIntelligentRankerConfiguration(String personalizeCampaign, + String iamRoleArn, + String recipe, + String itemIdField, + String region, + double weight) { + this.personalizeCampaign = personalizeCampaign; + this.iamRoleArn = iamRoleArn; + this.recipe = recipe; + this.itemIdField = itemIdField; + this.region = region; + this.weight = weight; + } + + /** + * Get PErsonalize campaign + * @return Personalize campaign + */ + public String getPersonalizeCampaign() { + return personalizeCampaign; + } + + /** + * Get recipe + * @return Recipe associated with Personalize campaign + */ + public String getRecipe() { + return recipe; + } + + /** + * Get Item ID field + * @return Item ID field + */ + public String getItemIdField() { + return itemIdField; + } + + /** + * Get AWS region + * @return AWS region + */ + public String getRegion() { + return region; + } + + /** + * + * @return weight value + */ + public double getWeight() { + return weight; + } + + /** + * Get IAM role ARN for Personalize campaign + * @return IAM role for accessing Personalize campaign + */ + public String getIamRoleArn() { + return iamRoleArn; + } +} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/Constants.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/Constants.java new file mode 100644 index 0000000..7d959cf --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/Constants.java @@ -0,0 +1,15 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter; + +public class Constants { + + public static final String PERSONALIZE_REQUEST_PARAMETERS = "personalize_request_parameters"; + public static final String USER_ID_PARAMETER = "user_id"; + +} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameterUtil.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameterUtil.java new file mode 100644 index 0000000..bb36927 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameterUtil.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.SearchExtBuilder; + +import java.util.List; +import java.util.stream.Collectors; + +public class PersonalizeRequestParameterUtil { + + public static PersonalizeRequestParameters getPersonalizeRequestParameters(SearchRequest searchRequest) { + PersonalizeRequestParametersExtBuilder personalizeRequestParameterExtBuilder = null; + if (searchRequest.source() != null && searchRequest.source().ext() != null && !searchRequest.source().ext().isEmpty()) { + List extBuilders = searchRequest.source().ext().stream() + .filter(extBuilder -> PersonalizeRequestParametersExtBuilder.NAME.equals(extBuilder.getWriteableName())) + .collect(Collectors.toList()); + + if (!extBuilders.isEmpty()) { + personalizeRequestParameterExtBuilder = (PersonalizeRequestParametersExtBuilder) extBuilders.get(0); + } + } + PersonalizeRequestParameters personalizeRequestParameters = null; + if (personalizeRequestParameterExtBuilder != null) { + personalizeRequestParameters = personalizeRequestParameterExtBuilder.getRequestParameters(); + } + return personalizeRequestParameters; + } +} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameters.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameters.java new file mode 100644 index 0000000..a2130b0 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameters.java @@ -0,0 +1,85 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.ParseField; +import org.opensearch.core.xcontent.ObjectParser; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.Constants.PERSONALIZE_REQUEST_PARAMETERS; +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.Constants.USER_ID_PARAMETER; + +public class PersonalizeRequestParameters implements Writeable, ToXContentObject { + + private static final ObjectParser PARSER; + private static final ParseField USER_ID = new ParseField(USER_ID_PARAMETER); + + static { + PARSER = new ObjectParser<>(PERSONALIZE_REQUEST_PARAMETERS, PersonalizeRequestParameters::new); + PARSER.declareString(PersonalizeRequestParameters::setUserId, USER_ID); + } + + private String userId; + + public PersonalizeRequestParameters() {} + + public PersonalizeRequestParameters(String userId) { + this.userId = userId; + } + + public PersonalizeRequestParameters(StreamInput input) throws IOException { + this.userId = input.readString(); + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.userId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.field(USER_ID.getPreferredName(), this.userId); + } + + public static PersonalizeRequestParameters parse(XContentParser parser) throws IOException { + PersonalizeRequestParameters requestParameters = PARSER.parse(parser, null); + return requestParameters; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + PersonalizeRequestParameters config = (PersonalizeRequestParameters) o; + + if (!userId.equals(config.userId)) return false; + return userId.equals(config.userId); + } + + @Override + public int hashCode() { + return Objects.hash(userId); + } +} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilder.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilder.java new file mode 100644 index 0000000..0dc4f04 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilder.java @@ -0,0 +1,80 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchExtBuilder; + +import java.io.IOException; +import java.util.Objects; + +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.Constants.PERSONALIZE_REQUEST_PARAMETERS; + +public class PersonalizeRequestParametersExtBuilder extends SearchExtBuilder { + private static final Logger logger = LogManager.getLogger(PersonalizeRequestParametersExtBuilder.class); + public static final String NAME = PERSONALIZE_REQUEST_PARAMETERS; + private PersonalizeRequestParameters requestParameters; + + public PersonalizeRequestParametersExtBuilder() {} + + public PersonalizeRequestParametersExtBuilder(StreamInput input) throws IOException { + requestParameters = new PersonalizeRequestParameters(input); + } + + public PersonalizeRequestParameters getRequestParameters() { + return requestParameters; + } + + public void setRequestParameters(PersonalizeRequestParameters requestParameters) { + this.requestParameters = requestParameters; + } + + @Override + public int hashCode() { + return Objects.hash(this.getClass(), this.requestParameters); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (!(obj instanceof PersonalizeRequestParametersExtBuilder)) { + return false; + } + PersonalizeRequestParametersExtBuilder o = (PersonalizeRequestParametersExtBuilder) obj; + return this.requestParameters.equals(o.requestParameters); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + requestParameters.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.value(requestParameters); + } + + public static PersonalizeRequestParametersExtBuilder parse(XContentParser parser) throws IOException{ + PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder(); + PersonalizeRequestParameters requestParameters = PersonalizeRequestParameters.parse(parser); + extBuilder.setRequestParameters(requestParameters); + return extBuilder; + } +} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRanker.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRanker.java new file mode 100644 index 0000000..4699897 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRanker.java @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker; + +import org.opensearch.search.SearchHits; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters; + +public interface PersonalizedRanker { + + /** + * Re rank search hits + * @param hits Search hits to re rank + * @param requestParameters Request parameters for Personalize present in search request + * @return Re ranked search hits + */ + SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestParameters); + + /** + * Validate Personalize configuration for calling Personalize service + * @param requestParameters Request parameters for Personalize present in search request + * @return True if valid configuration present else false. + */ + boolean isValidPersonalizeConfigPresent(PersonalizeRequestParameters requestParameters); + +} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRankerFactory.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRankerFactory.java new file mode 100644 index 0000000..1e30ec9 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRankerFactory.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.impl.AmazonPersonalizedRankerImpl; + +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME; + +/** + * Factory for creating Personalize ranker instance based on Personalize ranker configuration + */ +public class PersonalizedRankerFactory { + private static final Logger logger = LogManager.getLogger(PersonalizedRankerFactory.class); + + /** + * Create an instance of Personalize ranker based on ranker configuration + * @param config Personalize ranker configuration + * @param client Personalize client + * @return Personalize ranker instance + */ + public PersonalizedRanker getPersonalizedRanker(PersonalizeIntelligentRankerConfiguration config, PersonalizeClient client){ + PersonalizedRanker ranker = null; + if (config.getRecipe().equals(AMAZON_PERSONALIZED_RANKING_RECIPE_NAME)) { + ranker = new AmazonPersonalizedRankerImpl(config, client); + } else { + logger.error("Personalize recipe provided in configuration is not supported for re ranking search results"); + //TODO : throw user error exception + } + return ranker; + } +} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java new file mode 100644 index 0000000..bd3a964 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.impl; + +import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingRequest; +import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingResult; +import com.amazonaws.services.personalizeruntime.model.PredictedItem; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.Constants; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRanker; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Personalize Re Ranker implementation using Amazon Personalized Ranking recipe + */ +public class AmazonPersonalizedRankerImpl implements PersonalizedRanker { + private static final Logger logger = LogManager.getLogger(AmazonPersonalizedRankerImpl.class); + private final PersonalizeIntelligentRankerConfiguration rankerConfig; + private final PersonalizeClient personalizeClient; + public AmazonPersonalizedRankerImpl(PersonalizeIntelligentRankerConfiguration config, + PersonalizeClient client) { + this.rankerConfig = config; + this.personalizeClient = client; + } + + /** + * Re rank search hits using Personalize campaign that uses Personalized Ranking recipe + * @param hits search hits returned by open search + * @param requestParameters request parameters for Personalize present in search request + * @return search hots re ranked using Amazon Personalize + */ + @Override + public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestParameters) { + try { + if (!isValidPersonalizeConfigPresent(requestParameters)) { + throw new IllegalArgumentException("Required configurations missing from Personalize " + + "response processor configuration or search request parameters"); + } + List originalHits = Arrays.asList(hits.getHits()); + String itemIdfield = rankerConfig.getItemIdField(); + List documentIdsToRank; + // If item field is not specified in the configruation then use default _id field. + if (!itemIdfield.isEmpty()) { + documentIdsToRank = originalHits.stream() + .filter(h -> h.getSourceAsMap().get(itemIdfield) != null) + .map(h -> h.getSourceAsMap().get(itemIdfield).toString()) + .collect(Collectors.toList()); + } else { + documentIdsToRank = originalHits.stream() + .filter(h -> h.getId() != null) + .map(h -> h.getId()) + .collect(Collectors.toList()); + } + logger.info("Document Ids to re-rank with Personalize: {}", Arrays.toString(documentIdsToRank.toArray())); + // TODO: Parse context from request parameters + String userId = requestParameters.getUserId(); + logger.info("User ID from request parameters. User ID: {}", userId); + GetPersonalizedRankingRequest personalizeRequest = new GetPersonalizedRankingRequest() + .withCampaignArn(rankerConfig.getPersonalizeCampaign()) + .withInputList(documentIdsToRank) + .withUserId(userId); + GetPersonalizedRankingResult result = personalizeClient.getPersonalizedRanking(personalizeRequest); + + //TODO: Combine Personalize and open search result. Change the result after transform logic is implemented + return hits; + } catch (Exception ex) { + logger.error("Failed to re rank with Personalize. Returning original search results without Personalize re ranking.", ex); + return hits; + } + } + + /** + * Validate Personalize configuration for calling Personalize service + * @param requestParameters Request parameters for Personalize present in search request + * @return True if valid configuration present else false. + */ + public boolean isValidPersonalizeConfigPresent(PersonalizeRequestParameters requestParameters) { + boolean isValidPersonalizeConfig = true; + + if (requestParameters == null || requestParameters.getUserId().isEmpty()) { + isValidPersonalizeConfig = false; + logger.error("Required Personalize parameters are not provided in the search request"); + } + + if (rankerConfig == null || rankerConfig.getPersonalizeCampaign().isEmpty() || + rankerConfig.getWeight() < 0.0 || rankerConfig.getWeight() > 1.0) { + isValidPersonalizeConfig = false; + logger.error("Required Personalized ranker configuration is missing"); + } + return isValidPersonalizeConfig; + } +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeResponseProcessorTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeResponseProcessorTests.java new file mode 100644 index 0000000..1bad35b --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeResponseProcessorTests.java @@ -0,0 +1,134 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking; + +import com.amazonaws.http.IdleConnectionReaper; +import org.apache.lucene.search.TotalHits; +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; +import org.opensearch.env.TestEnvironment; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClientSettings; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRankerFactory; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.WeakHashMap; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME; + +public class PersonalizeResponseProcessorTests extends OpenSearchTestCase { + + private static final String TYPE = "personalize_ranking"; + private Settings settings = buildEnvSettings(Settings.EMPTY); + private Environment env = TestEnvironment.newEnvironment(settings); + private String personalizeCampaign = "arn:aws:personalize:us-west-2:000000000000:campaign/test-campaign"; + private String iamRoleArn = ""; + private String recipe = "sample-personalize-recipe"; + private String itemIdField = "ITEM_ID"; + private String region = "us-west-2"; + private double weight = 0.25; + + private PersonalizeClientSettings clientSettings = PersonalizeClientSettings.getClientSettings(env.settings()); + + public void testCreateFactoryThrowsExceptionWithEmptyConfig() { + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings); + expectThrows(OpenSearchParseException.class, () -> factory.create( + Collections.emptyMap(), + null, + null, + Collections.emptyMap() + )); + } + + public void testCreateFactoryWithAllPersonalizeConfig() throws Exception { + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings); + + Map configuration = new HashMap<>(); + configuration.put("campaign_arn", personalizeCampaign); + configuration.put("item_id_field", itemIdField); + configuration.put("recipe", recipe); + configuration.put("weight", String.valueOf(weight)); + configuration.put("iam_role_arn", iamRoleArn); + configuration.put("aws_region", region); + + PersonalizeRankingResponseProcessor personalizeResponseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", configuration); + + assertEquals(TYPE, personalizeResponseProcessor.getType()); + assertEquals("testTag", personalizeResponseProcessor.getTag()); + assertEquals("testingAllFields", personalizeResponseProcessor.getDescription()); + IdleConnectionReaper.shutdown(); + } + + public void testProcessorWithNoHits() throws Exception { + PersonalizeClient mockClient = mock(PersonalizeClient.class); + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> mockClient); + + Map configuration = new HashMap<>(); + configuration.put("campaign_arn", personalizeCampaign); + configuration.put("item_id_field", itemIdField); + configuration.put("recipe", recipe); + configuration.put("weight", String.valueOf(weight)); + configuration.put("iam_role_arn", iamRoleArn); + configuration.put("aws_region", region); + + PersonalizeRankingResponseProcessor personalizeResponseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", configuration); + SearchRequest searchRequest = new SearchRequest(); + SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f); + SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, false, null, 0); + SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 1, new ShardSearchFailure[0], null); + + personalizeResponseProcessor.processResponse(searchRequest, searchResponse); + } + + public void testProcessorWithHits() throws Exception { + PersonalizeClient mockClient = mock(PersonalizeClient.class); + + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> mockClient); + + Map configuration = new HashMap<>(); + configuration.put("campaign_arn", personalizeCampaign); + configuration.put("item_id_field", itemIdField); + configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); + configuration.put("weight", String.valueOf(weight)); + configuration.put("iam_role_arn", iamRoleArn); + configuration.put("aws_region", region); + + PersonalizeRankingResponseProcessor personalizeResponseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", configuration); + SearchRequest searchRequest = new SearchRequest(); + SearchHit[] searchHits = new SearchHit[10]; + for (int i = 0; i < searchHits.length; i++) { + searchHits[i] = new SearchHit(i, Integer.toString(i), Collections.emptyMap(), Collections.emptyMap()); + searchHits[i].score(1.0f); + } + SearchHits hits = new SearchHits(searchHits, new TotalHits(searchHits.length, TotalHits.Relation.EQUAL_TO), 1.0f); + SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, false, null, 0); + SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 1, new ShardSearchFailure[0], null); + + personalizeResponseProcessor.processResponse(searchRequest, searchResponse); + } +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientSettingsTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientSettingsTests.java new file mode 100644 index 0000000..a81564c --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientSettingsTests.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.client; + +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSSessionCredentials; +import org.opensearch.common.settings.SettingsException; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.PersonalizeClientSettingsTestUtil; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.PersonalizeClientSettingsTestUtil.ACCESS_KEY; +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.PersonalizeClientSettingsTestUtil.SECRET_KEY; +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.PersonalizeClientSettingsTestUtil.SESSION_TOKEN; + +public class PersonalizeClientSettingsTests extends OpenSearchTestCase { + + public void testWithBasicCredentials() throws IOException { + PersonalizeClientSettings clientSettings = PersonalizeClientSettingsTestUtil.buildClientSettings(true, true, false); + AWSCredentials credentials = clientSettings.getCredentials(); + assertEquals(ACCESS_KEY, credentials.getAWSAccessKeyId()); + assertEquals(SECRET_KEY, credentials.getAWSSecretKey()); + assertFalse(credentials instanceof AWSSessionCredentials); + } + + public void testWithSessionCredentials() throws IOException { + PersonalizeClientSettings clientSettings = PersonalizeClientSettingsTestUtil.buildClientSettings(true, true, true); + AWSCredentials credentials = clientSettings.getCredentials(); + assertEquals(ACCESS_KEY, credentials.getAWSAccessKeyId()); + assertEquals(SECRET_KEY, credentials.getAWSSecretKey()); + assertTrue(credentials instanceof AWSSessionCredentials); + AWSSessionCredentials sessionCredentials = (AWSSessionCredentials) credentials; + assertEquals(SESSION_TOKEN, sessionCredentials.getSessionToken()); + } + + public void testWithoutCredentials() throws IOException { + PersonalizeClientSettings clientSettings = PersonalizeClientSettingsTestUtil.buildClientSettings(false, false, false); + assertNull(clientSettings.getCredentials()); + } + + public void testWithoutAccessKey() { + expectThrows(SettingsException.class, () -> PersonalizeClientSettingsTestUtil.buildClientSettings(false, true, false)); + expectThrows(SettingsException.class, () -> PersonalizeClientSettingsTestUtil.buildClientSettings(false, true, true)); + } + + public void testWithoutSecretKey() { + expectThrows(SettingsException.class, () -> PersonalizeClientSettingsTestUtil.buildClientSettings(true, false, false)); + expectThrows(SettingsException.class, () -> PersonalizeClientSettingsTestUtil.buildClientSettings(true, false, true)); + } + + public void testWithSessionTokenButNoCredentials() { + expectThrows(SettingsException.class, () -> PersonalizeClientSettingsTestUtil.buildClientSettings(false, false, true)); + } +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientTests.java new file mode 100644 index 0000000..57c2ead --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientTests.java @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.client; + +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicSessionCredentials; +import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingRequest; +import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingResult; +import org.mockito.Mockito; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.PersonalizeRuntimeTestUtil; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.mockito.ArgumentMatchers.any; + +public class PersonalizeClientTests extends OpenSearchTestCase { + + public void testCreateClient() throws IOException { + AWSCredentials credentials = new BasicSessionCredentials("accessKey", "secretKey", "sessionToken"); + AWSCredentialsProvider credentialsProvider = new AWSStaticCredentialsProvider(credentials); + String region = "us-west-2"; + try (PersonalizeClient client = new PersonalizeClient(credentialsProvider,region)) { + assertTrue(client.getPersonalizeRuntime() != null); + } + } + + public void testGetPersonalizedRanking() { + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + GetPersonalizedRankingRequest request = PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingRequest(); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult()); + GetPersonalizedRankingResult result = client.getPersonalizedRanking(request); + assertEquals(result.getRecommendationId(), "sampleRecommendationId"); + } +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeCredentialsProviderFactoryTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeCredentialsProviderFactoryTests.java new file mode 100644 index 0000000..af72a46 --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeCredentialsProviderFactoryTests.java @@ -0,0 +1,68 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.client; + +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; +import com.amazonaws.http.IdleConnectionReaper; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.PersonalizeClientSettingsTestUtil; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +public class PersonalizeCredentialsProviderFactoryTests extends OpenSearchTestCase { + + public void testGetStaticCredentialsProviderWithoutIAMRole() throws IOException { + PersonalizeClientSettings settings = + PersonalizeClientSettingsTestUtil.buildClientSettings(true, true, true); + + AWSCredentialsProvider credentialsProvider = PersonalizeCredentialsProviderFactory.getCredentialsProvider(settings); + assertEquals(credentialsProvider.getClass(), AWSStaticCredentialsProvider.class); + } + + public void testGetDefaultCredentialsProviderWithoutIAMRole() throws IOException { + PersonalizeClientSettings settings = + PersonalizeClientSettingsTestUtil.buildClientSettings(false, false, false); + + AWSCredentialsProvider credentialsProvider = PersonalizeCredentialsProviderFactory.getCredentialsProvider(settings); + assertEquals(credentialsProvider.getClass(), DefaultAWSCredentialsProviderChain.class); + } + + public void testGetCredentialsProviderWithIAMRole() throws IOException { + PersonalizeClientSettings settings = + PersonalizeClientSettingsTestUtil.buildClientSettings(true, true, true); + + String iamRoleArn = "test-iam-role-arn"; + String awsRegion = "us-west-2"; + AWSCredentialsProvider credentialsProvider = PersonalizeCredentialsProviderFactory.getCredentialsProvider(settings, iamRoleArn, awsRegion); + assertEquals(credentialsProvider.getClass(), STSAssumeRoleSessionCredentialsProvider.class); + IdleConnectionReaper.shutdown(); + } + + public void testGetStaticCredentialsProviderWithEmptyIAMRole() throws IOException { + PersonalizeClientSettings settings = + PersonalizeClientSettingsTestUtil.buildClientSettings(true, true, true); + + String iamRoleArn = ""; + String awsRegion = "us-west-2"; + AWSCredentialsProvider credentialsProvider = PersonalizeCredentialsProviderFactory.getCredentialsProvider(settings, iamRoleArn, awsRegion); + assertEquals(credentialsProvider.getClass(), AWSStaticCredentialsProvider.class); + } + + public void testGetDefaultCredentialsProviderWithEmptyIAMRole() throws IOException { + PersonalizeClientSettings settings = + PersonalizeClientSettingsTestUtil.buildClientSettings(false, false, false); + + String iamRoleArn = ""; + String awsRegion = "us-west-2"; + AWSCredentialsProvider credentialsProvider = PersonalizeCredentialsProviderFactory.getCredentialsProvider(settings, iamRoleArn, awsRegion); + assertEquals(credentialsProvider.getClass(), DefaultAWSCredentialsProviderChain.class); + } +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/PersonalizeIntelligentRankerConfigurationTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/PersonalizeIntelligentRankerConfigurationTests.java new file mode 100644 index 0000000..b54a2e4 --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/PersonalizeIntelligentRankerConfigurationTests.java @@ -0,0 +1,32 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration; + +import org.opensearch.test.OpenSearchTestCase; + +public class PersonalizeIntelligentRankerConfigurationTests extends OpenSearchTestCase { + + public void createConfigurationTest() { + String personalizeCampaign = "arn:aws:personalize:us-west-2:000000000000:campaign/test-campaign"; + String iamRoleArn = "sampleRoleArn"; + String recipe = "sample-personalize-recipe"; + String itemIdField = "ITEM_ID"; + String region = "us-west-2"; + double weight = 0.25; + + PersonalizeIntelligentRankerConfiguration config = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, weight); + + assertEquals(config.getPersonalizeCampaign(), personalizeCampaign); + assertEquals(config.getIamRoleArn(), iamRoleArn); + assertEquals(config.getRecipe(), recipe); + assertEquals(config.getItemIdField(), itemIdField); + assertEquals(config.getRegion(), region); + assertEquals(config.getWeight(), weight, 0.0); + } +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/PersonalizeRankerFactoryTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/PersonalizeRankerFactoryTests.java new file mode 100644 index 0000000..b88f601 --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/PersonalizeRankerFactoryTests.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.ranker; + +import org.mockito.Mockito; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRanker; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRankerFactory; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.impl.AmazonPersonalizedRankerImpl; +import org.opensearch.test.OpenSearchTestCase; + +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME; + +public class PersonalizeRankerFactoryTests extends OpenSearchTestCase { + + private String personalizeCampaign = "arn:aws:personalize:us-west-2:000000000000:campaign/test-campaign"; + private String iamRoleArn = "sampleRoleArn"; + private String itemIdField = "ITEM_ID"; + private String region = "us-west-2"; + private double weight = 0.25; + + public void testGetPersonalizeRankerForPersonalizedRankingRecipe() { + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, AMAZON_PERSONALIZED_RANKING_RECIPE_NAME, itemIdField, region, weight); + + PersonalizedRankerFactory factory = new PersonalizedRankerFactory(); + PersonalizedRanker ranker = factory.getPersonalizedRanker(rankerConfig, client); + assertEquals(ranker.getClass(), AmazonPersonalizedRankerImpl.class); + } + + public void testGetPersonalizeRankerForUnknownRecipe() { + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, "sample-recipe", itemIdField, region, weight); + + PersonalizedRankerFactory factory = new PersonalizedRankerFactory(); + PersonalizedRanker ranker = factory.getPersonalizedRanker(rankerConfig, client); + assertNull(ranker); + } +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java new file mode 100644 index 0000000..61c31ef --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.ranker.impl; + +import org.mockito.Mockito; +import org.opensearch.search.SearchHits; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.impl.AmazonPersonalizedRankerImpl; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.PersonalizeRuntimeTestUtil; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.SearchTestUtil; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.mockito.ArgumentMatchers.any; + +public class AmazonPersonalizeRankerImplTests extends OpenSearchTestCase { + + private String personalizeCampaign = "arn:aws:personalize:us-west-2:000000000000:campaign/test-campaign"; + private String iamRoleArn = "sampleRoleArn"; + private String recipe = "sample-personalize-recipe"; + private String itemIdField = "ITEM_ID"; + private String region = "us-west-2"; + private double weight = 0.25; + + public void testReRank() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, weight); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult()); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + assertEquals(responseHits.getHits().length, transformedHits.getHits().length); + } + + public void testReRankWithoutItemIdFieldInConfig() throws IOException { + String blankItemIdField = ""; + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, blankItemIdField, region, weight); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult()); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + assertEquals(responseHits.getHits().length, transformedHits.getHits().length); + } +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameterUtilTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameterUtilTests.java new file mode 100644 index 0000000..c997cb7 --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParameterUtilTests.java @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +public class PersonalizeRequestParameterUtilTests extends OpenSearchTestCase { + + public void testExtractParameters() { + PersonalizeRequestParameters expected = new PersonalizeRequestParameters("user_1"); + PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder(); + extBuilder.setRequestParameters(expected); + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource() + .ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + PersonalizeRequestParameters actual = PersonalizeRequestParameterUtil.getPersonalizeRequestParameters(request); + assertEquals(expected, actual); + } +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilderTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilderTests.java new file mode 100644 index 0000000..e2b287e --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/requestparameter/PersonalizeRequestParametersExtBuilderTests.java @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter; + +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +public class PersonalizeRequestParametersExtBuilderTests extends OpenSearchTestCase { + + public void testXContentRoundTrip() throws IOException { + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters("28"); + PersonalizeRequestParametersExtBuilder personalizeExtBuilder = new PersonalizeRequestParametersExtBuilder(); + personalizeExtBuilder.setRequestParameters(requestParameters); + XContentType xContentType = randomFrom(XContentType.values()); + BytesReference serialized = XContentHelper.toXContent(personalizeExtBuilder, xContentType, true); + + XContentParser parser = createParser(xContentType.xContent(), serialized); + + PersonalizeRequestParametersExtBuilder deserialized = + PersonalizeRequestParametersExtBuilder.parse(parser); + + assertEquals(personalizeExtBuilder, deserialized); + } + + public void testStreamRoundTrip() throws IOException { + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + PersonalizeRequestParametersExtBuilder personalizeExtBuilder = new PersonalizeRequestParametersExtBuilder(); + personalizeExtBuilder.setRequestParameters(requestParameters); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + personalizeExtBuilder.writeTo(bytesStreamOutput); + + PersonalizeRequestParametersExtBuilder deserialized = + new PersonalizeRequestParametersExtBuilder(bytesStreamOutput.bytes().streamInput()); + assertEquals(personalizeExtBuilder, deserialized); + } + + +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeClientSettingsTestUtil.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeClientSettingsTestUtil.java new file mode 100644 index 0000000..d5e4052 --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeClientSettingsTestUtil.java @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils; + +import org.opensearch.common.settings.MockSecureSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClientSettings; + +import java.io.IOException; + +public class PersonalizeClientSettingsTestUtil { + public static final String ACCESS_KEY = "my-access-key"; + public static final String SECRET_KEY = "my-secret-key"; + public static final String SESSION_TOKEN = "session-token"; + + public static PersonalizeClientSettings buildClientSettings(boolean withAccessKey, boolean withSecretKey, + boolean withSessionToken) throws IOException { + try (MockSecureSettings secureSettings = new MockSecureSettings()) { + if (withAccessKey) { + secureSettings.setString(PersonalizeClientSettings.ACCESS_KEY_SETTING.getKey(), ACCESS_KEY); + } + if (withSecretKey) { + secureSettings.setString(PersonalizeClientSettings.SECRET_KEY_SETTING.getKey(), SECRET_KEY); + } + if (withSessionToken) { + secureSettings.setString(PersonalizeClientSettings.SESSION_TOKEN_SETTING.getKey(), SESSION_TOKEN); + } + Settings settings = Settings.builder() + .setSecureSettings(secureSettings) + .build(); + return PersonalizeClientSettings.getClientSettings(settings); + } + } +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeRuntimeTestUtil.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeRuntimeTestUtil.java new file mode 100644 index 0000000..db68d2d --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeRuntimeTestUtil.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils; + +import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingRequest; +import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingResult; +import com.amazonaws.services.personalizeruntime.model.PredictedItem; + +import java.util.ArrayList; +import java.util.List; + +public class PersonalizeRuntimeTestUtil { + + public static GetPersonalizedRankingRequest buildGetPersonalizedRankingRequest() { + GetPersonalizedRankingRequest request = new GetPersonalizedRankingRequest() + .withUserId("sampleUserId") + .withInputList(new ArrayList()) + .withCampaignArn("sampleCampaign"); + return request; + } + + public static GetPersonalizedRankingResult buildGetPersonalizedRankingResult() { + List predictedItems = new ArrayList<>(); + GetPersonalizedRankingResult result = new GetPersonalizedRankingResult() + .withPersonalizedRanking(predictedItems) + .withRecommendationId("sampleRecommendationId"); + return result; + } +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/SearchTestUtil.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/SearchTestUtil.java new file mode 100644 index 0000000..b8e5d13 --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/SearchTestUtil.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils; + +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; + +import java.io.IOException; +import java.util.Map; + +public class SearchTestUtil { + public static SearchHits getSampleSearchHitsForPersonalize(int numHits) throws IOException { + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + XContentBuilder sourceContent = JsonXContent.contentBuilder() + .startObject() + .field("_id", String.valueOf(i)) + .field("ITEM_ID", String.valueOf(i)) + .field("body", "Body text for document number " + i) + .field("title", "This is the title for document " + i) + .endObject(); + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); + hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); + } + SearchHits searchHits = new SearchHits(hitsArray, new TotalHits(numHits, TotalHits.Relation.EQUAL_TO), 1.0f); + return searchHits; + } +}