Skip to content

Commit

Permalink
Personalized intelligent ranking for open search requests (#138)
Browse files Browse the repository at this point in the history
* Personalized intelligent ranking for open search requests

Description:
============
This change uses search pipeline feature of open search to allow customers to create a response processor.
Response processor will rerank search results obtained from open search using Personalization solution before returning results to the user.

Check List:
==========
- [Y] New functionality includes testing.
  - [Y] All tests pass
- [Y] New functionality has been documented.
  - [Y] New functionality has javadoc added
- [Y] Commits are signed as per the DCO using --signoff

Signed-off-by: Ketan Kulkarni <kektnr@amazon.com>

* Support passing user id for Personalization as part of search request

Description:
============
This change introduces ability ot pass user id as part of search request to be used by Personalize response processor.
SearchExtSpec will be used to support getting user id as well as other personalize related request parameters to personalize search results.

Check list:
===========
- [x] New functionality includes testing.
  - [x] All tests pass
- [x] New functionality has been documented.
  - [x] New functionality has javadoc added
- [x] Commits are signed as per the DCO using --signoff

Signed-off-by: Ketan Kulkarni <kektnr@amazon.com>

* Incorporate changes to SearchPipelinePlugin interface

Description:
============
SearchPipelinePlugin interface now uses Processor.Factory<SearchResponseProcessor> instead of Processor.Factory.
Also use getResponseProcessors method instead of getProcessors method.

Check list:
===========
- [x] New functionality includes testing.
  - [x] All tests pass
- [x] New functionality has been documented.
  - [x] New functionality has javadoc added
- [x] Commits are signed as per the DCO using --signoff

Signed-off-by: Ketan Kulkarni <kektnr@amazon.com>

* Improve test coverage

Added unit tests for PersonalizeRankingResponseProcessor and
PersonalizeRequestParameterUtil.

Also, needed to update JaCoCo to 0.8.9 in order to get rid of warnings
about invalid bytecode versions since we upgraded to JDK 20.

Signed-off-by: Michael Froh <froh@amazon.com>

* Move configruation constants to the search pipeline processor factory

Description:
============
Configuration constants are only used by search response processor factory.
Move related constants from global configuration constants file to factory class to ensure they are closer to
the palce where they are used.

Signed-off-by: Ketan Kulkarni <kektnr@amazon.com>

---------

Signed-off-by: Ketan Kulkarni <kektnr@amazon.com>
Signed-off-by: Michael Froh <froh@amazon.com>
Co-authored-by: Michael Froh <froh@amazon.com>
  • Loading branch information
kulket and msfroh authored May 24, 2023
1 parent 2e78139 commit 23b5226
Show file tree
Hide file tree
Showing 27 changed files with 1,585 additions and 7 deletions.
9 changes: 9 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,17 @@ dependencies {
implementation 'commons-logging:commons-logging:1.2'
implementation 'com.amazonaws:aws-java-sdk-sts:1.12.300'
implementation 'com.amazonaws:aws-java-sdk-core:1.12.300'
implementation 'com.amazonaws:aws-java-sdk-personalizeruntime:1.12.300'
}


allprojects {
plugins.withId('jacoco') {
jacoco.toolVersion = '0.8.9'
}
}


test {
include '**/*Tests.class'
finalizedBy jacocoTestReport
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
Expand All @@ -27,9 +26,12 @@
import org.opensearch.env.NodeEnvironment;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.relevance.actionfilter.SearchActionFilter;
import org.opensearch.search.relevance.client.OpenSearchClient;
import org.opensearch.search.relevance.configuration.ResultTransformerConfigurationFactory;
Expand All @@ -40,10 +42,13 @@
import org.opensearch.search.relevance.transformer.ResultTransformer;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfigurationFactory;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.PersonalizeRankingResponseProcessor;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClientSettings;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParametersExtBuilder;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

public class SearchRelevancePlugin extends Plugin implements ActionPlugin, SearchPlugin {
public class SearchRelevancePlugin extends Plugin implements ActionPlugin, SearchPlugin, SearchPipelinePlugin {

private OpenSearchClient openSearchClient;
private KendraHttpClient kendraClient;
Expand Down Expand Up @@ -101,10 +106,18 @@ public Collection<Object> createComponents(
public List<SearchExtSpec<?>> getSearchExts() {
Map<String, ResultTransformerConfigurationFactory> resultTransformerMap = getResultTransformerConfigurationFactories().stream()
.collect(Collectors.toMap(ResultTransformerConfigurationFactory::getName, i -> i));
return Collections.singletonList(
new SearchExtSpec<>(SearchConfigurationExtBuilder.NAME,
input -> new SearchConfigurationExtBuilder(input, resultTransformerMap),
parser -> SearchConfigurationExtBuilder.parse(parser, resultTransformerMap)));
return List.of(new SearchExtSpec<>(SearchConfigurationExtBuilder.NAME,
input -> new SearchConfigurationExtBuilder(input, resultTransformerMap),
parser -> SearchConfigurationExtBuilder.parse(parser, resultTransformerMap)),
new SearchExtSpec<>(PersonalizeRequestParametersExtBuilder.NAME,
input -> new PersonalizeRequestParametersExtBuilder(input),
parser -> PersonalizeRequestParametersExtBuilder.parse(parser)));
}

@Override
public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Processor.Parameters parameters) {
return Map.of(PersonalizeRankingResponseProcessor.TYPE,
new PersonalizeRankingResponseProcessor.Factory(
PersonalizeClientSettings.getClientSettings(parameters.env.settings())));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.relevance.transformer.personalizeintelligentranking;

import com.amazonaws.auth.AWSCredentialsProvider;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClientSettings;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeCredentialsProviderFactory;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameterUtil;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRanker;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRankerFactory;

import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;

/**
* This is a {@link SearchResponseProcessor} that applies Personalized intelligent ranking
*/
public class PersonalizeRankingResponseProcessor implements SearchResponseProcessor {

private static final Logger logger = LogManager.getLogger(PersonalizeRankingResponseProcessor.class);

public static final String TYPE = "personalize_ranking";
private final String tag;
private final String description;
private final PersonalizeClient personalizeClient;
private final PersonalizeIntelligentRankerConfiguration rankerConfig;

/**
* Constructor for Personalize ranking response processor
*
* @param tag processor tag
* @param description processor description
* @param rankerConfig personalize ranker config
* @param client personalize client
*/
public PersonalizeRankingResponseProcessor(String tag,
String description,
PersonalizeIntelligentRankerConfiguration rankerConfig,
PersonalizeClient client) {
super();
this.tag = tag;
this.description = description;
this.rankerConfig = rankerConfig;
this.personalizeClient = client;
}

/**
* Transform the response hits by re ranking results using Personalize
*
* @param request Search request
* @param response Search response that needs to be transformed
* @return Transformed search response using personalized re ranking
* @throws Exception Throws exception for any error while processing response
*/
@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
SearchHits hits = response.getHits();

if (hits.getHits().length == 0) {
logger.info("TotalHits = 0. Returning search response without applying Personalize transform");
return response;
}
logger.info("Personalizing search results.");
PersonalizeRequestParameters personalizeRequestParameters =
PersonalizeRequestParameterUtil.getPersonalizeRequestParameters(request);
PersonalizedRankerFactory rankerFactory = new PersonalizedRankerFactory();
PersonalizedRanker ranker = rankerFactory.getPersonalizedRanker(rankerConfig, personalizeClient);
long startTime = System.nanoTime();
SearchHits personalizedHits = ranker.rerank(hits, personalizeRequestParameters);
long personalizeTimeTookMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime);

final SearchResponseSections transformedSearchResponseSections = new InternalSearchResponse(personalizedHits,
(InternalAggregations) response.getAggregations(), response.getSuggest(),
new SearchProfileShardResults(response.getProfileResults()), response.isTimedOut(),
response.isTerminatedEarly(), response.getNumReducePhases());

final SearchResponse transformedResponse = new SearchResponse(transformedSearchResponseSections, response.getScrollId(),
response.getTotalShards(), response.getSuccessfulShards(),
response.getSkippedShards(), response.getTook().getMillis() + personalizeTimeTookMs, response.getShardFailures(),
response.getClusters());

logger.info("Personalize ranking processor took " + personalizeTimeTookMs + " ms");

return transformedResponse;
}

/**
* Get the type of the processor.
*/
@Override
public String getType() {
return TYPE;
}

/**
* Get the tag of a processor.
*/
@Override
public String getTag() {
return tag;
}

/**
* Gets the description of a processor.
*/
@Override
public String getDescription() {
return description;
}

public static final class Factory implements Processor.Factory<SearchResponseProcessor> {

private static final String CAMPAIGN_ARN_CONFIG_NAME = "campaign_arn";
private static final String ITEM_ID_FIELD_CONFIG_NAME = "item_id_field";
private static final String IAM_ROLE_ARN_CONFIG_NAME = "iam_role_arn";
private static final String RECIPE_CONFIG_NAME = "recipe";
private static final String REGION_CONFIG_NAME = "aws_region";
private static final String WEIGHT_CONFIG_NAME = "weight";
PersonalizeClientSettings personalizeClientSettings;
private final BiFunction<AWSCredentialsProvider, String, PersonalizeClient> clientBuilder;

Factory(PersonalizeClientSettings settings, BiFunction<AWSCredentialsProvider, String, PersonalizeClient> clientBuilder) {
this.personalizeClientSettings = settings;
this.clientBuilder = clientBuilder;
}

public Factory(PersonalizeClientSettings settings) {
this(settings, PersonalizeClient::new);
}

@Override
public PersonalizeRankingResponseProcessor create(Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories, String tag, String description, Map<String, Object> config) throws Exception {
String personalizeCampaign = ConfigurationUtils.readStringProperty(TYPE, tag, config, CAMPAIGN_ARN_CONFIG_NAME);
String iamRoleArn = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, IAM_ROLE_ARN_CONFIG_NAME);
String recipe = ConfigurationUtils.readStringProperty(TYPE, tag, config, RECIPE_CONFIG_NAME);
String itemIdField = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, ITEM_ID_FIELD_CONFIG_NAME);
String awsRegion = ConfigurationUtils.readStringProperty(TYPE, tag, config, REGION_CONFIG_NAME);
double weight = ConfigurationUtils.readDoubleProperty(TYPE, tag, config, WEIGHT_CONFIG_NAME);

PersonalizeIntelligentRankerConfiguration rankerConfig =
new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, awsRegion, weight);
AWSCredentialsProvider credentialsProvider = PersonalizeCredentialsProviderFactory.getCredentialsProvider(personalizeClientSettings, iamRoleArn, awsRegion);
PersonalizeClient personalizeClient = clientBuilder.apply(credentialsProvider, awsRegion);
return new PersonalizeRankingResponseProcessor(tag, description, rankerConfig, personalizeClient);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.relevance.transformer.personalizeintelligentranking.client;

import com.amazonaws.AmazonServiceException;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.services.personalizeruntime.AmazonPersonalizeRuntime;
import com.amazonaws.services.personalizeruntime.AmazonPersonalizeRuntimeClientBuilder;
import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingRequest;
import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingResult;

import java.io.Closeable;
import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedAction;

/**
* Amazon Personalize client implementation for getting personalized ranking
*/
public class PersonalizeClient implements Closeable {
private final AmazonPersonalizeRuntime personalizeRuntime;

/**
* Constructor for Amazon Personalize client
* @param credentialsProvider Credentials to be used for accessing Amazon Personalize
* @param awsRegion AWS region where Amazon Personalize campaign is hosted
*/
public PersonalizeClient(AWSCredentialsProvider credentialsProvider, String awsRegion) {
personalizeRuntime = AccessController.doPrivileged(
(PrivilegedAction<AmazonPersonalizeRuntime>) () -> AmazonPersonalizeRuntimeClientBuilder.standard()
.withCredentials(credentialsProvider)
.withRegion(awsRegion)
.build());
}

/**
* Get Personalize runtime client
* @return Personalize runtime client
*/
public AmazonPersonalizeRuntime getPersonalizeRuntime() {
return personalizeRuntime;
}

/**
* Get Personalized ranking using Personalized runtime client
* @param request Get personalized ranking request
* @return Personalized ranking results
*/
public GetPersonalizedRankingResult getPersonalizedRanking(GetPersonalizedRankingRequest request) {
GetPersonalizedRankingResult result;
try {
result = AccessController.doPrivileged(
(PrivilegedAction<GetPersonalizedRankingResult>) () -> personalizeRuntime.getPersonalizedRanking(request));
} catch (AmazonServiceException ex) {
throw ex;
}
return result;
}

@Override
public void close() throws IOException {
if (personalizeRuntime != null) {
personalizeRuntime.shutdown();
}
}
}
Loading

0 comments on commit 23b5226

Please sign in to comment.